# Vanilla GAN on MNIST dataset

In [11]:
import torch
import torch.nn as nn
from torch.nn import Sequential
import torch.optim as optim
from torch.autograd.variable import Variable
from torch.nn import CrossEntropyLoss
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import imageio
import numpy as np
from matplotlib import pyplot as plt

#### loading MNIST data 

In [12]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])
to_image = transforms.ToPILImage()
trainset = datasets.MNIST(root="./mnist_data",train=True,download=True,transform=transform)

trainloader = DataLoader(trainset,batch_size=64,shuffle=True)

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### Define Class Discriminator and Generator
For discriminator we give an image(flattend) of size 784 and predict whether is it true/fake digit from MNIST. 
For Generator we give a radom noise of 100 dimensional data and predict a image(flattend) of size 784 to produce an digit inorder to fool discriminator.

In [14]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.fc0 = Sequential(
                        nn.Linear(784,1024),
                        nn.LeakyReLU(0.2),
                        nn.Dropout(0.3)
                    )
        self.fc1 = Sequential(
                    nn.Linear(1024,512),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc2 = Sequential(
                    nn.Linear(512,128),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc3 = Sequential(
                    nn.Linear(128,1),
                    nn.Sigmoid()
                    )
    def forward(self,x): # x is shape of 28X28 image tensor
        x = x.view(-1,784)
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
    
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.fc0 = Sequential(
                    nn.Linear(128,256),
                    nn.LeakyReLU(0.2)
                    )
        self.fc1 = Sequential(
                    nn.Linear(256,512),
                    nn.LeakyReLU(0.2)
                    )
        self.fc2 = Sequential(
                    nn.Linear(512,1024),
                    nn.LeakyReLU(0.2)
                    )
        self.fc3 = Sequential(
                    nn.Linear(1024,784)          
                    )

    def forward(self,x): # x is noise of 100 dimensions
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = x.view(-1,1,28,28)
        return x

In [15]:
generator = Generator()
discriminator = Discriminator()
generator.to(device)
discriminator.to(device)
g_optim = optim.Adam(generator.parameters(),lr=1e-4,weight_decay=1e-6)
d_optim = optim.Adam(discriminator.parameters(),lr=1e-4,weight_decay=1e-6)
critetion = nn.BCELoss()
g_losses=[]
d_losses=[]
images = []

def noise(n,n_features=128):
    return Variable(torch.rand(n,n_features)).to(device)

def make_ones(size):
    data = Variable(torch.ones(size,1))
    return data.to(device)

def make_zeros(size):
    data = Variable(torch.zeros(size,1))
    return data.to(device)

In [16]:
# discriminator must return sum of loss of real and fake data
def train_discriminator(optimizer,real_data,fake_data):
    n=real_data.size(0)
    optimizer.zero_grad()
    predicted_real = discriminator(real_data)
    real_loss = critetion(predicted_real,make_ones(n))
    real_loss.backward()
    
    
    predicted_fake = discriminator(fake_data)
    fake_loss = critetion(predicted_fake,make_zeros(n))
    fake_loss.backward()
    
    optimizer.step()
    
    return real_loss + fake_loss

def train_generator(optimizer,fake_data):
    n=fake_data.size(0)
    optimizer.zero_grad()
    predicted_fake = discriminator(fake_data)
    fake_loss = critetion(predicted_fake,make_zeros(n))
    fake_loss.backward()
    optimizer.step()
    
    return fake_loss

In [18]:
epochs=3
test_noise = noise(64)
generator.train()
discriminator.train()
for epoch in range(epochs):
    g_error = 0.0
    d_error = 0.0
    for i,data in enumerate(trainloader):
        inputs,target = data
        real_data = inputs.to(device)
        n = inputs.size(0)
        fake_data = generator(noise(n)).detach()
        d_error += train_discriminator(d_optim,real_data,fake_data)
        fake_data = generator(noise(n))
        g_error += train_generator(g_optim,fake_data)
    
    img = generator(test_noise).cpu().detach()
    img = make_grid(img)
    images.append(img)
    print(i)
    g_losses.append(g_error/i)
    d_losses.append(d_error/i)
    print("Epoch {}: g_loss is {:.6f} and d_loss is {:.6f} ".format(epoch,g_error/i,d_error/i))

print("Finished Training")
torch.save(generator.state_dict(),"mnist_generator.pt")

KeyboardInterrupt: 

In [None]:
imgs = [np.array(to_image(i)) for i in images]
imageio.mimsave('progress.gif', imgs)

In [None]:
plt.plot(g_losses,label="generator loss")
plt.plot(d_losses,label="discriminator loss")
plt.legend()
plt.show()