# **CNN - Generative Adversarial Network for FMNIST**
* **Basic concepts learnt from: A Deep understanding of Deep Learning (with Python intro) - Mark X Cohen (Udemy) - https://www.udemy.com/course/deeplearning_x**
* **Extended learning and understanding by VigyannVeshi**

In [None]:
# important libraries

# basic deep learning libraries
import numpy as np
import torch as tr
import torch.nn as nn
import torch.nn.functional as F

# import torchvision and transformations libraries
import torchvision as tv
import torchvision.transforms as T

# import dataset/loader libraries
from torch.utils.data import TensorDataset,DataLoader
from sklearn.model_selection import train_test_split

import sys

# import plotting libraries
import matplotlib.pyplot as plt
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
device=tr.device('cuda' if tr.cuda.is_available() else 'cpu')
device

In [None]:
# transformations
transform = T.Compose([ T.ToTensor(),
                        T.Resize(64),
                        T.Normalize(.5,.5),
                       ])

# import the data and simultaneously apply the transform
dataset = tv.datasets.FashionMNIST(root='./data', download=True, transform=transform)

In [None]:
# list the categories
print(dataset.classes)

# pick three categories (leave one line uncommented)
classes2keep = [ 'Trouser','Sneaker','Pullover' ]
# classes2keep = [ 'Trouser','Sneaker', 'Sandal'  ]



# find the corresponding data indices
images2use = tr.Tensor()
for i in range(len(classes2keep)):
  classidx = dataset.classes.index(classes2keep[i])
  images2use = tr.cat( (images2use,tr.where(dataset.targets==classidx)[0]), 0).type(tr.long)
  print(f'Added class {classes2keep[i]} (index {classidx})')

# now select just those images

# transform to dataloaders
batchsize   = 100
sampler     = tr.utils.data.sampler.SubsetRandomSampler(images2use)
data_loader = DataLoader(dataset,sampler=sampler,batch_size=batchsize,drop_last=True)

In [None]:
# view some images
# inspect a few random images

X,y = next(iter(data_loader))

fig,axs = plt.subplots(3,6,figsize=(10,6))

for (i,ax) in enumerate(axs.flatten()):

  # extract that image
  pic = tr.squeeze(X.data[i])
  pic = pic/2 + .5 # undo normalization
  
  # and its label
  label = dataset.classes[y[i]]

  # and show!
  ax.imshow(pic,cmap='gray')
  ax.text(14,0,label,ha='center',fontweight='bold',color='k',backgroundcolor='y')
  ax.axis('off')

plt.tight_layout()
plt.show()

**Architecture and meta-parameter choices were inspired by https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html**

In [None]:
# Create the discriminator and generator

class DiscriminatorNet(nn.Module):
    def __init__(self):
        super().__init__()

        # convolutional layers
        self.conv1 = nn.Conv2d(1,64,4,2,1,bias=False)
        self.conv2 = nn.Conv2d(64,128,4,2,1,bias=False)
        self.conv3 = nn.Conv2d(128,256,4,2,1,bias=False)
        self.conv4 = nn.Conv2d(256,512,4,2,1,bias=False)
        self.conv5 = nn.Conv2d(512,1,4,1,0,bias=False)

        # batchnormalization
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4= nn.BatchNorm2d(512)

    def forward(self,x):
        x=F.leaky_relu(self.conv1(x),0.2)
        x=F.leaky_relu(self.conv2(x),0.2)
        x=self.bn2(x)
        x=F.leaky_relu(self.conv3(x),0.2)
        x=self.bn3(x)
        x=F.leaky_relu(self.conv4(x),0.2)
        x=self.bn4(x)
        return tr.sigmoid(self.conv5(x)).view(-1,1)

dnet=DiscriminatorNet()
y=dnet(tr.randn(10,1,64,64))
y

In [None]:
class GeneratorNet(nn.Module):
    def __init__(self):
        super().__init__()

        # convolutional layers
        self.conv1=nn.ConvTranspose2d(100,512,4,1,0,bias=False)
        self.conv2=nn.ConvTranspose2d(512,256,4,2,1,bias=False)
        self.conv3=nn.ConvTranspose2d(256,128,4,2,1,bias=False)
        self.conv4=nn.ConvTranspose2d(128,64,4,2,1,bias=False)
        self.conv5=nn.ConvTranspose2d(64,1,4,2,1,bias=False)

        # batchnorm
        self.bn1=nn.BatchNorm2d(512)
        self.bn2=nn.BatchNorm2d(256)
        self.bn3=nn.BatchNorm2d(128)
        self.bn4=nn.BatchNorm2d(64)

    
    def forward(self,x):
        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x=F.relu(self.bn3(self.conv3(x)))
        x=F.relu(self.bn4(self.conv4(x)))
        return tr.tanh(self.conv5(x))
    
gnet=GeneratorNet()
y=gnet(tr.randn(10,100,1,1))
plt.imshow(y[0,:,:,:].squeeze().detach().numpy())
plt.show()

In [None]:
# train the models
lossfun=nn.BCELoss()

dnet=DiscriminatorNet().to(device)
gnet=GeneratorNet().to(device)

d_optimizer=tr.optim.Adam(dnet.parameters(),lr=0.0002,betas=(0.5,0.999))
g_optimizer=tr.optim.Adam(gnet.parameters(),lr=0.0002,betas=(0.5,0.999))

In [None]:
# number of epochs (expressed in number of batches)
num_epochs = int(2500/len(data_loader))

losses  = []
disDecs = []

for epochi in range(num_epochs):

  for data,_ in data_loader:
    
    # send data to GPU
    data = data.to(device)

    # create labels for real and fake images
    real_labels = tr.ones(batchsize,1).to(device)
    fake_labels = tr.zeros(batchsize,1).to(device)



    ### ---------------- Train the discriminator ---------------- ###

    # forward pass and loss for REAL pictures
    pred_real   = dnet(data)                     # output of discriminator
    d_loss_real = lossfun(pred_real,real_labels) # all labels are 1

    # forward pass and loss for FAKE pictures
    fake_data   = tr.randn(batchsize,100,1,1).to(device) # random numbers to seed the generator
    fake_images = gnet(fake_data)                           # output of generator
    pred_fake   = dnet(fake_images)                         # pass through discriminator
    d_loss_fake = lossfun(pred_fake,fake_labels)            # all labels are 0

    # collect loss (using combined losses)
    d_loss = d_loss_real + d_loss_fake

    # backprop
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()



    ### ---------------- Train the generator ---------------- ###

    # create fake images and compute loss
    fake_images = gnet( tr.randn(batchsize,100,1,1).to(device) )
    pred_fake   = dnet(fake_images)

    # compute loss
    g_loss = lossfun(pred_fake,real_labels)

    # backprop
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()


    # collect losses and discriminator decisions
    losses.append([d_loss.item(),g_loss.item()])
    
    d1 = tr.mean((pred_real>.5).float()).detach().cpu()
    d2 = tr.mean((pred_fake>.5).float()).detach().cpu()
    disDecs.append([d1,d2])


  # print out a status message
  msg = f'Finished epoch {epochi+1}/{num_epochs}'
  sys.stdout.write('\r' + msg)


# convert performance from list to numpy array
losses  = np.array(losses)
disDecs = np.array(disDecs)

In [None]:
# create a 1D smoothing filter  
def smooth(x,k=15):
    return np.convolve(x,np.ones(k)/k,mode='same')

In [None]:
fig,ax=plt.subplots(1,3,figsize=(18,5))

ax[0].plot(smooth(losses[:,0]))
ax[0].plot(smooth(losses[:,1]))
ax[0].set_xlabel('Batches')
ax[0].set_ylabel('Loss')
ax[0].set_title("Model Loss")
ax[0].legend(['Discriminator','Generator'])

ax[1].plot(losses[:,0],losses[:,1],'k.',alpha=0.1)
ax[1].set_xlabel('Discriminator Loss')
ax[1].set_ylabel("Generator loss")

ax[2].plot(smooth(disDecs[:,0]))
ax[2].plot(smooth(disDecs[:,1]))
ax[2].set_xlabel('Epochs')
ax[2].set_ylabel('Probability ("real")')
ax[2].set_title('Discriminator output')
ax[2].legend(['Real','Fake'])

plt.show()

In [None]:
# generate the images from the generator network
gnet.eval()
fake_data = gnet( tr.randn(batchsize,100,1,1).to(device) ).cpu()

# and visualize...
fig,axs = plt.subplots(3,6,figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
  ax.imshow(fake_data[i,:,].detach().squeeze(),cmap='gray')
  ax.axis('off')

plt.suptitle(classes2keep,y=.95,fontweight='bold')
plt.show()

**Additional Explorations**

In [None]:
# 1) I've mentioned before that GANs can be quite sensitive to subtle changes in model architecture. Try running the code again with exactly one change: Set the Adam 'betas' parameters to their default values (simply remove that argument from the code. How much of an impact does this have on the results? More generally, these sensitivities can be frustrating when trying to build new models; the best thing to do is search the web for similar kinds of models and be inspired by their decisision (but don't assume that a model is good just because it's on the web!).