In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision import transforms
from tqdm.auto import tqdm
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


In [None]:
class Generator(nn.Module):
    def __init__(self,in_chans=10,out_chans=64):
        super().__init__()
        self.gen = nn.Sequential(
            self.make_gen_block(in_chans,out_chans*4),
            self.make_gen_block(4*out_chans,2*out_chans,kernel_size=4,stride=1),
            self.make_gen_block(2*out_chans,out_chans),
            self.make_gen_block(out_chans,1,kernel_size=4,final_layer=True)
        )
    def make_gen_block(self,in_chans,out_chans,kernel_size=3,stride=2,final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chans,out_chans,kernel_size,stride),
                nn.BatchNorm2d(out_chans),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chans,out_chans,kernel_size,stride),
                nn.Tanh()
            )
    def forward(self,inp):
        inp = inp.view(inp.shape[0],inp.shape[1],1,1)                
        return self.gen(inp)

def make_noise(num_samples,in_dim,device):
    return torch.randn(num_samples,in_dim,device=device)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,in_chans=1,hid_dim=64):
        super().__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(in_chans,hid_dim),
            self.make_disc_block(hid_dim,hid_dim*2),
            self.make_disc_block(2*hid_dim,1,final_layer=True)
        )
        
    def make_disc_block(self,in_chans,out_chans,kernel_size=3,stride=2,final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_chans,out_chans,kernel_size,stride),
                nn.BatchNorm2d(out_chans),
                nn.LeakyReLU(negative_slope=0.2,inplace=True)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_chans,out_chans,kernel_size,stride),
            )
    def forward(self,inp):
        out = self.disc(inp)
        return out.view(len(out),-1)

In [None]:
tfms = transforms.Compose([transforms.ToTensor(),
                           transforms.Normalize((0.5,),(0.5,))])
dset = MNIST('.',transform=tfms,download=True)
dls = torch.utils.data.DataLoader(dset,batch_size=128,shuffle=True)

In [None]:
if torch.cuda.is_available():device='cuda'
else:device='cpu'

def init(m):
    if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        nn.init.normal_(m.weight)
    if isinstance(m,nn.BatchNorm2d):
        nn.init.normal_(m.weight)
        nn.init.constant_(m.bias,0.)

z_dim = 64
n_classes=10
mnist_chans = 1

in_dim = z_dim+n_classes
disc_dim = mnist_chans+n_classes

gen = Generator(in_chans=in_dim).to(device)
disc = Discriminator(in_chans=disc_dim).to(device)

gen_opt = torch.optim.Adam(gen.parameters(),betas=(0.5,0.999))
disc_opt = torch.optim.Adam(disc.parameters(),betas = (0.5,0.999))

gen=gen.apply(init)
disc=disc.apply(init)


In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()

In [None]:
n_epochs = 100
loss = nn.BCEWithLogitsLoss()
cur_step = 0
display_step =500
generator_losses = []
discriminator_losses = []

for i in range(n_epochs):
    for real,labels in tqdm(dls):
        real = real.to(device)
        
        one_hots = F.one_hot(labels.to(device),num_classes=10)
        noise = make_noise(len(real),z_dim,device)
        # print(noise)
        combi_noise = torch.cat((noise,one_hots),dim=1)
        fake_ims = gen(combi_noise)

        image_oh = one_hots[:,:,None,None].repeat(1,1,real.shape[2],real.shape[3])
        real_in = torch.cat((real,image_oh),dim=1)
        fake_in = torch.cat((fake_ims,image_oh),dim=1)

        preds_real = disc(real_in)
        preds_fake = disc(fake_in.detach())

        disc_opt.zero_grad()
        disc_loss = (loss(preds_real,torch.ones_like(preds_real))+
                     loss(preds_fake,torch.zeros_like(preds_fake)))/2
        disc_loss.backward()
        disc_opt.step()
        discriminator_losses += [disc_loss.item()]

        gen_opt.zero_grad()
        preds =  disc(fake_in)
        gen_loss = loss(preds,torch.ones_like(preds))
        gen_loss.backward()
        gen_opt.step()
        generator_losses += [gen_loss.item()]
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
            show_tensor_images(noise)
            show_tensor_images(real)
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss"
            )
            plt.legend()
            plt.show()
        cur_step += 1


