<a href="https://colab.research.google.com/github/rajaboja/learnings/blob/master/WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.optim import Adam
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
class Generator(nn.Module):
    def __init__(self,z_dim=10,out_chans=64):
        super().__init__()
        self.gen = nn.Sequential(
            self.make_conv_block(z_dim,4*out_chans),
            self.make_conv_block(4*out_chans,2*out_chans,kernel_size=4,stride=1),
            self.make_conv_block(2*out_chans,out_chans),
            self.make_conv_block(out_chans,1,kernel_size=4,final_layer=True)
        )
    def make_conv_block(self,in_chans,out_chans,kernel_size=3,stride=2,final_layer=False):
        if not final_layer: 
            return nn.Sequential(
                nn.ConvTranspose2d(in_chans,out_chans,kernel_size,stride),
                nn.BatchNorm2d(out_chans),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chans,out_chans,kernel_size,stride),
                nn.Tanh()
            )

    def forward(self,inp):
        inp = inp.view(inp.shape[0],inp.shape[1],1,1)
        return self.gen(inp)


def make_noise(num_samples,z_dim,device):
    return torch.randn(num_samples,z_dim,device=device)

In [None]:
class Critic(nn.Module):
    def __init__(self,image_chans=1,out_chans=64):
        super().__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(image_chans,out_chans),
            self.make_crit_block(out_chans,2*out_chans),
            self.make_crit_block(2*out_chans,1,final_layer=True)
        )

    def make_crit_block(self,in_chans,out_chans,kernel_size=4,stride=2,final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_chans,out_chans,kernel_size,stride),
                nn.BatchNorm2d(out_chans),
                nn.LeakyReLU(negative_slope=0.2)
            )
        else:
            return nn.Sequential(nn.Conv2d(in_chans,out_chans,kernel_size,stride))

    def forward(self,inp):
        pred = self.crit(inp)
        return pred.view(len(pred), -1)

In [None]:
z_dim=64
bs=128
if torch.cuda.is_available():device='cuda'
else:device = 'cpu'

tfms = transforms.Compose([transforms.ToTensor(),
                           transforms.Normalize((0.5,),(0.5,))])

dls = DataLoader(MNIST('.',download=True,transform=tfms),batch_size=bs,shuffle=True)


In [None]:
gen = Generator(z_dim=z_dim).to(device)
crit = Critic().to(device)

gen_opt = Adam(gen.parameters(),betas=(0.5,0.999))
crit_opt = Adam(crit.parameters(),betas=(0.5,0.999))

In [None]:
def init(m):
    if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        nn.init.normal_(m.weight,0.,0.02)
    if isinstance(m,nn.BatchNorm2d):
        nn.init.normal_(m.weight,0.,0.02)
        nn.init.constant_(m.bias,0)

gen = gen.apply(init)
crit = crit.apply(init)

In [None]:
def gredient_penalty(crit,real,fake,epsilon):
    mix = epsilon*real+((1-epsilon)*fake)
    mix_score = crit(mix)

    grad = torch.autograd.grad(outputs=mix_score,inputs=mix,
                               grad_outputs=torch.ones_like(mix_score),create_graph=True,retain_graph=True)[0]

    grad = grad.view(len(grad),-1) 
    norm = grad.norm(2,dim=1)
    penalty = torch.mean(torch.pow(norm-1,2))
    return penalty

In [None]:
def get_gen_loss(crit_fake_pred):
    return -torch.mean(crit_fake_pred)
def get_crit_loss(crit_fake_pred,crit_real_pred,gp,weight):
    return torch.mean(crit_fake_pred)-torch.mean(crit_real_pred)+gp*weight

In [None]:
from torchvision.utils import make_grid

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()


In [None]:
n_epochs = 100
crit_reps = 5
c_lambda =10
cur_step = 0
display_step = 500
generator_losses = []
critic_losses =[]

for i in range(n_epochs):
    for real,_ in tqdm(dls):
        
        real = real.to(device)
        mean_iteration_critic_loss = 0
        for _ in range(crit_reps):
            crit_opt.zero_grad()
            noise = make_noise(len(real),z_dim,device)
            gen_out = gen(noise).detach()
            crit_fake = crit(gen_out)
            crit_real = crit(real)

            epsilon = torch.rand(len(real),1,1,1,device=device,requires_grad=True)
            # print(epsilon.shape, real.shape,gen_out.shape)
            gp = gredient_penalty(crit,real,gen_out.detach(),epsilon)
            # print(crit_fake.shape,crit_real.shape)
            crit_loss = get_crit_loss(crit_fake,crit_real,gp,c_lambda)
            mean_iteration_critic_loss += crit_loss.item() / crit_reps

            crit_loss.backward(retain_graph=True)
            crit_opt.step()
        critic_losses+=[mean_iteration_critic_loss]

        gen_opt.zero_grad()
        noise_1 = make_noise(len(real),z_dim,device)
        gen_out1 = gen(noise_1)
        gen_loss = get_gen_loss(crit(gen_out1))
        gen_loss.backward()
        gen_opt.step()

        generator_losses += [gen_loss.item()]

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(gen_out)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1
