In [None]:
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torchvision.utils import make_grid

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt

In [None]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4
batch_size = 64
img_dim =  1*28*28  # 784
epochs = 100
b1 = 0.5
b2 = 0.999
latent_dim=100
img_size = 64

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    #image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu().view(-1,*size) #.view(-1,size) is added
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

#def make_grad_hook():
    '''
    Function to keep track of gradients for visualization purposes, 
    which fills the grads list when using model.apply(grad_hook).
    '''
    #grads = []
    #def grad_hook(m):
        #if isinstance(m, nn.Linear) :
            #grads.append(m.weight.grad)
    #return grads, grad_hook


    #def show_tensor_images(image_tensor,num_images=25,size=(1,28,28)):
    #image_unflat=image_tensor.detach().cpu().view(-1,*size)
    #image_grid=make_grid(image_unflat[:num_images],nrow=5)
    #plt.imshow(image_grid.permute(1,2,0).squeeze())
    #plt.show()


In [None]:
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
class GAN_generator(nn.Module):
     def __init__(self, latent_dim, img_dim):  
         super(GAN_generator, self).__init__()

         self.net = nn.Sequential(
              nn.Linear(latent_dim, 128),
              nn.LeakyReLU(),
              nn.Linear(128,256),
              nn.LeakyReLU(),
              nn.Linear(256,512),
              nn.LeakyReLU(),
              nn.Linear(512,img_dim),
              nn.Tanh()
              ) 
         

     def forward(self,input_tensor):
        return self.net(input_tensor)

In [None]:
gen = GAN_generator(latent_dim, img_dim).to(device)

In [None]:
gen

GAN_generator(
  (net): Sequential(
    (0): Linear(in_features=100, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=256, out_features=512, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=512, out_features=784, bias=True)
    (7): Tanh()
  )
)

In [None]:
class GAN_discriminator(nn.Module):
    def __init__(self, input_dim):
         super(GAN_discriminator, self).__init__()

         self.disc = nn.Sequential(
              nn.Linear(input_dim, 512),
              nn.LeakyReLU(),
              nn.Linear(512,256),
              nn.LeakyReLU(),
              nn.Linear(256,128),
              nn.LeakyReLU(),
              nn.Linear(128,1),
              nn.Sigmoid(),
              )
         

    def forward(self,input_tensor):
      return self.disc(input_tensor)
  

In [None]:
dis = GAN_discriminator(img_dim).to(device)

In [None]:
dis

GAN_discriminator(
  (disc): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=128, out_features=1, bias=True)
    (7): Sigmoid()
  )
)

In [None]:

# Loss function
loss_function = torch.nn.BCELoss()

# Optimizers
optimizer_G = torch.optim.Adam(gen.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(dis.parameters(), lr=lr, betas=(b1, b2))


In [None]:
for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        #print(real.shape)
        cur_batch_size = real.shape[0]
        #print(cur_batch_size)

In [None]:
def get_noise(n_samples,latent_dim,device="cuda"):
    return torch.randn(n_samples,latent_dim,device=device)

In [None]:
def get_disc_loss(generator,discriminator,criterion,real ,num_images,latent_dim,device):
    noise=get_noise(num_images,latent_dim,device=device)
    gen_out=generator(noise)
    disc_fake_out=discriminator(gen_out.detach())
    fake_loss=criterion(disc_fake_out,torch.zeros_like(disc_fake_out))
    disc_real_out=discriminator(real)
    real_loss=criterion(disc_real_out,torch.ones_like(disc_real_out))
    disc_loss=(fake_loss+real_loss)/2
    return(disc_loss)

def get_gen_loss(generator,discriminator,criterion,num_images,latent_dim,device):
    noise=get_noise(num_images,latent_dim,device=device)
    gen_out=generator(noise)
    disc_out=discriminator(gen_out)
    loss=criterion(disc_out,torch.ones_like(disc_out))
    return loss

In [None]:
# Training

# for tensorboard plotting
#fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/GAN_MNIST/real")
writer_fake = SummaryWriter(f"logs/GAN_MNIST/fake")
step = 0
display_step = 50 




#Going into training mode 
gen.train()
dis.train()

generator_losses = []
discriminator_losses = []

for epoch in range(epochs):

    for batch_idx, (real, _) in enumerate(dataloader):
       
        cur_batch_size = real.shape[0]
        real = real.view(cur_batch_size, -1).to(device)

        # Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        #noise = torch.randn(cur_batch_size, latent_dim).to(device)
        #fake = gen(noise)

        #disc_real = dis(real).view(-1)
        #lossD_real = loss_function(disc_real, torch.ones_like(disc_real))

        #disc_fake = dis(fake).view(-1)
        #lossD_fake = loss_function(disc_fake, torch.zeros_like(disc_fake))

        #loss_dis = (lossD_real + lossD_fake) / 2

        #dis.zero_grad()
        #loss_dis.backward(retain_graph=True)
        #optimizer_D.step()
        optimizer_D.zero_grad()
        disc_loss=get_disc_loss(gen,dis,loss_function,real,cur_batch_size,latent_dim,device)
        disc_loss.backward(retain_graph=True)
        optimizer_D.step()

        discriminator_losses +=[disc_loss.item()]



        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))

        optimizer_G.zero_grad()
        gen_loss=get_gen_loss(gen,dis,loss_function,cur_batch_size,latent_dim,device)
        gen_loss.backward()
        optimizer_G.step()

        

       
        
        #gen_fake = dis(fake).view(-1)
        #loss_gen = loss_function(gen_fake, torch.ones_like(gen_fake))

        #gen.zero_grad()
        #loss_gen.backward()
        #optimizer_G.step()

        generator_losses +=[gen_loss.item()]

        # Print losses occasionally and print to tensorboard
        if batch_idx % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (epoch, epochs, batch_idx, len(dataloader),
                     disc_loss.item(), gen_loss.item()) )

            fake_noise = get_noise(cur_batch_size, latent_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)

            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins

            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss")
            
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss")
            
            plt.legend()
            plt.show()
  

           

      
        step += 1