In [1]:
use_gpu = True
data_dir = '/home/victor/data'

In [2]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 1
%aimport alphagan

In [3]:
from collections import defaultdict
from psutil import cpu_count

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.nn import init, Parameter
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
cifar = datasets.CIFAR100(
    data_dir,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ]),
    target_transform=None,
    download=False)
cifar = torch.stack(list(zip(*cifar))[0])
cifar.size()

torch.Size([50000, 3, 32, 32])

In [5]:
cifar_test = datasets.CIFAR100(
    data_dir,
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ]),
    target_transform=None,
    download=False)
cifar_test = torch.stack(list(zip(*cifar_test))[0])
cifar_test.size()                          

torch.Size([10000, 3, 32, 32])

In [6]:
batch_size = 64

In [7]:
n_train, n_test = len(cifar)//1, batch_size*2

num_workers = cpu_count() if use_gpu else 0

train_idxs = torch.LongTensor(np.random.permutation(len(cifar))[:n_train])
X_train = DataLoader(cifar[train_idxs], batch_size=batch_size, shuffle=True,
                     num_workers=num_workers, pin_memory=use_gpu)
test_idxs = torch.LongTensor(np.random.permutation(len(cifar_test))[:n_test])
X_test = DataLoader(cifar_test[test_idxs], batch_size=batch_size, shuffle=False,
                   num_workers=num_workers, pin_memory=use_gpu)

In [8]:
# I think broadcasting should make these unnecessary in the next pytorch release
class ChannelsToLinear(nn.Linear):
    def forward(self, x):
        return super().forward(x.squeeze(-1).squeeze(-1))
class LinearToChannels(nn.Linear):
    def forward(self, x):
        return super().forward(x).unsqueeze(-1).unsqueeze(-1)

# versatile ResNet block which can be upsampling or downsampling
class ResBlock(nn.Module):
    def __init__(self, c,
                 activation=nn.LeakyReLU, norm=nn.BatchNorm2d,
                 upsample=1, downsample=1):
        super().__init__()
        self.a1 = activation()
        self.a2 = activation()
        self.norm1 = norm(c)
        self.norm2 = norm(c)
        
        self.resample=None
        assert upsample==1 or downsample==1
        if upsample>1:
            self.conv1 = nn.ConvTranspose2d(c, c, upsample, upsample)
            self.resample = nn.UpsamplingNearest2d(scale_factor=upsample)
        else:
            self.conv1 = nn.Conv2d(c, c, 3, downsample, 1)
        if downsample>1:
            self.resample = nn.AvgPool2d(downsample)
            
        self.conv2 = nn.Conv2d(c, c, 3, 1, 1)
        
        init.xavier_uniform(self.conv1.weight, 2)
        init.xavier_uniform(self.conv2.weight, 2)
        
    def forward(self, x):
        
        y = self.conv1(x)
        y = self.norm1(y)
        y = self.a1(y)
        y = self.conv2(y)
        y = self.norm2(y)
        
        if self.resample:
            x = self.resample(x)
        
        return self.a2(x+y)

In [9]:
latent_dim = 128

In [10]:
h = 128
pool = nn.AvgPool2d
norm = nn.BatchNorm2d
a = lambda: nn.LeakyReLU(.2)
E = nn.Sequential(
    nn.Conv2d(3,h,3,2,1), norm(h), a(),
    ResBlock(h, activation=a, norm=norm, downsample=2),
    ResBlock(h, activation=a, norm=norm),
    ResBlock(h, activation=a, norm=norm),
    a(), pool(8),
    ChannelsToLinear(h, latent_dim)
)

for layer in (E[0], E[8]):
    init.xavier_uniform(layer.weight, 2)

t = Variable(torch.randn(batch_size,3,32,32))
assert E(t).size() == (batch_size,latent_dim)

In [11]:
h = 128
pool = nn.AvgPool2d
norm = nn.BatchNorm2d
a = lambda: nn.LeakyReLU(.2)
G = nn.Sequential(
    LinearToChannels(latent_dim, h), norm(h), a(),
    nn.ConvTranspose2d(h, h, 4, 1), norm(h), a(),
    ResBlock(h, activation=a, norm=norm, upsample=2),
    ResBlock(h, activation=a, norm=norm, upsample=2),
    ResBlock(h, activation=a, norm=norm, upsample=2),
    nn.Conv2d(h, 3, 1, 1), nn.Sigmoid()
)

for layer in (G[0], G[3], G[9]):
    init.xavier_uniform(layer.weight, 2)

t = Variable(torch.randn(batch_size,latent_dim))
assert G(t).size() == (batch_size,3,32,32)



In [12]:
h = 128
pool = nn.AvgPool2d
norm = nn.BatchNorm2d
a = lambda: nn.LeakyReLU(.2)
D = nn.Sequential(
    nn.Conv2d(3,h,3,2,1), norm(h), a(),
    ResBlock(h, activation=a, norm=norm, downsample=2),
    ResBlock(h, activation=a, norm=norm),
    ResBlock(h, activation=a, norm=norm),
    a(), pool(8),
    ChannelsToLinear(h, 1), nn.Sigmoid()
)

for layer in (D[0], D[8]):
    init.xavier_uniform(layer.weight, 2)
    
t = Variable(torch.randn(batch_size,3,32,32))
assert D(t).size() == (batch_size,1)

In [13]:
h = 256
a = lambda: nn.LeakyReLU(.2)
norm = nn.BatchNorm1d
C = nn.Sequential(
    nn.Linear(latent_dim, h), norm(h), a(),
    nn.Linear(h, h), norm(h), a(),
    nn.Linear(h, 1), nn.Sigmoid(),
)

for i,layer in enumerate(C):
    if i%3==0:
        init.xavier_uniform(layer.weight, 2)

t = Variable(torch.randn(batch_size,latent_dim))
assert C(t).size() == (batch_size,1)

In [14]:
model = alphagan.AlphaGAN(E, G, D, C, latent_dim, lambd=10)
if use_gpu:
    model = model.cuda()

In [None]:
diag = []
def log_fn(d):
    d = pd.DataFrame(d)
    diag.append(d)
    print(d)

In [None]:
model.fit(
    X_train, X_test,
    log_fn = log_fn,
    n_iter=(1,2), report_every=1,
    n_batches=len(X_train)//10, n_epochs=20
)



                            train     valid
adversarial_loss         0.763698  0.786654
code_adversarial_loss    0.666158  0.665527
code_discriminator_loss  1.636050  2.050261
discriminator_loss       1.247683  1.317447
reconstruction_loss      1.713441  1.626369


                            train     valid
adversarial_loss         0.768617  0.822387
code_adversarial_loss    0.592550  0.652246
code_discriminator_loss  1.583596  1.866234
discriminator_loss       1.291427  1.316001
reconstruction_loss      1.363657  1.649362


                            train     valid
adversarial_loss         0.782339  0.950987
code_adversarial_loss    0.607436  0.762274
code_discriminator_loss  1.549248  1.573001
discriminator_loss       1.271981  1.321378
reconstruction_loss      1.303240  1.654246


In [None]:
fig, ax = plt.subplots(1,1,figsize=(14,8))
diagnostic = pd.concat([pd.DataFrame(d.stack(), columns=[i]).T for i,d in enumerate(diag)])
cols = list('rgbcmy')
colors = defaultdict(lambda: cols.pop())
for c in diagnostic:
    component, dataset = c
    kw = {}
    if dataset=='valid':
        kw['label'] = component
    else:
        kw['ls'] = '--'
    ax.plot(diagnostic[c].values, c=colors[component], **kw)
ax.legend(bbox_to_anchor=(1, 0.7))

In [None]:
# samples
z, x = model(16, mode='sample')
fig, ax = plt.subplots(1,1,figsize=(16,4))
ax.imshow(make_grid(x.data, normalize=True, range=(-1,1)).cpu().numpy().transpose(1,2,0), interpolation='nearest')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(16,4))
# training reconstructions
x = cifar[train_idxs][:12]
z, x_rec = model(x)
ax.imshow(make_grid(
    torch.cat((x, x_rec.cpu().data)), nrow=12, normalize=True, range=(-1,1)
).cpu().numpy().transpose(1,2,0), interpolation='nearest')

In [None]:
x.min()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(16,4))
# test reconstructions
x = cifar_test[test_idxs][:12]
z, x_rec = model(x)
ax.imshow(make_grid(
    torch.cat((x, x_rec.cpu().data)), nrow=12, normalize=True
).cpu().numpy().transpose(1,2,0), interpolation='nearest')