In [None]:
import numpy as np
import matplotlib.pyplot as plt
from time import time

import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_data = datasets.MNIST(root='../Data',  
                            train=True, 
                            download=True, 
                            transform=transforms.ToTensor()
                           )

In [None]:
# Hyperparameters
hparam = {"num_epochs" : 100, # Number of training epochs
          "batch_size" : 60,
          "latent_size" : 100,  # Size of z latent vector (i.e. size of generator input)    
          "lr" : 0.0001,  # Learning rate for optimizer
          "b1" : 0.9,
          "b2" : 0.999,
         }

In [None]:
dataloader = DataLoader(train_data, 
                        batch_size=hparam["batch_size"], 
                        shuffle=True)

In [None]:
im, _ = next(iter(dataloader))
plt.imshow(im[0].numpy().reshape(28, 28), cmap="gray")
plt.show()

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_size=hparam["latent_size"], batch_size=hparam["batch_size"]):
        super().__init__()
        
        self.batch_size = batch_size
        
        self.model = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.BatchNorm1d(128, 0.1),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256, 0.1),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.1),
            nn.ReLU(),
            nn.Linear(512, 28*28),
        )
        
    def forward(self, z):
        img = self.model(z)
        return img.view(self.batch_size, 1, 28, 28)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, latent_size=hparam["latent_size"], batch_size=hparam["batch_size"]):
        super().__init__()
        
        self.batch_size = batch_size
        
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.BatchNorm1d(512, 0.1),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256, 0.1),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128, 0.1),
            nn.ReLU(),              
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
            
    def forward(self, z):
        img_flat = z.view(self.batch_size, -1)
        return self.model(img_flat)

In [None]:
# Define Models
gen_net = Generator().to(device)
dis_net = Discriminator().to(device)

In [None]:
# Initialize Models
def init_weights(m):
    classname = m.__class__.__name__
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif type(m) == nn.BatchNorm1d:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

gen_net.apply(init_weights)
dis_net.apply(init_weights)

In [None]:
criterion = nn.BCELoss().to(device)

In [None]:
gen_optimizer = torch.optim.Adam(gen_net.parameters(), lr=hparam["lr"], betas=(hparam["b1"], hparam["b2"]))
dis_optimizer = torch.optim.Adam(dis_net.parameters(), lr=hparam["lr"], betas=(hparam["b1"], hparam["b2"]))

In [None]:
true_label = torch.ones((hparam["batch_size"], 1), requires_grad=False).to(device)
false_label = torch.zeros((hparam["batch_size"], 1), requires_grad=False).to(device)

t1 = time()
dis_losses = []
gen_losses = []

k = 0
for epoch in range(hparam["num_epochs"]):
    for i, (data, _) in enumerate(dataloader):
        data = data.to(device)
        
        #  ----------------------------------
        #  Train Generator
        #  ----------------------------------

        gen_optimizer.zero_grad()

        # Sample noise
        z = torch.Tensor(np.random.uniform(0, 1, size=(hparam["batch_size"], hparam["latent_size"]))).to(device)
        
        
        # Generated data
        gen_data = gen_net(z)       
        generator_loss = criterion(dis_net(gen_data), true_label)
        generator_loss.backward(retain_graph=True)
        gen_optimizer.step()
     
        #  ----------------------------------
        #  Train Discriminator
        #  ----------------------------------

        dis_optimizer.zero_grad()

        real_loss = criterion(dis_net(data), true_label)
        fake_loss = criterion(dis_net(gen_data.detach()), false_label) # detach is extremely important
        discriminator_loss = (real_loss + fake_loss)

        discriminator_loss.backward()
        dis_optimizer.step()

        #  -----------------------------------------------------------------
        
        if i % 250 == 0:          
            print(
                "[Epoch %d/%d]\t [Batch %d/%d]\t [Dis loss: %f]\t [Gen loss: %f]"
                % (epoch, hparam["num_epochs"], i, len(dataloader) , discriminator_loss.item(), generator_loss.item()))
            
        
            dis_losses.append(discriminator_loss.item())
            gen_losses.append(generator_loss.item())
            
        if i % 250 == 0:
            save_image(gen_data.data[:36], "generated_images/im" + str(k) + ".png", nrow=6)
            
            
            fig, axs = plt.subplots(2)
            plt.subplots_adjust(hspace=1)
            
            xaxis = [250*i for i in range(k+1)]
            
            axs[0].plot(xaxis, np.array(dis_losses))
            axs[0].set_title("Discriminator Loss")
            axs[0].set_xlabel("Number of Iterations")
            axs[0].set_ylabel("Loss")
            axs[0].grid(True)

            axs[1].plot(xaxis, np.array(gen_losses))
            axs[1].set_title("Generator Loss")
            axs[1].set_xlabel("Number of Iterations")
            axs[1].set_ylabel("Loss")

            axs[1].grid(True)
            plt.savefig("loss_images/loss"+ str(k) +".png", dpi=300)
            
            
            fig.show(False)
            plt.close('all')
            
            k += 1
            
print("\nTraining Time: " + str(time()-t1))