In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam 

import numpy as np
import matplotlib.pyplot as plt


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

latent = 100
img_size = 28

# discriminator 
disc_hidden = [img_size*img_size, 1000, 500, 200]

# generator
gen_hidden = [latent, 200, 500, 1000]

In [None]:
train_set = datasets.MNIST(root='../mnist_data', train=True,  download=True)


In [None]:
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))
])

In [None]:
class custom_mnist(Dataset):
    def __init__(self, input_data):
        super().__init__()
        self.data = input_data
        

    def __getitem__(self, idx):
        return transform(self.data[idx][0])

    def __len__(self):
        return len(self.data)

In [None]:
mnist_train = custom_mnist(train_set)
plt.imshow(mnist_train[0][0])

In [None]:
class Discriminator(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.h1 = hidden
        self.h2 = hidden[1:] + [1]

        self.net = []

        for l1, l2 in zip(self.h1, self.h2):
            self.net.append(nn.Linear(l1,l2))
            self.net.append(nn.LeakyReLU(0.2))
            self.net.append(nn.Dropout(p=0.5))

        self.net = self.net[:-2]
        self.net = nn.ModuleList(self.net)

    def forward(self,x):
        x = torch.reshape(x, (-1,img_size*img_size))
        out = nn.Sequential(*self.net)(x)
        return torch.sigmoid(out)



class Generator(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.h1 = hidden
        self.h2 = hidden[1:] + [img_size*img_size]

        self.net = []

        for l1, l2 in zip(self.h1, self.h2):
            self.net.append(nn.Linear(l1,l2))
            self.net.append(nn.LeakyReLU(0.2))

        self.net = self.net[:-1]
        self.net = nn.ModuleList(self.net)

    def forward(self,x):
        out = nn.Sequential(*self.net)(x)
        out =  torch.tanh(out)
        return torch.reshape(out, (-1,1,img_size,img_size))

In [None]:
## Very very important for sampling
# torch.rand = samples from uniform dist
# torch.randn = samples from normal dist - correct (use this one)


def generate_fake_data(generator,num_data):
    x = torch.randn(num_data, latent).to(device)
    out = generator(x)
    return out.to(device), torch.zeros(num_data,1).to(device)


def generate_real_data(data):
    return data.to(device), torch.ones(len(data),1).to(device)


def display_generated_images(generator):
    images,_ = generate_fake_data(generator, 16)
    images = images.detach().to('cpu')
    f, axarr = plt.subplots(4,4)

    for i in range(4):
        for j in range(4):
            axarr[i,j].imshow(images[4*i+j][0])
            axarr[i,j].axis('off')

    plt.show()

In [None]:
disc = Discriminator(disc_hidden).to(device)
gen = Generator(gen_hidden).to(device)

print(disc)
print(gen)

In [None]:

disc_optim = Adam(disc.parameters(), lr = 0.0001)
gen_optim = Adam(gen.parameters(), lr = 0.0001)

criterion = nn.BCELoss().to(device)
num_epochs = 40
batch = 64 # take 64 from Pdata and 64 from Pgen

disc_loss_hist = []
gen_loss_hist = []


for epoch in range(num_epochs):
    train_loader = DataLoader(mnist_train, batch_size=batch, shuffle=True)


    for idx,data in enumerate(train_loader):

        ## Train discriminator

        # create training data for discriminator
        pdata,label_data = generate_real_data(data)
        pgen, label_gen = generate_fake_data(gen, batch)

        fin_data = torch.cat([pdata, pgen], dim=0)
        fin_out = torch.cat([label_data, label_gen], dim=0)
        # fin_out = torch.abs(fin_out - 0.1)

        out = disc(fin_data)
        d_loss = criterion(out,fin_out)


        out = disc(pgen)
        disc_optim.zero_grad()
        d_loss.backward()
        disc_optim.step()

        disc_loss_hist.append(d_loss)

        
        ## Train generator

        # For generator, for fooling the discriminator keep generated data points output as 1
        pgen, label_gen = generate_fake_data(gen, 2*batch)
        label_gen = torch.ones(label_gen.shape).to(device)
        out = disc(pgen)

        g_loss = criterion(out,label_gen)
        gen_optim.zero_grad()
        g_loss.backward()
        gen_optim.step()


        gen_loss_hist.append(g_loss)


    print(epoch)
    display_generated_images(gen)
    plt.plot(disc_loss_hist, label='disc')
    plt.plot(gen_loss_hist, label='gen')
    plt.legend(loc='best')
    plt.show()


In [None]:
a = torch.randn(1,latent).to(device)
b = torch.randn(2,latent).to(device)

In [None]:
plt.imshow(gen(a)[0][0].detach().to('cpu'))

In [None]:
plt.imshow(gen(b)[0][0].detach().to('cpu'))

In [None]:
len = 10

f, axarr = plt.subplots(1,len)
for idx,i in enumerate(np.linspace(0,1,len)):
    
    vec = i * a + (1-i)*b
    axarr[idx].imshow(gen(vec)[0][0].detach().to('cpu'))
    axarr[idx].axis('off')
    
plt.show()