In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [None]:
data = datasets.MNIST('mnist', transform=transforms.Compose([
    transforms.ToTensor(), transforms.Normalize(0.5, 0.5)
]), download=True)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        batch_size = x.shape[0]
        return self.main(x).reshape(batch_size, 1, 28, 28)


class WassersteinCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        return self.main(x.reshape(batch_size, 28 * 28))

## Training

In [None]:
generator = Generator()
discriminator = WassersteinCritic()

In [None]:
batch = 100
# batch_of_real_data, _ = data.get_training_batch(batch)
# batch_of_noise = torch.rand(batch, 100)

In [39]:
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.99), weight_decay=0.05)
generator_optimizer = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.99), weight_decay=0.1)

data_loader = torch.utils.data.DataLoader(data, batch_size=batch, shuffle=True)
# criterion = nn.BCELoss()

for epoch in range(100):
    for batch_id, (x, _) in enumerate(data_loader):
        # sampling
        batch_of_noise = torch.randn(batch, 100)
        
        # zero grad
        generator_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()

        # predictions
        batch_of_generated = generator(batch_of_noise)
        batch_of_generated_discrimination = discriminator(batch_of_generated)
        batch_of_real_discrimination = discriminator(x)

        # computing loss
        discriminator_loss = batch_of_generated_discrimination.mean() - batch_of_real_discrimination.mean()
        discriminator_loss.backward()

        discriminator_optimizer.step()

        batch_of_generated = generator(batch_of_noise)
        batch_of_generated_discrimination = discriminator(batch_of_generated)

        generator_loss = - batch_of_generated_discrimination.mean() - 0.3 * batch_of_generated.std(dim=0).sum()
        discriminator_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        generator_loss.backward()

        generator_optimizer.step()
        if (batch_id + 1) % 100 == 0:
            print(f'------{batch_id + 1}:D:{discriminator_loss.item()}:G:{generator_loss.item()}---')
    # log
    print(f'---{epoch + 1}:D:{discriminator_loss.item()}:G:{generator_loss.item()}---')
    examples = generator(torch.randn(10, 100))
    save_image(examples.view(examples.size(0), 1, 28, 28), f'./samples/wass/epoch_{epoch + 1}_samples' + '.png')

------100:D:-38.90008544921875:G:-499.43719482421875---
------200:D:-34.91748046875:G:-557.641357421875---
------300:D:-114.423828125:G:-480.1089782714844---
------400:D:-43.17816162109375:G:-540.6329345703125---
------500:D:-108.23867797851562:G:-397.5446472167969---
------600:D:-157.94915771484375:G:-437.50689697265625---
---1:D:-157.94915771484375:G:-437.50689697265625---
------100:D:-226.1434326171875:G:-592.5015258789062---
------200:D:-320.61865234375:G:-753.5255737304688---
------300:D:-339.61932373046875:G:-874.8909301757812---
------400:D:-70.4556884765625:G:-1281.43017578125---
------500:D:-93.70703125:G:-1057.31005859375---
------600:D:-148.328857421875:G:-909.2763061523438---
---2:D:-148.328857421875:G:-909.2763061523438---
------100:D:-18.561279296875:G:-1304.016357421875---
------200:D:6.4486083984375:G:-1360.2918701171875---
------300:D:-110.3538818359375:G:-1629.740234375---
------400:D:-103.1407470703125:G:-1588.10205078125---
------500:D:47.0556640625:G:-1677.07922363

In [40]:
torch.save(generator, 'models\\generator.model')
torch.save(discriminator, 'models\\discriminator.model')

## Check

In [41]:
examples = generator(torch.randn(100, 100))
save_image(examples.view(examples.size(0), 1, 28, 28), './samples/wass/samples_100' + '.png')