In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [8]:
class Critic(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Critic, self).__init__()
    self.disc = nn.Sequential(
        # Input: N x channels_img x 64 x 64
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
        self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
        self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        # In paper they have used LayerNorm which is similar to instancenorm
        # doesn't normalize across the batches
        # affine = True, for learnable parameters
        nn.InstanceNorm2d(out_channels, affine=True), # LayerNorm <--> InstanceNorm
        nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        # Input: N x z_dim x 1 x 1
        self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
        self._block(features_g*16, features_g*8, 4, 2, 1), # 8 x 8
        self._block(features_g*8, features_g*4, 4, 2, 1), # 16 x 16
        self._block(features_g*4, features_g*2, 4, 2, 1), # 32 x 32
        # here we are not using the block as in the generators last conv layer
        # bias has to be excluded
        nn.ConvTranspose2d(
            features_g*2,
            channels_img,
            kernel_size = 4,
            stride = 2,
            padding = 1
        ),
        nn.Tanh(), # Pixels value between [-1, -1]
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d( 
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ), # upsampling
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.gen(x)

def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Critic(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1)
  gen = Generator(z_dim, in_channels, 8)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W)
# test()

In [9]:
### GRADIENT PENALTY (Better than clipping) ###

def gradient_penalty(critic, real, fake, device='cpu'):
  BATCH_SIZE, C, H, W = real.shape
  print(f'BATCH_SIZE: {BATCH_SIZE}, fake.shape[0]: {fake.shape[0]}')
  epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
  interpolated_images = real * epsilon + fake * (1 - epsilon)

  # calculate critic scores
  mixed_scores = critic(interpolated_images)

  gradient = torch.autograd.grad(
      inputs=interpolated_images,
      outputs=mixed_scores,
      grad_outputs=torch.ones_like(mixed_scores),
      create_graph=True,
      retain_graph=True,
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim=1)
  gradient_penalty = torch.mean((gradient_norm - 1)**2)
  
  return gradient_penalty


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

transforms = transforms.Compose(
    [
      transforms.Resize(IMAGE_SIZE),
      transforms.ToTensor(),
      transforms.Normalize(
          [0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)]
      ),
    ]
    
)

dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms,
                         download=True)
# comment mnist above and uncomment below if train on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Critic(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f'logs/real')
writer_fake = SummaryWriter(f'logs/fake')
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, _) in enumerate(dataloader):
    real = real.to(device)
    # noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
    # fake = gen(noise)

    for _ in range(CRITIC_ITERATIONS):
      noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
      fake = gen(noise)
      critic_real = critic(real).reshape(-1)
      critic_fake = critic(fake).reshape(-1)
      gp = gradient_penalty(critic, real, fake, device=device)
      loss_critic = (
          -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
      )
      critic.zero_grad()
      loss_critic.backward(retain_graph=True)
      opt_critic.step()
    
    ## Train Generator: min -E[critic(gen_fake)]
    gen_fake = critic(fake).reshape(-1)
    loss_gen = -torch.mean(gen_fake)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    # Print losses occasionally and print to tensorboard
    if batch_idx % 100 == 0:
        print(
            f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
              Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
        )

        with torch.no_grad():
            fake = gen(fixed_noise)
            # take out (up to) 32 examples
            img_grid_real = torchvision.utils.make_grid(
                real[:32], normalize=True
            )
            img_grid_fake = torchvision.utils.make_grid(
                fake[:32], normalize=True
            )

            writer_real.add_image("Real", img_grid_real, global_step=step)
            writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, fake.shape[0]: 32
BATCH_SIZE: 32, f

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs