Notebook for Plain GAN and WGAN implementation. Shamelessly adapt from here : https://github.com/eriklindernoren/PyTorch-GAN/

You can adapt this for demo on your presentation -- of course there will be always bonus points if Python notebook is included ;-)

# Checking whether everything is working (including installing PyTorch)

In [None]:
# Check your current folder
! pwd

In [None]:
# Installing torch if you do not have it yet. Should work also for Windows
! pip3 install torch torchvision
! mkdir datasets

In [None]:
# Checking if everything is installed nicely
! python -c "import torch; print('torch version: ', torch.__version__)"
! python -c "import torchvision; print('torchvision version: ', torchvision.__version__)"

In [None]:
# This is just for creating the model + dataset + figures folder 

! mkdir -p gan
! mkdir -p gan/model/
! mkdir -p gan/fig

- We import all the neccessary modules

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils import data

%matplotlib inline

 - Load dataset (which is MNIST for our case here)

In [None]:
mnist_trainset = datasets.MNIST(
    root='./datasets/mnist', 
    train=True, 
    download=True, 
    transform=transforms.Compose([transforms.Resize(28), 
    transforms.ToTensor(), 
    transforms.Normalize([0.5], [0.5])])
)

- Infer from the output below the size of the dataset. Do you know why we have one image following this shape?

In [None]:
print(mnist_trainset.targets.shape, mnist_trainset.data.shape )

In [None]:
batch_size = 64
dataloader = data.DataLoader(mnist_trainset, batch_size=batch_size,shuffle=True)

- We check a batch of our samples (which size should it have?)

In [None]:
imgs, targets = next(iter(dataloader))
print( imgs.mean(), imgs.std(), imgs.max(), imgs.min()  )
print( imgs.dtype, targets.dtype )

# Building the Discriminator and Generator Network

In [None]:
import torch
from torch import nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # why taking 784 as the first dimension? 
        model = [nn.Linear(784, 256), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(256, 128), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(128, 1),
                 nn.Sigmoid()
                ]
        self.model = nn.Sequential(*model) 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        model = [nn.Linear(latent_dim, 128), 
                 nn.BatchNorm1d(128, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(128, 256), 
                 nn.BatchNorm1d(256, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(256, 512),
                 nn.BatchNorm1d(512, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(512, 1024),
                 nn.BatchNorm1d(1024, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(1024, 784),
                 nn.Tanh()
                ]
        self.model = nn.Sequential(*model)
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img 


In [None]:
# This block is for loading model into either GPU or CPU -- depending if your machine has GPU or not 

use_cuda = torch.cuda.is_available() 
latent_dim = 100  # so which size of the noise we have here?
if use_cuda:
    G = Generator(latent_dim=latent_dim).cuda()
    D = Discriminator().cuda()
else:
    G = Generator(latent_dim=latent_dim)
    D = Discriminator()

- Do you still remember why we use `torch.nn.BCELoss` as a loss?
- Adam is one of (if not) the most popular stochastic GD optimization scheme for neural network. Read more [here](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) for the PyTorch documentation and https://arxiv.org/abs/1412.6980 the link for the paper

In [None]:
# Again the first line is just to ensure we have nice behaviour when playing in both GPU and CPU
Tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

# Some visualization functions

def show_img(img): # img: torch.Size([28, 28])
    #print(img.shape)
    data = img.squeeze(0).cpu().detach().numpy()
    #print(data.shape)
    plt.axis('off')
    imgplot = plt.imshow(data, cmap="gray")
    
def show_imgs(imgs, n_height=8, n_width=8):  # imgs: torch.Size([64, 1, 28, 28])
    assert len(imgs) == n_height* n_width
    imgs = imgs.squeeze(1).cpu().detach().numpy()
    fig = plt.figure(figsize = (n_height, n_width) )
    gs1 = gridspec.GridSpec(n_height, n_width)
    gs1.update(wspace=0.025, hspace=0.05) # set the spacing between axes.
    for i in range(len(imgs)):
        plottable_image = imgs[i]
        # i = i + 1 # grid spec indexes from 0
        ax = plt.subplot(gs1[i])
        ax.axis('off')
        ax.set_aspect('equal')
        ax.imshow(plottable_image, cmap='gray')

def sample_img(G, data_loader, PATH="test.jpg"):
    G.eval()
    imgs, targets = next(iter(data_loader))
    imgs, targets = imgs.type(Tensor), targets.type(Tensor)
    batch_size = imgs.size(0)
    z = torch.zeros((batch_size,latent_dim)).normal_(0, 1).type(Tensor)
    gen_imgs_batch = G(z)
    show_imgs(gen_imgs_batch)
    plt.savefig(PATH)
    plt.show()
    G.train()
    

In [None]:
# Plot a batch of sample images
imgs, targets = next(iter(dataloader))
show_imgs(imgs)

### Training the plain GAN -- warning: to finish 200 epochs without GPU it can take a very long time

In [None]:
# Parameters
model_name = "gan"
start_epo = 0
n_epochs = 200  # playing with this for some fun
print_freq = 10 # this is for outputting the generator's images + printing the log

# Loss function
loss_fn = torch.nn.BCELoss()

# Optimizers, betas are exponential decay rates for the moment estimate (read more on the links above)
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# Recheck the architecture of G and D
print(G.train())
print(D.train())

In [None]:
# Training

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Configure input
        batch_size = imgs.shape[0]
        real_imgs = imgs.type(Tensor)
        # Sample noise as generator input
        z = torch.randn(batch_size, latent_dim).type(Tensor)
        # Generate a batch of images
        gen_imgs = G(z)
        # Adversarial ground truths
        valid = torch.ones((batch_size, 1)).type(Tensor)
        fake = torch.zeros((batch_size, 1)).type(Tensor)
        if (epoch+i)%2==0:
            # g_loss: Loss measures generator's ability to fool the discriminator
            g_loss = loss_fn(D(gen_imgs), valid)
            # g_loss: loss backward
            optimizer_G.zero_grad()
            g_loss.backward(retain_graph=True)
            optimizer_G.step()
        else:
            # d_loss: Measure discriminator's ability to classify real from generated samples
            real_loss = loss_fn(D(real_imgs), valid)
            fake_loss = loss_fn(D(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            # d_loss: loss backward        
            optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            optimizer_D.step()

    if epoch==0 or epoch % print_freq == (print_freq-1) or epoch==n_epochs-1:         
        print(
            f"Epoch {epoch + 1} / {n_epochs} | Discriminator loss: {d_loss.item()} | Generator loss: {g_loss.item()}"
        )
        # save out
        this_epo_str = str(epoch+start_epo).zfill(4) 
        torch.save(G.state_dict(), f"{model_name}/model/G_{this_epo_str}")
        torch.save(D.state_dict(), f"{model_name}/model/D_{this_epo_str}")
        # sampling using the generator at checked-point epoch
        sample_img(G, dataloader, f"{model_name}/fig/{this_epo_str}.jpg")


### After training is finished, we have a Generator to play with

In [None]:
sample_img(G, dataloader)

# WGAN

In [None]:
# Same thing, but a bit different
! mkdir -p wgan/model/
! mkdir -p wgan/fig/

In [None]:
# Note that now the architecture is almost (99%) the same with plain GAN above

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        model = [nn.Linear(784, 256), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(256, 128), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(128, 1),
                 #nn.Sigmoid()  # for wGAN
                ]
        self.model = nn.Sequential(*model) 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        model = [nn.Linear(latent_dim, 128), 
                 nn.BatchNorm1d(128, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(128, 256), 
                 nn.BatchNorm1d(256, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(256, 512),
                 nn.BatchNorm1d(512, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(512, 1024),
                 nn.BatchNorm1d(1024, 0.8),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Linear(1024, 784),
                 nn.Tanh()
                ]
        self.model = nn.Sequential(*model)
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img 


In [None]:
# And also different optimizers 
optimizer_G = torch.optim.RMSprop(G.parameters(), lr=0.00005)
optimizer_D = torch.optim.RMSprop(D.parameters(), lr=0.00005)

In [None]:
# Training parameters
model_name = "wgan"
start_epo = 0
n_epochs = 200
print_freq = 10

In [None]:
# Recheck the architecture of G and D
print(G.train())
print(D.train())

In [None]:
# Training -- can you spot the very important differences here compared to plain GANs above?

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        batch_size = imgs.size(0)
        # Configure input
        real_imgs = imgs.type(Tensor)
        # Sample noise as generator input
        z = torch.randn(batch_size, latent_dim).type(Tensor)
        # Generate a batch of images
        gen_imgs = G(z)

        if (epoch + i) % 2 == 0:
            # g_loss: Loss measures generator's ability to fool the discriminator
            g_loss = -torch.mean(D(gen_imgs))
            # g_loss: loss backward
            optimizer_G.zero_grad()
            g_loss.backward(retain_graph=True)
            optimizer_G.step()
        else:
            # d_loss: Measure discriminator's ability to classify real from generated samples
            d_loss = torch.mean(D(gen_imgs.detach())) - torch.mean(D(real_imgs))
            # d_loss: loss backward        
            optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            optimizer_D.step()
            
    if epoch == 0 or epoch % print_freq == (print_freq-1) or epoch == n_epochs-1:         
        print(
            f"Epoch {epoch + 1} / {n_epochs} | Discriminator loss: {d_loss.item()} | Generator loss: {g_loss.item()}"
        )
        # save out
        this_epo_str = str(epoch + start_epo).zfill(4) 
        torch.save(G.state_dict(), f"{model_name}/model/G_{this_epo_str}")
        torch.save(D.state_dict(), f"{model_name}/model/D_{this_epo_str}")
        # sampling using the generator at checked-point epoch
        sample_img(G, dataloader, f"{model_name}/fig/{this_epo_str}.jpg")


In [None]:
sample_img(G, dataloader)