## Conditional GAN for Data Augmentation

In [9]:
import os

import lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FashionMNIST

BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = 1

## DataModule

In [10]:
class FMNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = ".",
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        FashionMNIST(self.data_dir, train=True, download=True)
        FashionMNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            fmnist_full = FashionMNIST(self.data_dir, train=True, transform=self.transform)
            self.fmnist_train, self.fmnist_val = random_split(fmnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.fmnist_test = FashionMNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.fmnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(self.fmnist_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def test_dataloader(self):
        return DataLoader(self.fmnist_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

## Generator

In [11]:
class Generator(nn.Module):
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4, stride=1, padding=0),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=3),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, final_layer=True),
        )
    
    def make_gen_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.Tanh(),
            )
    
    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)
    
def get_noise(n_samples, input_dim):
    return torch.randn(n_samples, input_dim)

## Discriminator

In [12]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, hidden_dim * 4, kernel_size=3),
            self.make_disc_block(hidden_dim * 4, 1, stride=1, padding=0, final_layer=True),
        )
    
    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )
    
    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [13]:
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    generator_input_dim = z_dim + n_classes
    discriminator_im_chan = mnist_shape[0] + n_classes
    return generator_input_dim, discriminator_im_chan

## GAN (LightningModule)

In [14]:
class GAN(pl.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        n_classes,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        generator_input_dim, discriminator_im_chan = get_input_dimensions(latent_dim, (channels, width, height), n_classes)
        data_shape = (channels, width, height)
        self.generator = Generator(input_dim=generator_input_dim)
        self.discriminator = Discriminator(im_chan=discriminator_im_chan)

        def weights_init(m):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                torch.nn.init.normal_(m.weight, 0.0, 0.02)
            if isinstance(m, nn.BatchNorm2d):
                torch.nn.init.normal_(m.weight, 0.0, 0.02)
                torch.nn.init.constant_(m.bias, 0)

        self.generator = self.generator.apply(weights_init)
        self.discriminator = self.discriminator.apply(weights_init)

        self.validation_z = torch.cat((torch.randn(10, latent_dim), F.one_hot(torch.arange(n_classes))), 1)
        

        self.example_input_array = torch.zeros(2, generator_input_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy_with_logits(y_hat, y)

    def training_step(self, batch):
        imgs, labels = batch
        optimizer_g, optimizer_d = self.optimizers()

        # sample noise and one hot labels
        one_hot_labels = F.one_hot(labels, self.hparams.n_classes)
        one_hot_labels = one_hot_labels.type_as(imgs)
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, self.hparams.height, self.hparams.width).type_as(imgs)
        
        fake_noise = get_noise(imgs.shape[0], self.hparams.latent_dim).type_as(imgs)
        noise_and_labels = torch.cat((fake_noise, one_hot_labels), 1).type_as(imgs)
        
        # train generator and generate images
        self.toggle_optimizer(optimizer_g, optimizer_idx=0)

        fake = self.generator.forward(noise_and_labels)

        # adversarial loss for generator
        fake_images_and_labels = torch.cat((fake, image_one_hot_labels), 1).type_as(imgs)
        disc_fake_pred = self.discriminator.forward(fake_images_and_labels).type_as(imgs)
        valid = torch.ones_like(disc_fake_pred).type_as(imgs)
        g_loss = self.adversarial_loss(disc_fake_pred, valid)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.log("g_loss", g_loss, prog_bar=True)
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        self.toggle_optimizer(optimizer_d, optimizer_idx=1)

        # distinguish fake images
        fake_images_and_labels = torch.cat((fake, image_one_hot_labels), 1).type_as(imgs)
        disc_fake_pred = self.discriminator.forward(fake_images_and_labels.detach()).type_as(imgs)
        valid = torch.zeros_like(disc_fake_pred).type_as(imgs)
        fake_loss = self.adversarial_loss(disc_fake_pred, valid)
        
        # distinguish real images
        real_images_and_labels = torch.cat((imgs, image_one_hot_labels), 1).type_as(imgs)
        disc_real_pred = self.discriminator.forward(real_images_and_labels).type_as(imgs)
        valid = torch.ones_like(disc_real_pred).type_as(imgs)
        real_loss = self.adversarial_loss(disc_real_pred, valid)

        # ground truth results
        d_loss = (fake_loss + real_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []


    def training_epoch_end(self, outputs):
        z = self.validation_z.type_as(self.generator.gen[0][0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs, nrow=2)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

## Logger

In [15]:
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("gan_logs")

## Training

In [None]:
dm = FMNISTDataModule()
dm.prepare_data()
dm.setup("fit")
model = GAN(*dm.dims, dm.num_classes)
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=100,
    logger=logger,
)
trainer.fit(model, dm)