In [None]:
# @title <p>Essential Import
import os, shutil, json
from PIL import Image
from zipfile import ZipFile
import matplotlib.pyplot as plt
import numpy as np, pandas as pd, random as rd
import warnings
warnings.filterwarnings("ignore")

In [None]:
# @title <p>Torch Essential Import
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Discriminator(nn.Module):
  def __init__(self, image_ch, features_d):
    super().__init__()
    self.disc = nn.Sequential(
        self._block(image_ch, features_d, 4, 2, 1),
        self._block(features_d, features_d *2, 4, 2, 1),
        self._block(features_d *2, features_d *4, 4, 2, 1),
        self._block(features_d *4, features_d *8, 4, 2, 1),
        nn.Conv2d(features_d * 8, 1, 4, 2, 0),
        nn.Sigmoid()
    )

  def _block(self, in_ch, out_ch, kernel_size, stride, padding):
    block = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding),
        nn.LeakyReLU(0.2)
    )
    return block

  def forward(self, x):
    return self.disc(x)


class Generator(nn.Module):
  def __init__(self, noise_ch, image_ch, features_g):
    super().__init__()
    self.net = nn.Sequential(
        self._block(noise_ch, features_g * 16, 4, 1, 0),
        self._block(features_g * 16, features_g * 8, 4, 2, 1),
        self._block(features_g * 8, features_g *4, 4, 2, 1),
        self._block(features_g * 4, features_g *2, 4, 2, 1),
        nn.ConvTranspose2d(
            features_g * 2, image_ch, 4, 2, 1
        ),
        # Output: N x image_ch x 64 x 64
        nn.Tanh()
    )

  def _block(self, in_ch, out_ch, kernel_size, stride, padding):
    block = nn.Sequential(
        nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride, padding),
        nn.ReLU(),
    )
    return block

  def forward(self, x):
    return self.net(x)

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal(m.weight.data, 0.0, 0.02)

In [None]:
def test():
  N, in_ch, H, W = 8, 3, 64, 64
  noise_dim = 100
  x = torch.randn((N, in_ch, H, W))
  disc = Discriminator(in_ch, 8)
  print(f"Disc shape : {disc(x).shape}")
  assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
  gen = Generator(noise_dim, in_ch, 8)
  z = torch.randn((N, noise_dim, 1, 1))
  print(f"Gen shape : {gen(z).shape}")
  assert gen(z).shape == (N, in_ch, H, W), "Generator test failed"
  print("Succes test passed!")

In [None]:
test()

Disc shape : torch.Size([8, 1, 1, 1])
Gen shape : torch.Size([8, 3, 64, 64])
Succes test passed!


In [None]:
LR = 5e-5
BS = 64
IMAGE_SIZE = 64
IMAGE_CH = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 1e-2

transforms = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(
        [0.5 for _ in range(IMAGE_CH)],
        [0.5 for _ in range (IMAGE_CH)]
    )
])

In [None]:
from torchvision import datasets
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size = BS, shuffle=True)

gen = Generator(Z_DIM, IMAGE_CH, FEATURES_GEN).to(device)
disc = Discriminator(IMAGE_CH, FEATURES_CRITIC).to(device)

initialize_weights(gen)
initialize_weights(disc)

opt_gen = torch.optim.RMSprop(gen.parameters(), lr = LR)
opt_disc = torch.optim.RMSprop(disc.parameters(), lr = LR)
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
  # Target labels not needed! <3 unsupervised
  for idx, (real, _) in enumerate(loader):
    real = real.to(device)
    noise = torch.randn(BS, Z_DIM, 1, 1).to(device)
    fake = gen(noise)

    ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
    disc_real = disc(real).to(device)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake.detach()).reshape(-1)
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real + loss_disc_fake)/2

    disc.zero_grad()
    loss_disc.backward()
    opt_disc.step()

    ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
    output = disc(fake).reshape(-1)
    loss_gen = criterion(output, torch.ones_like(output))

    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    if idx % 100 == 0:
      print(
          f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {idx}/{len(loader)} \
          Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}"
      )

    with torch.no_grad():
      fake = gen(fixed_noise)

      # take out (up to) 32 examples
      img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
      img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

    step+=1

Epoch [0/5] Batch 0/938           Loss D: 0.7028, Loss G: 0.6520


KeyboardInterrupt: ignored

In [None]:
output

tensor([1.5232, 1.5237, 1.5237, 1.5240, 1.5244, 1.5243, 1.5240, 1.5237, 1.5242,
        1.5238, 1.5240, 1.5241, 1.5242, 1.5240, 1.5238, 1.5240, 1.5237, 1.5241,
        1.5237, 1.5238, 1.5240, 1.5239, 1.5243, 1.5241, 1.5241, 1.5239, 1.5243,
        1.5244, 1.5232, 1.5240, 1.5241, 1.5246, 1.5241, 1.5237, 1.5241, 1.5241,
        1.5238, 1.5236, 1.5239, 1.5238, 1.5247, 1.5236, 1.5239, 1.5239, 1.5238,
        1.5240, 1.5244, 1.5241, 1.5236, 1.5241, 1.5236, 1.5236, 1.5242, 1.5243,
        1.5245, 1.5239, 1.5239, 1.5239, 1.5243, 1.5237, 1.5239, 1.5240, 1.5246,
        1.5238], grad_fn=<ReshapeAliasBackward0>)