In [None]:
#%config Completer.use_jedi = False

In [1]:
from time import time
import torch
from torch import nn
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import norm

import os
from collections import defaultdict
import pickle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

import scipy
from scipy.stats import multivariate_normal

import torch.utils.data as data

In [2]:
torch.cuda.is_available()

True

In [4]:
torch.cuda.get_device_name(0)

'Tesla T4'

# Vanilla GAN

In [30]:
def load_pickle(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    train_data, test_data = 2*(data['train'] / 255) - 1, 2*(data['test']/255)-1
    train_labels, test_labels = data['train_labels'], data['test_labels']
    return train_data, test_data, train_labels, test_labels


def show_samples(samples, title=None, nrow=10):
    samples = (torch.FloatTensor(samples.cpu())).permute(0, 3, 1, 2)
    grid_img = make_grid(samples, nrow=nrow)
    plt.figure(figsize=(10,10))
    if title is not None:
        plt.title(title)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis('off')
    plt.show()


def visualize_data(data, title):
    idxs = np.random.choice(len(data), replace=False, size=(100,))
    images = data[idxs]
    show_samples(images, title)
    

def visualize_batch(batch, nrow=10):
    show_samples(batch[...,None], nrow=nrow)


def plot_training_curves(train_losses, test_losses):
    n_train = len(train_losses[list(train_losses.keys())[0]])
    n_test = len(test_losses[list(train_losses.keys())[0]])
    x_train = np.linspace(0, n_test - 1, n_train)
    x_test = np.arange(n_test)

    plt.figure()
    for key, value in train_losses.items():
        plt.plot(x_train, value, label=key + '_train')

    for key, value in test_losses.items():
        plt.plot(x_test, value, label=key + '_test')

    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()

In [7]:
train_data, test_data, train_labels, test_labels = load_pickle('mnist.pkl')

In [None]:
visualize_data(train_data, 'MNIST samples')

In [50]:
class VanillaGAN():
    def __init__(self, G, D, noise_fn, data_fn,
                 batch_size=32, device='cpu', lr_D=1e-3, lr_G=2e-4):
        """A GAN class for holding and training a generator and discriminator
        Args:
            G: a Ganerator network
            D: A Discriminator network
            noise_fn: function f(num: int) -> pytorch tensor, (latent vectors)
            data_fn: function f(num: int) -> pytorch tensor, (real samples)
            batch_size: training batch size
            device: cpu or CUDA
            lr_D: learning rate for the discriminator
            lr_G: learning rate for the generator
        """
        self.G = G
        self.G = self.G.to(device)
        self.D = D
        self.D = self.D.to(device)
        self.noise_fn = noise_fn
        self.data_fn = data_fn
        self.batch_size = batch_size
        self.device = device
        # !
        self.criterion = nn.BCELoss()
        self.optim_D = optim.Adam(D.parameters(),
                                  lr=lr_D, betas=(0.5, 0.999))
        self.optim_G = optim.Adam(G.parameters(),
                                  lr=lr_G, betas=(0.5, 0.999))
        # is needed in D train loop
        self.target_ones = torch.ones((batch_size, 1)).to(device)
        self.target_zeros = torch.zeros((batch_size, 1)).to(device)
    
    def generate_samples(self, latent_vec=None, num=None):
        """Sample from the generator.
        Args:
            latent_vec: A pytorch latent vector or None
            num: The number of samples to generate if latent_vec is None
        If latent_vec and num are None then us self.batch_size random latent
        vectors.
        ! We don't need grad for generated samples
        """
        num = self.batch_size if num is None else num
        latent_vec = self.noise_fn(num) if latent_vec is None else latent_vec
        # your code here
        with torch.no_grad():
            samples = self.G(latent_vec)
        return samples

    def train_step_G(self):
        """Train the generator one step and return the loss."""
        self.optim_G.zero_grad()
        latent_vec = self.noise_fn(self.batch_size)
        # your code here
        # use self.target_ones
        generated = self.G(latent_vec)
        classifications = self.D(generated)
        loss = self.criterion(classifications, self.target_ones)
        loss.backward()
        self.optim_G.step()
        return loss.item()

    def train_step_D(self):
        """Train the discriminator one step and return the losses."""
        self.optim_D.zero_grad()

        # real samples
        real_samples = self.data_fn(self.batch_size)
        # calc real loss
        # you code here
        pred_real = self.D(real_samples)
        loss_real = self.criterion(pred_real, self.target_ones)

        # generated samples
        latent_vec = self.noise_fn(self.batch_size)
        # calc fake loss
        # you shouldn't optimize G here
        # you code here
        
        with torch.no_grad():
            fake_samples = self.G(latent_vec)
        pred_fake = self.D(fake_samples)
        loss_fake = self.criterion(pred_fake, self.target_zeros)

        # combine
        loss = (loss_real + loss_fake) / 2
        loss.backward()
        self.optim_D.step()
        
        return loss_real.item(), loss_fake.item()

    def train_step(self):
        """Train both networks and return the losses."""
        loss_D = self.train_step_D()
        loss_G = self.train_step_G()
        return loss_G, loss_D
    
    def generate_images(self, latent_vec=None, num=None):
        samples = self.generate_samples(latent_vec=latent_vec, num=num)
        return samples.view(-1,28,28)

In [51]:
def visualize_GAN(gan):
    sampes = gan.generate_images(num=4)
    visualize_batch(sampes)

In [52]:
# make data_fn for MNIST

In [53]:
def loopy(dl):
    while True:
        for x in iter(dl): yield x

In [54]:
def get_data_fn(data_loader):
    
    def data_fn(x):
        return next(loopy(data_loader)).cuda()
    
    return data_fn

In [55]:
###

In [None]:
def get_simple_model(hiddens):
    assert len(hiddens) > 1

    modules = []
    for in_, out_ in zip(hiddens[:-2], hiddens[1:-1]):
        modules.extend([nn.Linear(in_, out_), nn.ReLU()])

    modules.append(nn.Linear(hiddens[-2], hiddens[-1]))

    return nn.Sequential(*modules)

In [56]:
BATCH_SIZE = 128

In [65]:
latent_size = 2
noise_fn = lambda x: torch.randn((x,latent_size), device='cuda')
train_loader = DataLoader(train_data[...,0].astype(np.float32).reshape(len(train_data), 28*28), batch_size=BATCH_SIZE, shuffle=True)
data_fn = get_data_fn(train_loader)

gen_hiddens = [latent_size, 16, 32, 128]
dis_hiddens = [28*28, 128, 32, 16, 1]
G = nn.Sequential(get_simple_model(gen_hiddens), nn.Tanh()).cuda()
D = nn.Sequential(*get_simple_model(dis_hiddens), nn.Sigmoid()).cuda()

gan = VanillaGAN(G, D, noise_fn, data_fn, batch_size=BATCH_SIZE, device='cuda')

In [60]:
EPOCHS = 200
BATCHES = 100

In [69]:
loss_g, loss_d_real, loss_d_fake = [], [], []
start = time()
for epoch in range(EPOCHS):
    #break
    loss_g_running, loss_d_real_running, loss_d_fake_running = 0, 0, 0
    for i,batch in enumerate(range(BATCHES)):
        lg_, (ldr_, ldf_) = gan.train_step()
        
        loss_g_running += lg_
        loss_d_real_running += ldr_
        loss_d_fake_running += ldf_
    loss_g.append(loss_g_running / BATCHES)
    loss_d_real.append(loss_d_real_running / BATCHES)
    loss_d_fake.append(loss_d_fake_running / BATCHES)
    print(f"Epoch {epoch+1}/{EPOCHS} ({int(time() - start)}s):"
          f" G={loss_g[-1]:.3f},"
          f" Dr={loss_d_real[-1]:.3f},"
          f" Df={loss_d_fake[-1]:.3f}")
    visualize_GAN(gan)

In [None]:
plt.figure(figsize=(15,10))
plt.plot(loss_g, label='G loss')
plt.plot(loss_d_real, label='D real')
plt.plot(loss_d_fake, label='D fake')

plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.show()

In [None]:
visualize_batch(gan.generate_images(num=100))

In [None]:
a = 2
step = 0.25

print('GAN')
x, y = np.mgrid[-a:a:step, -a:a:step]
pos = np.dstack((x, y))
n_row = pos.shape[0]

pos = pos.reshape((np.product(pos.shape[:2]), 2))

samples = gan.generate_images(torch.from_numpy(pos).float().cuda())

visualize_batch(samples, n_row)