In [None]:
# @title <p>Essential Import
import os, shutil, json
from PIL import Image
from zipfile import ZipFile
import matplotlib.pyplot as plt
import numpy as np, pandas as pd, random as rd
import warnings
warnings.filterwarnings("ignore")

In [None]:
# @title <p>Torch Essential Import
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Discriminator(nn.Module):
  def __init__(self, image_ch, features_d):
    super().__init__()
    self.disc = nn.Sequential(
        self._block(image_ch, features_d, 4, 2, 1),
        self._block(features_d, features_d * 2, 4, 2, 1),
        self._block(features_d * 2, features_d * 4, 4, 2, 1),
        self._block(features_d * 4, features_d * 8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, 4, 2, 0),
    )

  def _block(self, in_ch, out_ch, kernel_size, stride, padding):
    block = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False),
        nn.InstanceNorm2d(out_ch, affine=True),
        nn.LeakyReLU()
    )
    return block

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, noise_ch, image_ch, features_g):
    super().__init__()
    self.net = nn.Sequential(
        self._block(noise_ch, features_g * 16, 4, 1, 0),
        self._block(features_g * 16, features_g * 8, 4, 2, 1),
        self._block(features_g * 8, features_g * 4, 4, 2, 1),
        self._block(features_g * 4, features_g * 2, 4, 2, 1),
        nn.ConvTranspose2d(features_g * 2, image_ch, 4, 2, 1),
        nn.Tanh()
    )

  def _block(self, in_ch, out_ch, kernel_size, stride, padding):
    block = nn.Sequential(
        nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride, padding),
        nn.BatchNorm2d(out_ch),
        nn.ReLU()
    )
    return block

  def forward(self, x):
    return self.net(x)

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m ,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
x = torch.rand(1, 3, 64, 64)
model = Discriminator(3, 8)
out = model(x)
print(f'Disc shape : {out.shape}')

x = torch.rand(1, 100, 1, 1)
model = Generator(100, 3, 8)
out = model(x)
print(f'Gen shape : {out.shape}')

Disc shape : torch.Size([1, 1, 1, 1])
Gen shape : torch.Size([1, 3, 64, 64])


In [None]:
LR = 5e-5
BS = 64
IMAGE_SIZE = 64
IMAGE_CH = 1
Z_DIM = 128
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 1e-2

transforms = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(
        [0.5 for _ in range(IMAGE_CH)], [0.5 for _ in range(IMAGE_CH)]
    )
])

In [None]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size = BS, shuffle=True)

gen = Generator(Z_DIM, IMAGE_CH, FEATURES_GEN)
disc = Discriminator(IMAGE_CH, FEATURES_CRITIC)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = torch.optim.RMSprop(gen.parameters(), lr=LR)
opt_disc = torch.optim.RMSprop(disc.parameters(), lr=LR)

fixed_noise = torch.randn(BS, Z_DIM, 1, 1)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
  for idx, (real, _) in enumerate(tqdm(loader)):
    real = real.to(device)
    cur_batch_size = real.shape[0]

    for _ in range(CRITIC_ITERATIONS):
      noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
      fake = gen(noise)
      disc_real = disc(real).reshape(-1)
      disc_fake = disc(fake).reshape(-1)
      loss_disc = -(torch.mean(disc_real) - torch.mean(disc_fake))
      disc.zero_grad()
      loss_disc.backward(retain_graph=True)
      opt_disc.step()

      for p in disc.parameters():
        p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

    gen_fake = disc(fake).reshape(-1)
    loss_gen = -torch.mean(gen_fake)

    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    if idx % 100 == 0 and idx > 0:
      gen.eval()
      disc.eval()
      print(
          f"Epoch [{epoch}/{NUM_EPOCHS} Batch {idx}/{len(loader)}] \
          Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
      )

      with torch.no_grad():
        fake = gen(noise)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)

      step += 1
      gen.train()
      disc.train()

  2%|▏         | 18/938 [10:41<9:06:14, 35.62s/it]


KeyboardInterrupt: ignored