In [None]:
import torch
import pdb
from torch.utils.data import DataLoader
from torch import nn 
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import matplotlib.pyplot as plt 

In [None]:
#Visualisation Function
def show(tensor, ch=1, size=(28,28), num=16):
    #tensor: 128 x 784
    data = tensor.detach().cpu().view(-1,ch,*size) 
    grid = make_grid(data[:num], nrow=4).permute(1,2,0)
    plt.imshow(grid)
    plt.show()

In [None]:
#setup of the main parameters and hyperparameters
epochs = 500
current_step = 0
info_step = 1875
mean_generator_loss = 0
mean_discriminator_loss = 0

z_dimension = 64
lr = 0.0001
loss_fn = nn.BCEWithLogitsLoss()

batch_size = 125
device = 'cuda' 

dataloader = DataLoader(MNIST("./data",download=True, transform=transforms.ToTensor()),shuffle=True,batch_size=batch_size)


In [None]:
# Declare our models

# Generator
def generatorBlock(inp, out):
    return nn.Sequential(
        nn.Linear(in_features=inp, out_features=out),
        nn.BatchNorm1d(num_features=out),
        nn.ReLU(inplace=True)
    )

class Generator(nn.Module):
    def __init__(self, z_dimension=64, i_dimension=784, h_dimension=128):
        super().__init__()
        self.gen = nn.Sequential(
            generatorBlock(z_dimension, h_dimension), #64 -> 128
            generatorBlock(h_dimension, h_dimension*2), #128 -> 256
            generatorBlock(h_dimension*2, h_dimension*4), #256 -> 512
            generatorBlock(h_dimension*4, h_dimension*8), #512 -> 1024
            nn.Linear(h_dimension*8, i_dimension), # 1024 -> 784
            nn.Sigmoid()
        )
        
    def forward(self, noise):
        return self.gen(noise)

def gen_noise(number, z_dimension):
    return torch.randn(number, z_dimension).to(device)

# Discriminator
def discriminatorBlock(inp, out):
    return nn.Sequential(
        nn.Linear(in_features=inp, out_features=out),
        nn.LeakyReLU(0.2)
    )

class Discriminator(nn.Module):
    def __init__(self, i_dimension=784, h_dimension=256):
        super().__init__()
        self.disc = nn.Sequential(
            discriminatorBlock(i_dimension, h_dimension*4), #784 * 1024
            discriminatorBlock(h_dimension*4, h_dimension*2),
            discriminatorBlock(h_dimension*2, h_dimension),
            nn.Linear(h_dimension, 1)
        )
        
    def forward(self, image):
        return self.disc(image)

In [None]:
generator = Generator(z_dimension=z_dimension).to(device=device)
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
discriminator = Discriminator().to(device=device)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

In [None]:
x, y = next(iter(dataloader))
print(x.shape, y.shape)
print(y[:10])

noise = gen_noise(batch_size, z_dimension=z_dimension)
fake = generator(noise)
show(fake)

In [None]:
# Calculating the loss

# Generator Loss
def calculate_gen_loss(loss_fn, generator, discriminator, number, z_dimension):
    noise = gen_noise(number=number, z_dimension=z_dimension)
    fake = generator(noise)
    pred = discriminator(fake)
    target = torch.ones_like(pred)
    gen_loss = loss_fn(pred,target)
    
    return gen_loss

def calculate_disc_loss(loss_fn, generator, discriminator, number, real, z_dimension):
    noise = gen_noise(number=number, z_dimension=z_dimension)
    fake = generator(noise)
    discriminator_fake = discriminator(fake.detach())
    disc_fake_targets = torch.zeros_like(discriminator_fake)
    disc_fake_loss = loss_fn(discriminator_fake, disc_fake_targets)

    disc_real = discriminator(real)
    disc_real_targets = torch.ones_like(disc_real)
    disc_real_loss = loss_fn(disc_real, disc_real_targets)

    disc_loss = (disc_fake_loss + disc_real_loss)/2

    return disc_loss

In [None]:
for epoch in range(epochs):
    for real, _ in tqdm(dataloader):
        ## Discriminator
        disc_optimizer.zero_grad()
        
        current_batch_size = len(real) #real: 128
        real = real.view(current_batch_size, -1)
        real = real.to(device)
        
        discriminator_loss = calculate_disc_loss(loss_fn=loss_fn, generator=generator, discriminator=discriminator, number=current_batch_size, real=real, z_dimension=z_dimension)
        
        discriminator_loss.backward(retain_graph=True)
        disc_optimizer.step()
        
        #Generator
        gen_optimizer.zero_grad()
        gen_loss = calculate_gen_loss(loss_fn=loss_fn, generator=generator, discriminator=discriminator, number=current_batch_size, z_dimension=z_dimension)
        gen_loss.backward(retain_graph=True)
        gen_optimizer.step()
        
        #Visualisation and stats
        mean_discriminator_loss += discriminator_loss.item()/info_step
        mean_generator_loss += gen_loss.item()/info_step
        
        if current_step % info_step == 0 and current_step > 0:
            fake_noise = gen_noise(current_batch_size, z_dimension=z_dimension)
            fake = generator(fake_noise)
            show(fake)
            show(real)
            print(f'{epoch}: step {current_step}, generator loss:{mean_generator_loss}, discriminator loss:{mean_discriminator_loss}')
            mean_discriminator_loss, mean_generator_loss = 0, 0
            
        current_step+=1
        break
    break

In [None]:
fake_noise = gen_noise(128, z_dimension=z_dimension)
fake = generator(fake_noise)
show(fake)