In [None]:
import os
import math

import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import transforms, datasets

import q2_sampler

from q3_sampler import svhn_sampler
from q3_model import Critic, Generator
from q2_solution_modified import lp_reg, vf_wasserstein_distance, vf_squared_hellinger

In [None]:
# Example of usage of the code provided and recommended hyper parameters for training GANs.
data_root = './'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_iter = 50000 # N training iterations
n_critic_updates = 5 # N critic updates per generator update
lp_coeff = 10 # Lipschitz penalty coefficient
train_batch_size = 64
test_batch_size = 64
lr = 1e-4
beta1 = 0.5
beta2 = 0.9
z_dim = 100

# Custom
verbose = 20

train_loader, valid_loader, test_loader = svhn_sampler(data_root, train_batch_size, test_batch_size)

generator = Generator(z_dim=z_dim).to(device)
critic = Critic().to(device)

optim_critic = optim.Adam(critic.parameters(), lr=lr, betas=(beta1, beta2))
optim_generator = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))

## Run if you want to load saved model

In [None]:
# Load states
ckpt = torch.load('wgan_lp_50000.ckpt')
ckpt.keys()

epoch = ckpt['epoch']

generator.load_state_dict(ckpt['generator_state_dict'])
optim_generator.load_state_dict(ckpt['optimizer_generator_state_dict'])

critic.load_state_dict(ckpt['critic_state_dict'])
optim_critic.load_state_dict(ckpt['optimizer_critic_state_dict'])

## Train Loop

In [None]:
train_iterator = iter(train_loader)
c_losses = []
g_losses = []

for step in range(epoch, n_iter):
    for t in range(n_critic_updates):
        # Update critic
        optim_critic.zero_grad()

        try:
            x = next(train_iterator)[0].to(device)
        except StopIteration:
            train_iterator = iter(train_loader)
            x = next(train_iterator)[0].to(device)
        
        z = torch.randn((train_batch_size, z_dim)).to(device)
        y = generator(z)

        loss = vf_wasserstein_distance(x, y, critic) + lp_coeff * lp_reg(x, y, critic, device=device)

        # if step <= 5:
        loss.backward()
        optim_critic.step()
        c_losses.append(loss[0])

    # Update generator
    optim_generator.zero_grad()
    z = torch.randn((train_batch_size, z_dim)).to(device)
    y = generator(z)
    x = critic(y)

    loss = x.mean()
    loss.backward()
    optim_generator.step()
    g_losses.append(float(loss))

    # Logging
    if step % verbose == 0:
        print(f"Step #{step}. Generator Loss: {g_losses[-1]:.6f}. Critic Loss: {c_losses[-1]:.6f}.")

In [None]:
torch.save({
    'epoch': 50000,
    'critic_state_dict': critic.state_dict(),
    'generator_state_dict': generator.state_dict(),
    'optimizer_critic_state_dict': optim_critic.state_dict(),
    'optimizer_generator_state_dict': optim_generator.state_dict(),
    'critic_losses': c_losses,
    'generator_losses': g_losses,
    }, wgan_lp_50000.ckpt')

# Sample and display real

In [None]:
train_iterator = iter(train_loader)
x = next(train_iterator)[0].numpy()[:32]

In [None]:
reals = np.transpose(x, (0, 2, 3, 1))
reals = reals * 0.5 + 0.5
real_grid = reals.reshape(8, 4, 32, 32, 3)

fig, axes =  plt.subplots(8, 4, figsize=(4, 8))

for i in range(8):
    for j in range(4):
        axes[i, j].imshow(real_grid[i, j])
        axes[i, j].axis('off')


# Sample and display fake

In [None]:
z = torch.randn((32, z_dim)).to(device)
x = generator(z).detach().cpu().numpy()

fakes = np.transpose(x, (0, 2, 3, 1))
fakes = fakes * 0.5 + 0.5

fakes_grid = fakes.reshape(8, 4, 32, 32, 3)
fig, axes =  plt.subplots(8, 4, figsize=(4, 8))

for i in range(8):
    for j in range(4):
        axes[i, j].imshow(fakes_grid[i, j])
        axes[i, j].axis('off')


# Q3.2

## First display the original

In [None]:
z = torch.randn((1, z_dim)).to(device)
x = generator(z).detach().cpu().numpy()

orig = np.transpose(x, (0, 2, 3, 1))
orig = orig * 0.5 + 0.5

plt.imshow(orig[0])
plt.axis('off')


## Now, add in the perturbations

In [None]:
eps_ratio = 1

z_range = z.max() - z.min()
eps = eps_ratio * z_range
eps = torch.eye(z_dim, z_dim) * eps
eps = eps.to(device)

In [None]:
z_pert = z + eps

In [None]:
z.shape

In [None]:
x_pert = generator(z_pert).detach().cpu().numpy()

pert = np.transpose(x_pert, (0, 2, 3, 1))
pert = pert * 0.5 + 0.5

In [None]:
pert_grid = pert.reshape(20, 5, 32, 32, 3)

fig, axes =  plt.subplots(20, 5, figsize=(5, 20), squeeze=True)

for i in range(20):
    for j in range(5):
        axes[i, j].imshow(pert_grid[i, j])
        axes[i, j].axis('off')


In [None]:
select_pert = [
    [
        pert_grid[18, 3], # background color
        pert_grid[2, 3], # 0
        pert_grid[19, 0], # 7
    ], [
        pert_grid[1, 1], # erased
        pert_grid[3, 1], # 1
        pert_grid[2, 0], # 8
    ]
]

In [None]:
fig, axes =  plt.subplots(2, 3, figsize=(3, 2), squeeze=True)

for i in range(2):
    for j in range(3):
        axes[i, j].imshow(select_pert[i][j])
        axes[i, j].axis('off')


# 3.3

In [None]:
z0 = torch.randn((1, z_dim)).to(device)
z1 = torch.randn((1, z_dim)).to(device)

In [None]:
x0 = generator(z0)
img0 = np.transpose(x0.detach().cpu().numpy(), (0, 2, 3, 1))
img0 = img0 * 0.5 + 0.5

In [None]:
x1 = generator(z1)
img1 = np.transpose(x1.detach().cpu().numpy(), (0, 2, 3, 1))
img1 = img1 * 0.5 + 0.5

In [None]:
plt.imshow(img0[0])
plt.axis('off')


In [None]:
plt.imshow(img1[0])
plt.axis('off')


# 3.3a

In [None]:
alphas = torch.arange(0, 1.01, 0.1).unsqueeze(1).to(device)

In [None]:
z_interp = alphas * z0 + (1-alphas) * z1

x_interp = generator(z_interp)
interp = np.transpose(x_interp.detach().cpu().numpy(), (0, 2, 3, 1))
interp = interp * 0.5 + 0.5

In [None]:
fig, axes =  plt.subplots(1, 11, figsize=(11, 1), squeeze=True)

for j in range(11):
    axes[j].imshow(interp[j])
    axes[j].axis('off')


# 3.3b

In [None]:
beta = alphas.view(11, 1, 1, 1)
x0 = x_interp[-1]
x1 = x_interp[0]

In [None]:
x_interp = beta * x0 + (1 - beta) * x1
interp = np.transpose(x_interp.detach().cpu().numpy(), (0, 2, 3, 1))
interp = interp * 0.5 + 0.5

In [None]:
fig, axes =  plt.subplots(1, 11, figsize=(11, 1), squeeze=True)

for j in range(11):
    axes[j].imshow(interp[j])
    axes[j].axis('off')
