In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from utils import load_data
from models import Generator, Discriminator
from samplers import UniformConditionalDatasetSampler, UniformConditionalLatentSampler
from config import params_wgangp

In [2]:
configuration = 'wgangp'
params = eval(f'params_{configuration}')

In [3]:
np.random.seed(params['seed'])
torch.manual_seed(params['seed'])
torch.cuda.manual_seed_all(params['seed'])

In [4]:
data, labels = load_data('grid')

In [5]:
data_sampler = UniformConditionalDatasetSampler(torch.tensor(data).float(), torch.tensor(labels).long())
noise_sampler = UniformConditionalLatentSampler(params['latent_dim'], labels)
num_classes = data_sampler.num_classes

In [6]:
G = Generator(params['latent_dim'] + num_classes, params['model_dim'], params['data_dim'], spec_norm=params['spec_norm_g']).to(params['device'])
D = Discriminator(params['model_dim'], params['data_dim'] + num_classes, spec_norm=params['spec_norm_d']).to(params['device'])

In [7]:
D_optimizer = optim.Adam(D.parameters(), lr=params['learning_rate'], betas=params['betas'])
G_optimizer = optim.Adam(G.parameters(), lr=params['learning_rate'], betas=params['betas'])

In [8]:
def compute_gp(real_data, real_labels, fake_data, fake_labels):
    epsilon = torch.rand(real_data.shape[0], 1).to(params['device'])
    alpha = epsilon.expand(real_data.size())
    beta = epsilon.expand(real_labels.size())
    
    interpolate_data = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    interpolate_labels = (beta * real_labels + (1 - beta) * fake_labels).requires_grad_(True)
    interpolate = torch.cat([interpolate_data, interpolate_labels], dim=1)
    interpolate_pred = D(interpolate)

    gradients = torch.autograd.grad(outputs=interpolate_pred,
                                    inputs=interpolate,
                                    grad_outputs=torch.ones(interpolate_pred.size()).to(params['device']),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = (gradients.norm(2, dim=1) - 1).pow(2).mean() * params['gp_lambda']
    return gradient_penalty

In [9]:
losses = {
    'D':[],
    'G':[]
}

In [None]:
D.train()
G.train()
for it in range(params['iterations']):
    D.zero_grad()

    # compute real data loss
    x_real, y_real = data_sampler.get_batch(params['batch_size'])
    x_real, y_real = x_real.to(params['device']), y_real.to(params['device'])
    real_sample = torch.cat([x_real, y_real], dim=1)
    real_pred = D(real_sample)

    # compute fake data loss
    z_fake, y_fake = noise_sampler.get_batch(params['batch_size'])
    z_fake, y_fake = z_fake.to(params['device']), y_fake.to(params['device'])
    x_fake = G(torch.cat([z_fake, y_fake], dim=1)).detach()
    fake_sample = torch.cat([x_fake, y_fake], dim=1)
    fake_pred = D(fake_sample)
    
    # backpropagate and step
    gradient_penalty = compute_gp(x_real, y_real, x_fake, y_fake)
    D_loss = fake_pred.mean() - real_pred.mean() + gradient_penalty
    D_loss.backward()
    D_optimizer.step()
    
    # for bookkeeping
    D_loss = fake_pred.mean() - real_pred.mean()
    
    if it % params['n_critic'] == 0:
        G.zero_grad()
        
        # compute loss
        z_fake, y_fake = noise_sampler.get_batch(params['batch_size'])
        z_fake, y_fake = z_fake.to(params['device']), y_fake.to(params['device'])
        x_fake = G(torch.cat([z_fake, y_fake], dim=1))
        fake_sample = torch.cat([x_fake, y_fake], dim=1)
        fake_pred = D(fake_sample)

        # backpropagate and step
        G_loss = -fake_pred.mean()
        G_loss.backward()
        G_optimizer.step()
                
    losses['D'].append(D_loss.item())
    losses['G'].append(-G_loss.item())

In [None]:
plt.plot(losses['G'], label='Generator')
plt.plot(losses['D'], label='Discriminator')
plt.legend()
plt.show()

In [None]:
G.eval()

In [None]:
plt.gcf().set_size_inches(5, 5)
z_fake, y_fake = noise_sampler.get_batch(100000)
z_fake, y_fake = z_fake.to(params['device']), y_fake.to(params['device'])
generated = G(torch.cat([z_fake, y_fake], dim=1)).detach().cpu().numpy()
plt.scatter(generated[:,0], generated[:,1], marker='.', color=(0, 1, 0, 0.01))
plt.axis('equal')
plt.show()