In [None]:
# 1. imports

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from tqdm.auto import tqdm

In [None]:
# 2. Helper functions

def getDevice():
  if torch.cuda.is_available():
    device = 'cuda'
  else:
    device = 'cpu'
  return device


def show(tensor, num):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow=4).permute(1,2,0)
  plt.imshow(grid.clip(0,1))
  plt.show()

def show2(tensor, num):
  data = tensor.detach().cpu()
  data = data.view(-1, 28, 28).unsqueeze(1)
  grid = make_grid(data[:num], nrow=4).permute(1,2,0)
  print(grid.shape)
  plt.imshow(grid)
  plt.show()

def encodeOneHot(labels, class_dim):
  one_hot = torch.nn.functional.one_hot(labels, class_dim)
  return one_hot

In [None]:
# 2. hyperparameters

device = getDevice()
noise_dim = 100
label_dim = 10
batch_size = 64
img_size = 64
img_dim = (img_size, img_size)
img_ch = 1 # greyscale
lr = 1e-4
epochs = 100
save_step = 900


In [None]:
# 5. Generator

# a. Generator model

def genBlock(inp, out, f, s, p):
  return nn.Sequential(
      nn.ConvTranspose2d(inp, out, f, s, p),
      nn.BatchNorm2d(out),
      nn.ReLU(True)
  )

class Generator(nn.Module):
  def __init__(self, labels_dim, embed_size=10, z_dim=noise_dim, d_dim=16):
    super(Generator, self).__init__()
    self.z_dim = z_dim
    self.embed_size = embed_size
    self.labels_dim = labels_dim
    self.embed = nn.Embedding(labels_dim, embed_size)

    self.gen = nn.Sequential(
        #genBlock(z_dim + embed_size, d_dim*32, 4, 1, 0), # 4x4 (512)
        #genBlock(z_dim + embed_size, d_dim*32, 4, 2, 1), # 8x8 (256)
        genBlock(z_dim + embed_size, d_dim*16, 4, 1, 0), # 16x16 (128)
        genBlock(d_dim*16, d_dim*8, 4, 2, 1), # 32x32 (64)
        genBlock(d_dim*8, d_dim*4, 4, 2, 1), # 64x64 (32)
        genBlock(d_dim*4, d_dim*2, 4, 2, 1), # 128x128 (3)
        genBlock(d_dim*2, img_ch,  4, 2, 1),
        nn.Tanh() # [-1,1]
    )

  def forward(self, noise, labels):
    #noise = noise.view(len(noise), self.z_dim, 1, 1) # batch_size x 200 x1 x1
    noise = noise.unsqueeze(2).unsqueeze(3)
    embedding = self.embed(labels).view(-1, self.embed_size, 1, 1)
    noise = torch.cat([noise, embedding], dim = 1)
    return self.gen(noise)


# b. noise generator

def gen_noise(batch_size, z_dim, device):
  device = getDevice()
  return torch.randn(batch_size, z_dim, device=device)  # batch_size x 200 x1x1

In [None]:
gen = Generator(label_dim).to(device)
noise = gen_noise(batch_size, noise_dim, device)
label = gen_labels(batch_size, label_dim, device)

#print(noise.shape)
#noise = noise.unsqueeze(2).unsqueeze(3)
noise.to(device)
label.to(device)
gen(noise, label).shape

torch.Size([64, 1, 64, 64])

In [None]:
# c. Critic

def discBlock(inp, out, f, s, p):
  return nn.Sequential(
      nn.Conv2d(inp, out, f, s, p),
      nn.InstanceNorm2d(out),
      nn.LeakyReLU(0.2),
      #nn.Dropout(0.5)
  )

class Critic(nn.Module):
  def __init__(self, img_size, labels_dim, d_dim=16):
    super(Critic, self).__init__()
    self.img_size = img_size
    self.labels_dim = labels_dim
    self.embed = nn.Embedding(labels_dim, img_size**2)

    # input: b_s x 1 x 28 x 28
    self.crit = nn.Sequential(
        discBlock(img_ch + 1, d_dim, 4, 2, 1), # 64 x 64 (16)
        discBlock(d_dim, d_dim*2, 4, 2, 1), # 32 x 32 (32)
        discBlock(d_dim*2, d_dim*4, 4, 2, 1), # 16 x 16 (64)
        discBlock(d_dim*4, d_dim*8, 4, 2, 1), # 8 x 8 (128)
        #discBlock(d_dim*8, d_dim*16, 4, 2, 1), # 4 x 4 (256)
        nn.Conv2d(d_dim*8, 1, 4, 2, 0), # 1 x 1 (1)
    )

  def forward(self, image, labels):
    embedding = self.embed(labels).view(-1, 1, self.img_size, self.img_size)
    image = torch.cat([image, embedding], dim=1)
    crit_pred = self.crit(image)                # 128 x 1 x 1 x 1
    return crit_pred.view(len(crit_pred), -1)   # 128 x 1

def gen_labels(batch_size, labels_num, device):
  target = torch.randint(0, labels_num, (batch_size,), device = device)

  return target

In [None]:
gen = Generator(label_dim).to(device)
noise = gen_noise(batch_size, noise_dim, device)
label = gen_labels(batch_size, label_dim, device)
fake = gen(noise, label)

crit = Critic(img_size, label_dim).to(device)
crit(fake, label).shape

torch.Size([64, 1])

In [None]:
tranformation = transforms.Compose([
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    #transforms.Normalize((0.1307,), (0.3081,))
])

ds = datasets.MNIST(
    root='data',
    download=True,
    transform=tranformation
)

dataloader = DataLoader(
    dataset=ds,
    batch_size=batch_size,
    shuffle=True
)

In [None]:
# 4. model initializer & optimizers

gen = Generator(label_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5,0.9))

crit = Critic(img_size, label_dim).to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(0.5,0.9))

#criterion = nn.BCEWithLogitsLoss()

In [None]:
# Real/Fake labels and test Noise for visualization
real_labels = torch.ones(batch_size, 1).int().to(device)
fake_labels = torch.zeros(batch_size, 1).int().to(device)
#print(f'real labels: {real_labels}')

test_z = gen_noise(10, noise_dim, device)
print(f'test_z: {test_z.shape}')
test_labels = torch.tensor(np.arange(0,10))
print(f'test_y: {test_labels.shape}')
#print(real_labels.shape)

test_z: torch.Size([10, 100])
test_y: torch.Size([10])


In [None]:
# save and load checkpoints

#root_path='./data/'

from google.colab import drive
drive.mount('/content/gdrive')
root_path='/content/gdrive/MyDrive/Colab/models/'

def save_checkpoint(name):
  torch.save({
      'epoch': epoch,
      'model_state_dict': gen.state_dict(),
      'optimizer_state_dict': gen_opt.state_dict()
  }, f'{root_path}G-{name}.pkl')

  torch.save({
      'epoch': epoch,
      'model_state_dict': crit.state_dict(),
      'optimizer_state_dict': crit_opt.state_dict()
  }, f'{root_path}C-{name}.pkl')

  print('Saved checkpoint')

def load_checkpoint(name):
  checkpoint = torch.load(f'{root_path}G-{name}.pkl')
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  checkpoint = torch.load(f'{root_path}C-{name}.pkl')
  crit.load_state_dict(checkpoint['model_state_dict'])
  crit_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  print('Loaded checkpoint')

In [None]:
# 9. Calculating loss

# calculate gradient penalty

def get_gp(real, fake, labels, crit, alpha):

  #real = real.view(len(real), -1)
  interpolated_imgs = alpha * real + (1-alpha) * fake
  crit_scores = crit(interpolated_imgs, labels)

  c_grad = torch.autograd.grad(
      inputs = interpolated_imgs,
      outputs = crit_scores,
      grad_outputs = torch.ones_like(crit_scores),
      create_graph = True,
      retain_graph = True
  )[0]

  c_grad = c_grad.view(len(c_grad), -1)

  g_norm = c_grad.norm(2, dim=1)
  #print(c_grad.shape)
  gp = ((g_norm - 1)**2).mean()
  #print(gp)
  return gp

In [None]:
def get_gp(real, fake, labels, crit, alpha, gamma=10):
  mix_images = real * alpha + fake * (1-alpha) # 128 x 3 x 128 x 128
  mix_scores = crit(mix_images, labels) # 128 x 1

  gradient = torch.autograd.grad(
      inputs = mix_images,
      outputs = mix_scores,
      grad_outputs=torch.ones_like(mix_scores),
      retain_graph=True,
      create_graph=True,
  )[0] # 128 x 3 x 128 x 128

  gradient = gradient.view(len(gradient), -1)   # 128 x 49152
  gradient_norm = gradient.norm(2, dim=1)
  gp = gamma * ((gradient_norm-1)**2).mean()

  return gp

In [None]:
# image save location
import os

visual_dir = 'visual_test_generated'

if not os.path.exists(visual_dir):
  os.mkdir(visual_dir)

In [None]:
gen_loss = []
disc_loss = []

gen_losses=[]
crit_losses=[]
crit_cycles = 5
random = 1

for epoch in range(100):
  cur_step = 0
  g_loss = 0

  for real, labels in tqdm(dataloader):

    real = real.to(device)
    labels = labels.to(device)
    cur_bs = int(len(real))

    # Discriminator Training
    mean_crit_loss = 0
    for _ in range(crit_cycles):
      # 1. set zero gradients
      crit_opt.zero_grad()

      # 2. get predictions
      noise = gen_noise(cur_bs, noise_dim, device)
      f_labels = gen_labels(cur_bs, label_dim, device)

      fake = gen(noise, f_labels)

      crit_real_pred = crit(real, labels)
      crit_fake_pred = crit(fake.detach(), f_labels)  # or labels ?

      alpha = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
      #gp = get_gp(real, fake.detach(), labels, crit, alpha, gamma=10)
      gamma= 10
      gp = get_gp(real, fake.detach(), labels, crit, alpha, gamma)

      crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp

      mean_crit_loss+=crit_loss.item() / crit_cycles

      crit_loss.backward(retain_graph=True)
      crit_opt.step()

    crit_losses+=[mean_crit_loss]

    # Generator training

    # 1. set gradients to zero
    gen_opt.zero_grad()

    # 2. get generated image
    noise = gen_noise(cur_bs, noise_dim, device)
    f_labels = gen_labels(cur_bs, label_dim, device)
    fake = gen(noise, f_labels)

    crit_fake_pred = crit(fake, f_labels)

    # 3. update weights
    gen_loss = -crit_fake_pred.mean()
    gen_loss.backward()
    gen_opt.step()

    gen_losses+=[gen_loss.item()]

    # Loss Log
    show_step = 250
    if cur_step % 250 == 0 and cur_step> 0:

        gen_mean=sum(gen_losses[-show_step:]) / show_step
        crit_mean = sum(crit_losses[-show_step:]) / show_step
        print(f"Epoch: {epoch}: Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")

    if cur_step % 900 == 0 and cur_step> 0:
      test_z = test_z.to(device)
      test_labels = test_labels.to(device)
      generated = gen(test_z, test_labels)
      #print(generated.shape)
      show(generated, 10)


    if cur_step % save_step == 0 and cur_step >0:
      print('Saving checkpoint: ', cur_step, save_step)
      save_checkpoint('latest') # set diff name for each chkp

    cur_step += 1

In [None]:
z_r = gen_noise(64, 100)
t_l = gen_labels(64, 10)
generated = gen(z_r, t_l)
def show2(tensor, num):
  data = tensor.detach().cpu()
  data = data.view(-1, 28, 28).unsqueeze(1)
  print(data[1])
  grid = make_grid(data[:num], nrow=4)#.permute(1,2,0)
  print(grid.shape)
  plt.imshow(grid)
  plt.show()

show2(generated, 10)
#unflatten = torch.nn.Unflatten(2, (28,28))
#generated = generated.unsqueeze(1)
#output = unflatten(generated)
#output.shape

#generated.view(-1, 28, 28).unsqueeze(1).shape