<a href="https://colab.research.google.com/github/vijayshankarrealdeal/intro_to_pytorch-Gans/blob/main/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [56]:
import torch 
import torch.nn as nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [57]:
def show_image(image_tensor,num_images = 25,img_size = (1,28,28)):

  image_tensor = (image_tensor + 1)/2
  image_unflat = image_tensor.detach().cpu()
  image_grid = make_grid(image_unflat[:num_images], nrow=5)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [58]:
class Generator(nn.Module):
  def __init__(self,z_dim = 10,img_chan = 1,hidden_dim = 64):
    super(Generator,self).__init__()
    self.z_dim = z_dim
    self.gen = nn.Sequential(
        self.make_gen_block(z_dim,hidden_dim*4),
        self.make_gen_block(hidden_dim*4,hidden_dim*2,kernel_size=4, stride=1),
        self.make_gen_block(hidden_dim*2,hidden_dim),
        self.make_gen_block(hidden_dim,img_chan,kernel_size=4, stride=1,final_layer = True)
    )

  def make_gen_block(self,input_dim,output_dim,kernel_size = 3,stride = 2,final_layer = False):
    if not final_layer:
      return nn.Sequential(
          nn.ConvTranspose2d(input_dim,output_dim,kernel_size,stride),
          nn.BatchNorm2d(output_dim),
          nn.ReLU(inplace=True)
      )
    else:
      return nn.Sequential(
           nn.ConvTranspose2d(input_dim,output_dim,kernel_size,stride),
           nn.Tanh()
      )

  def unsqueeze_noise(self,noise):
    return noise.view(len(noise),self.z_dim,1,1)
  
  def forward(self,noise):
    x = self.unsqueeze_noise(noise)
    return self.gen(x)

In [59]:
def noise_vector(num_samples,z_dim,device = 'cpu'):
  return torch.randn(num_samples,z_dim,device = device)

In [60]:
class Discriminator(nn.Module):
  def __init__(self,image_ch = 1,hidden_units = 16):
    super(Discriminator,self).__init__()
    self.disc = nn.Sequential(
        self.make_disc_block(image_ch,hidden_units),
        self.make_disc_block(hidden_units,hidden_units*2),
        self.make_disc_block(hidden_units*2,hidden_units*4),
        self.make_disc_block(hidden_units*4,1,final_layer = True),
    )
  
  def make_disc_block(self,input_units,output_units,final_layer = False,kernel_size=4, stride=2):
    if not  final_layer:
      return nn.Sequential(
          nn.Conv2d(input_units,output_units,kernel_size, stride),
          nn.BatchNorm2d(output_units),
          nn.LeakyReLU(0.1))
    else:
      return nn.Sequential(
      nn.Conv2d(input_units,output_units,kernel_size, stride)
      )
  def forward(self,image):
    disc_pred = self.disc(image)
    return disc_pred.view(len(disc_pred),-1)

In [61]:
z_dim = 64
display_step = 500
batch_size = 128
beta1 = 0.5
lr = 0.002
beta2 = 0.9999
device = 'cuda'

In [62]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [63]:
loss_function = nn.BCEWithLogitsLoss()
gen = Generator(z_dim).to(device)
gen_optimizer = torch.optim.Adam(gen.parameters(),lr=lr,betas=(beta1,beta2))
disc = Discriminator()
disc_optimizers = torch.optim.Adam(disc.parameters(),lr = lr,betas=(beta1,beta2))

In [64]:
def weights_init(m):
  if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
    nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


In [65]:
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [66]:
n_epochs = 50
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
  for real,_ in tqdm(dataloader):
    cur_batch_size = len(real)
    real = real.to(device)

    disc_optimizers.zero_grad()
    fake_noise = noise_vector(cur_batch_size,z_dim,device = device)
    fake = gen(fake_noise)
    
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = loss_function(disc_fake_pred,torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = loss_function(disc_real_pred,torch.ones_like(disc_real_pred))

    disc_loss = (disc_fake_loss + disc_real_loss)/2
    mean_discriminator_loss += disc_loss.item() / display_step

    disc_loss.backward(retain_graph=True)
    disc_optimizers.step()

    gen_optimizer.zero_grad()
    fake_noise_2 = noise_vector(cur_batch_size,z_dim,device = device)
    gen_fake = gen(fake_noise_2)
    disc_fake_pred = disc(gen_fake)
    gen_loss = loss_function(disc_fake_pred,torch.ones_like(disc_fake_pred))

    gen_loss.backward()
    gen_optimizer.step()
    mean_generator_loss += gen_loss.item() / display_step

        ## Visualization code ##
    if cur_step % display_step == 0 and cur_step > 0:
      print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
      show_tensor_images(fake)
      show_tensor_images(real)
      mean_generator_loss = 0
      mean_discriminator_loss = 0
    cur_step += 1



    
    

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

RuntimeError: ignored