<a href="https://colab.research.google.com/github/rajaboja/learnings/blob/master/DC_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from tqdm.auto import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
class Generator(nn.Module):
    def __init__(self,z_dim=10,hid_dim=64,image_chan=1):
        super().__init__()
        self.gen =  nn.Sequential(
            self.make_gen_block(z_dim,4*hid_dim),
            self.make_gen_block(4*hid_dim,2*hid_dim,kernel_size=4,stride=1),
            self.make_gen_block(2*hid_dim,hid_dim),
            self.make_gen_block(hid_dim,image_chan,kernel_size=4,final_layer=True)
        )
    def make_gen_block(self,in_chans,out,kernel_size=3,stride=2,final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chans,out,kernel_size,stride),
                nn.BatchNorm2d(out),
                nn.ReLU()
            )
        else: 
            return nn.Sequential(
                nn.ConvTranspose2d(in_chans,out,kernel_size,stride),
                nn.Tanh()
            )
    def forward(self,inp):
        # print(inp.shape)
        inp = inp.view(inp.shape[0],inp.shape[1],1,1)
        return self.gen(inp)

In [None]:
def make_noise(num_images,z_dim,device):
    return torch.randn((num_images,z_dim),device=device)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,im_chans=1,hid_dim=16):
        super().__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chans,hid_dim),
            self.make_disc_block(hid_dim,2*hid_dim),
            self.make_disc_block(2*hid_dim,1,final_layer=True)
        )
    
    def make_disc_block(self,im_chans,out,kernel_size=4,stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(im_chans,out,kernel_size,stride),
                nn.BatchNorm2d(out),
                nn.LeakyReLU(negative_slope=0.2)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(im_chans,out,kernel_size,stride)
            )

    def forward(self,inp):
        pred = self.disc(inp)
        return pred.view(len(pred), -1)


In [None]:
loss = nn.BCEWithLogitsLoss()
n_epochs=50
lr=1e-3
bs=128
z_dim=64
display_step=500

if torch.cuda.is_available():device='cuda'
else: device='cpu'

In [None]:
tfms = transforms.Compose([transforms.ToTensor(),
                           transforms.Normalize((0.5,),(0.5,))])

dls = DataLoader(MNIST('.',transform=tfms,download=True),batch_size=bs,shuffle=True)

In [None]:
gen = Generator(z_dim).to(device)
disc = Discriminator().to(device)

gen_opt = torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
disc_opt = torch.optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
                      
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):

    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0

for i in range(n_epochs):
    for real,_ in tqdm(dls):
        real = real.to(device)

        disc_opt.zero_grad()
        fake_noise = make_noise(len(real), z_dim, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = loss(disc_fake_pred, torch.zeros_like(disc_fake_pred))

        disc_real_pred = disc(real)
        disc_real_loss = loss(disc_real_pred, torch.ones_like(disc_real_pred))

        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        mean_discriminator_loss += disc_loss.item() / display_step
        
        disc_loss.backward(retain_graph=True)
        disc_opt.step()
        
        gen_opt.zero_grad()
        disc_fake_pred = disc(fake)
        gen_loss = loss(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()        

        mean_generator_loss += gen_loss.item() / display_step

        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1
