In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
from collections import OrderedDict

import torch
import torchvision
import numpy as np
import pytorch_lightning as pl 

from torch import nn
from torch.nn import functional as F 
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.datasets import MNIST
from pytorch_lightning.metrics.functional import accuracy

In [2]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

In [3]:
class DataModule(pl.LightningDataModule):
    def __init__(self, conf):
        super().__init__()
        self.data_path = conf.data_path
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.dims = conf.dims 
        self.num_classes = conf.num_classes
    
    def prepare_data(self):
        MNIST(self.data_path, train=True, download=True)
        MNIST(self.data_path, train=False, download=True)
    
    def setup(self, stage):
        if stage == 'fit' or stage is None:
            data_full = MNIST(self.data_path, train=True, transform=self.transform)
            self.data_train, self.data_val = random_split(data_full, [55000, 5000])
        
        if stage == 'test' or stage is None:
            self.data_test = MNIST(self.data_path, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=32)
    
    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=32)
    
    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=32)

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape
    
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [5]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [6]:
class GAN(pl.LightningModule):
    def __init__(self, conf):
        super().__init__()
        self.conf = conf
        data_shape = conf.dims
        self.generator = Generator(latent_dim=conf.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)
        self.validation_z = torch.randn(8, conf.latent_dim)
        self.example_input_array = torch.zeros(2, conf.latent_dim)
    
    def forward(self, z):
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, bathc_idx, optimizer_idx):
        imgs, _ = batch 

        # sample noise
        z = torch.randn(imgs.shape[0], self.conf.latent_dim)
        z = z.type_as(imgs)

        # train_generator
        if optimizer_idx == 0:
            # generate images
            self.generated_imgs =self(z)
            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)
            # ground truth result
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
        
        # train discrimintor
        if optimizer_idx == 1:
            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss)
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
        
    def configure_optimizers(self):
        lr = self.conf.lr 
        b1 = self.conf.b1 
        b2 = self.conf.b2 
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
    
    def on_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)

In [7]:
conf = Config(
    data_path=r'data/',
    dims=(1, 28, 28),
    num_classes=10,
    lr=2e-4,
    b1=0.5,
    b2=0.999,
    latent_dim=100,
)
conf

{'data_path': 'data/',
 'dims': (1, 28, 28),
 'num_classes': 10,
 'lr': 0.0002,
 'b1': 0.5,
 'b2': 0.999,
 'latent_dim': 100}

In [8]:
dm = DataModule(conf)
model = GAN(conf)
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [9]:
trainer.fit(model, dm)


  | Name          | Type          | Params | In sizes | Out sizes     
----------------------------------------------------------------------------
0 | generator     | Generator     | 1 M    | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K  | ?        | ?             
Epoch 4:  99%|█████████▉| 1700/1719 [01:34<00:01, 17.97it/s, loss=2.978, v_num=14, g_loss=6.36, d_loss=0.00594]


1

In [10]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 76961), started 10:46:19 ago. (Use '!kill 76961' to kill it.)