## GAN on MNIST

In [1]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

define parameters

In [2]:
num_epoch, batch_size, latent_size = 100, 32, 64

define generator model

In [3]:
class Generator(nn.Module):

    def __init__(self, in_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(inplace=True),
            nn.Linear(64, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 256), nn.ReLU(inplace=True),
            nn.Linear(256, 512), nn.ReLU(inplace=True),
            nn.Linear(512, 1024), nn.ReLU(inplace=True),
            nn.Linear(1024, 784), nn.Tanh(),
        )

    def forward(self, z):
        # shape of z: [batch_size, latent_dim]

        output = self.model(z)
        image = output.reshape(z.shape[0], 1, 28, 28)

        return image

generator = Generator(latent_size)
generator.to(device)

Generator(
  (model): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): ReLU(inplace=True)
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): ReLU(inplace=True)
    (10): Linear(in_features=1024, out_features=784, bias=True)
    (11): Tanh()
  )
)

define discriminator model

In [4]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 1024), nn.ReLU(inplace=True),
            nn.Linear(1024, 512), nn.ReLU(inplace=True),
            nn.Linear(512, 256), nn.ReLU(inplace=True),
            nn.Linear(256, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 1), nn.Sigmoid(),
        )

    def forward(self, image):
        # shape of image: [batch_size, 1, 28, 28]
        image = image.reshape(image.shape[0], -1)
        prob = self.model(image)

        return prob

discriminator = Discriminator()
discriminator.to(device)

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=256, out_features=128, bias=True)
    (7): ReLU(inplace=True)
    (8): Linear(in_features=128, out_features=1, bias=True)
    (9): Sigmoid()
  )
)

load data

In [5]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(28),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5]),
])

dataset = torchvision.datasets.MNIST("data/mnist/", train=True, download=True, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

define optimizer and loss function

In [6]:
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

loss_fn = nn.BCELoss()

train:

In [None]:
print(f'training on {device}')

for epoch in range(num_epoch):
    print(f'epcoh {epoch}: ')
    for i, mini_batch in enumerate(dataloader):
        gt_images, _ = mini_batch

        gt_images = gt_images.to(device)

        z = torch.randn(batch_size, latent_size, device=device)
        pred_images = generator(z)

        g_optimizer.zero_grad()
        g_loss = loss_fn(discriminator(pred_images), torch.ones(batch_size, 1, device=device))
        g_loss.backward()
        g_optimizer.step()

        d_optimizer.zero_grad()
        # d_loss = 0.5 * (loss_fn(discriminator(gt_images), torch.ones(batch_size, 1)) + loss_fn(discriminator(pred_images.detach()), torch.zeros(batch_size, 1)))
        real_loss = loss_fn(discriminator(gt_images), torch.ones(batch_size, 1, device=device))
        fake_loss = loss_fn(discriminator(pred_images.detach()), torch.zeros(batch_size, 1, device=device))
        d_loss = 0.5 * (real_loss + fake_loss)

        # model is stable when you observing real_loss and fake_loss both minimize

        d_loss.backward()
        d_optimizer.step()

        if i % 1000 == 0:
            for index, image in enumerate(pred_images):
                torchvision.utils.save_image(image, f"image_{index}.png")
