In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os, sys
sys.path.append("/content/drive/MyDrive/HW5/2_GANS/5_2_2_implementations/MNIST")
os.makedirs("/content/drive/MyDrive/HW5/2_GANS/5_2_2_implementations/MNIST/checkpoint/",exist_ok = True)

In [None]:
# Reference :  https://deeplearning.cs.cmu.edu/S20/document/recitation/recitation13.pdf
import torch
from torch import nn
import matplotlib.pyplot as plt
from tqdm import trange, tqdm



from Discriminator import Discriminator
from Generator import Generator
from data_loader import train_loader
from gan_loss_template import gan_loss_discriminator
from gan_loss_template import non_saturating_gan_loss_generator
from gan_loss_template import  wgan_loss_discriminator
from gan_loss_template import wgan_gradient_penalty
from config import lr, batch_size, latent_dim, n_critic, n_epochs, device


In [None]:
print(f"lr : {lr}, batch_size :{batch_size}, latent_dim :{latent_dim}, device : {device}, n_epochs :{n_epochs}")

In [None]:
# ----------------------------------------------
#  Models : discriminator and generatore 
# ----------------------------------------------
discriminator    = Discriminator()
generator        = Generator(latent_dim)

generator.to(device)
discriminator.to(device)

print(discriminator)
print(generator)

# ---------------
# optimizers : Adam 
# ----------------

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr,betas=(0.5,0.999))
optimizer_generator     = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5,0.999))


In [None]:
# ------------
# training 
# --------------
for epoch in range(n_epochs):

    for n, (real_samples,_) in enumerate(train_loader):

        real_samples = real_samples.to(device)

        # ---------------------
        #  Training the  Discriminator
        # ---------------------

        # Sampling noise for the generator
        z = torch.normal(0,1,(batch_size,latent_dim)).to(device)

        # Generate a batch of fake images
        fake_samples = generator(z)

        optimizer_discriminator.zero_grad()

        #loss_discriminator = gan_loss_discriminator(discriminator,generator,real_samples,z)

        loss_discriminator = wgan_loss_discriminator(discriminator, generator, real_samples, z)

        loss_discriminator += wgan_gradient_penalty(discriminator, real_samples, fake_samples)

        loss_discriminator.backward()

        optimizer_discriminator.step()

        # -----------------
        #  Training the Generator
        # -----------------
        if (n+1) % n_critic == 0 :
            
          optimizer_generator.zero_grad()

          #loss_generator = non_saturating_gan_loss_generator(discriminator,generator,z)

          fake_samples = generator(z)

          # can I fool the discreminator
          disc_fakes = discriminator(fake_samples)

          loss_generator = -torch.mean(disc_fakes)

          loss_generator.backward()

          optimizer_generator.step()


        # Printing the loss to display 
        if  n == batch_size - 1 :    

            # -----------------
            #  checkpointing the best models
            # ------------------

            print(f" Epoch : {epoch+1}, D Loss : {loss_discriminator} , G Loss : {loss_generator}")

            PATH = "/content/drive/MyDrive/HW5/2_GANS/5_2_2_implementations/MNIST/checkpoint"+f"/generator_{epoch+1}.pt"
            torch.save(generator.state_dict(), PATH)

            PATH = "/content/drive/MyDrive/HW5/2_GANS/5_2_2_implementations/MNIST/checkpoint"+f"/discriminator_{epoch+1}.pt"
            torch.save(discriminator.state_dict(), PATH)


        del real_samples
        del z
        torch.cuda.empty_cache()


In [None]:
# ----------------
# Loading the models 
# ------------------

generator   = Generator(latent_dim)
PATH = "/content/drive/MyDrive/HW5/2_GANS/5_2_2_implementations/MNIST/checkpoint"+"/generator.pt"
generator.load_state_dict(torch.load(PATH))
generator.to(device)

generator.eval()
import numpy as np


z = torch.randn(batch_size, latent_dim).to(device)
generated_samples = generator(z)
generated_samples = generated_samples.cpu().detach().numpy()
for i in range(6):
	    ax = plt.subplot(2, 3, i + 1)
	    plt.imshow(generated_samples[i].reshape(28, 28), cmap="gray")
	    plt.xticks([])
	    plt.yticks([])
	    plt.pause(1)

