## *NECESSARY IMPORTS* #

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import  transforms,datasets 
from torchsummary import summary
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from time import time

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## *DATA IMPORTS AND DEFINING THE DATALOADERS* #

In [None]:
transform = transforms.Compose([transforms.Resize(300),transforms.CenterCrop(64),transforms.ToTensor()])

In [None]:
data_dir = 'data'
dataset = datasets.ImageFolder(data_dir,transform=transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=16)

In [None]:
print(len(train_loader))

## *DISPLAYING THE IMAGES*

In [None]:
figsize = (16,16)
def PlotBatch(dataloader):  # prints the image of a batch in dataloader
  sample_data = next(iter(dataloader))[0].to(device)
  plt.figure(figsize=figsize)
  plt.axis('off')  # it is written not to print axis, instead it will make a white border around the images
  plt.title("IMAGES")
  plt.imshow(np.transpose(torchvision.utils.make_grid(
      sample_data, normalize = False
  ).cpu(), (1,2,0))) # (1,2,0) is for first showing the images correctly

In [None]:
PlotBatch(train_loader)

## *DEFINING THE NETWORKS* 
### * *GENERATOR NETWORK* 
### * *DISCRIMINATOR NETWORK* 

## *GENERATOR NETWORK* #

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # input will be a noise matrix array of (batch_size,100) 
        self.linear = nn.Linear(100,512*16).to(device)
        # input_shape is (512,4,4)
        self.model = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # input_shape is (512,4,4)
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # input_shape is (256,8,8)
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # input_shape is (128,16,16)
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # input_shape is (64,32,32)
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
            # output_shape is (3,64,64)
        ).to(device)
        
    def forward(self, x): # x is noise of size (batch_size,100) 
        op = self.linear(x)
        op = op.reshape(-1,512,4,4)
        op = self.model(op)
        return op

In [None]:
generator = Generator().to(device)
summary(generator, (100,)) 

## *DISCRIMINATOR NETWORK* 

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # image_shape is (3,64,64)
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), 
            nn.LeakyReLU(0.2),
            
            # image_shape is (16,32,32)
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            # image_shape is (32,16,16)
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            # image_shape is (64,8,8)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            # image_shape is (128,4,4)
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            #image_shape is (256,2,2)
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
            #output_shape is (512,1,1)
            
        ).to(device)
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(512, 1),
            nn.Sigmoid()
        ).to(device)
        
    def forward(self, x):
        op = self.model(x)
        op = op.reshape(-1,512)
        op = self.classifier(op)
        return op

In [None]:
discriminator = Discriminator().to(device)
summary(discriminator,(3,64,64))

## *HYPERPARAMETERS, LOSS FUNCTION AND OPTIMIZER*

In [None]:
lr = 2e-4
batch_size = 16
epochs = 2
generator = Generator().to(device)
discriminator = Discriminator().to(device)
gen_optim = optim.Adam(generator.parameters(), lr=lr)
dis_optim = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()

In [None]:
def train(epochs):
    discriminator_loss = []
    generator_loss = []
    for epoch in range(1,epochs+1):
        total_dis_loss = 0
        total_gen_loss = 0
        for i, batch in enumerate(train_loader,1):
            # training the discrminator model
            target_for_true = torch.ones(batch[0].shape[0]).to(device)
            target_for_fake = torch.zeros(batch[0].shape[0]).to(device)
            # passing the true data and fake data through the discriminator model with their actual targets as real or fake    
            # forward propoagation
            model_input_for_true_data = (batch[0]/255.0).to(device)
            model_output_for_true_data = discriminator(model_input_for_true_data)
            model_input_for_fake_data = torch.randn(batch[0].shape[0],100).to(device)
            model_output_for_fake_data = discriminator(generator(model_input_for_fake_data))
            model_output_for_true_data = model_output_for_true_data.reshape(batch[0].shape[0])
            model_output_for_fake_data = model_output_for_fake_data.reshape(batch[0].shape[0])
            loss_of_true_data = criterion(model_output_for_true_data, target_for_true)
            loss_of_fake_data = criterion(model_output_for_fake_data, target_for_fake)
            loss_of_dis_model = (loss_of_true_data+loss_of_fake_data)/2
            total_dis_loss+= (loss_of_dis_model).item()
            # backpropoagation
            dis_optim.zero_grad()
            loss_of_dis_model.backward()
            dis_optim.step()
            
            
            # now the fake data will be passed through the generator model as having its target as real
            # forward propoagation
            model_input_for_fake_data2 = torch.randn(batch[0].shape[0],100).to(device)
            model_output_for_fake_data2 = discriminator(generator(model_input_for_fake_data2))
            model_output_for_fake_data2 = model_output_for_fake_data2.reshape(batch[0].shape[0])
            target_for_true2 = torch.ones(batch[0].shape[0]).to(device)
            loss_of_gen_model = criterion(model_output_for_fake_data2,target_for_true2)
            total_gen_loss+= (loss_of_gen_model).item()
            # backward propoagation
            gen_optim.zero_grad()
            loss_of_gen_model.backward()
            gen_optim.step()
            
            if i%100==0:
                with torch.no_grad():
                    noise = torch.randn(1,100).to(device)
                    noise = generator(noise)
                    noise = noise.squeeze(0)
                    noise = noise.detach().cpu()
                    plt.imshow(np.transpose(noise.cpu(), (1,2,0)))
                    plt.show()
        
        discriminator_loss.append(total_dis_loss/len(train_loader))
        generator_loss.append(total_gen_loss/len(train_loader))
        print("[%d/%d], dis_loss = %f, gen_loss = %f" % (epoch,epochs,discriminator_loss[-1], generator_loss[-1]))                           
    return discriminator_loss, generator_loss     
            

In [None]:
dis_loss, gen_loss = train(epochs)

In [None]:
plt.plot(range(1, epochs+1), dis_loss, label = "Discriminator loss")
plt.plot(range(1, epochs+1), gen_loss, label = "Generator loss")
plt.xlabel('Number of epochs')
plt.ylabel('Loss')
plt.title("dis_loss vs gen_loss")
plt.legend()
plt.show()