<a href="https://colab.research.google.com/github/vijayshankarrealdeal/intro_to_pytorch-Gans/blob/main/WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Wasserstein GAN with Gradient Penalty

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7f4a6c106d50>

In [None]:
def show_image(image_tensor,num_images = 25,size = (1,28,28)):
  image_tensor = (image_tensor + 1)/2
  image_unflat = image_tensor.detach().cpu()
  image_grid = make_grid(image_unflat[:num_images],nrow=5)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [None]:
def make_grad_hook():
  grads = []
  def grad_hooks(m):
    if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
      grads.append(m.weights.grad)
    
  return grads,grad_hooks


In [None]:
class Generator(nn.Module):
  def __init__(self,z_dim = 10,img_cha = 1,hidden_units = 64):
    super(Generator,self).__init__()
    self.z_dim = z_dim
    self.gen = nn.Sequential(
        self.make_gen_block(z_dim,hidden_units*4),
        self.make_gen_block(hidden_units*4,hidden_units*2,kernel_size=4, stride=1),
        self.make_gen_block(hidden_units*2,hidden_units),
        self.make_gen_block(hidden_units,img_cha,kernel_size=4, stride=1,final_layer = True),
    )

  def make_gen_block(self,input_units,output_units,kernel_size = 3 ,stride = 2,final_layer = False):
    if not final_layer:
      return nn.Sequential(
          nn.ConvTranspose2d(input_units,output_units,kernel_size,stride),
          nn.BatchNorm2d(output_units),
          nn.ReLU(inplace = True),
      )
    else:
      return nn.Sequential(
          nn.ConvTranspose2d(input_units,output_units,kernel_size,stride),
          nn.Tanh()
      )
  
  def forward(self,noise):
    x = noise.view(len(noise),self.z_dim,1,1)
    return self.gen(x)

In [None]:
def get_noise(n_samples,z_dim,device = 'cpu'):
  return torch.randn(n_samples,z_dim,device = device)

In [None]:
class Critic(nn.Module):
  def __init__(self,im_chan = 1, hidden_dim = 64):
    super(Critic,self).__init__()
    self.critic = nn.Sequential(
          self.make_crit_block(im_chan, hidden_dim),
          self.make_crit_block(hidden_dim, hidden_dim * 2),
          self.make_crit_block(hidden_dim * 2, 1, final_layer=True),
    )
  
  def make_crit_block(self,input_channels,output_channels,kernel_size=4, stride=2, final_layer=False):
    if not final_layer:
      return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
    else:
      return nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size, stride),
          )
    
  def forward(self,image):
    cred_predict = self.critic(image)
    return cred_predict.view(len(cred_predict),-1)

In [None]:
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [None]:
gen = Generator(z_dim).to(device)
gen_optimizer = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
critic = Critic().to(device)
critic_optimizer = torch.optim.Adam(critic.parameters(),lr = lr,betas=(beta_1, beta_2))


In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = critic.apply(weights_init)