In [58]:
data_dir = '/home/victor/data-ssd'

In [59]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 1
%aimport alphagan

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
from collections import defaultdict

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.nn import init, Parameter
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.models.resnet import BasicBlock
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
%matplotlib inline

In [61]:
cifar = datasets.CIFAR100(
    data_dir,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean= [ 0.5071, 0.4865, 0.4409 ],
            std = [ 0.2673, 0.2564, 0.2762 ]
        ),
    ]),
    target_transform=None,
    download=False)
cifar = torch.stack(list(zip(*cifar))[0])
cifar.size()

torch.Size([50000, 3, 32, 32])

In [62]:
cifar_test = datasets.CIFAR100(
    data_dir,
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean= [ 0.5071, 0.4865, 0.4409 ],
            std = [ 0.2673, 0.2564, 0.2762 ]
        ),
    ]),
    target_transform=None,
    download=False)
cifar_test = torch.stack(list(zip(*cifar_test))[0])
cifar_test.size()                          

torch.Size([10000, 3, 32, 32])

In [63]:
batch_size = 64

In [64]:
n_train, n_test = len(cifar)//1, 128

train_idxs = torch.LongTensor(np.random.permutation(len(cifar))[:n_train])
X_train = DataLoader(cifar[train_idxs], batch_size=batch_size, shuffle=True)
test_idxs = torch.LongTensor(np.random.permutation(len(cifar_test))[:n_test])
X_test = DataLoader(cifar_test[test_idxs], batch_size=batch_size, shuffle=False)

In [81]:
# I think broadcasting should make these unnecessary in the next pytorch release
class ChannelsToLinear(nn.Linear):
    def forward(self, x):
        return super().forward(x.squeeze(-1).squeeze(-1))
class LinearToChannels(nn.Linear):
    def forward(self, x):
        return super().forward(x).unsqueeze(-1).unsqueeze(-1)

# versatile ResNet block which can be upsample or downsampling
class ResBlock(nn.Module):
    def __init__(self, c,
                 activation=nn.LeakyReLU, norm=nn.BatchNorm2d,
                 upsample=1, downsample=1):
        super().__init__()
        self.a1 = activation()
        self.a2 = activation()
        self.norm1 = norm(c)
        self.norm2 = norm(c)
        
        self.resample=None
        assert upsample==1 or downsample==1
        if upsample>1:
            self.conv1 = nn.ConvTranspose2d(c, c, upsample, upsample)
            self.resample = nn.UpsamplingNearest2d(scale_factor=upsample)
        else:
            self.conv1 = nn.Conv2d(c, c, 3, downsample, 1)
        if downsample>1:
            self.resample = nn.AvgPool2d(downsample)
            
        self.conv2 = nn.Conv2d(c, c, 3, 1, 1)
        
        init.xavier_uniform(self.conv1.weight, 2)
        init.xavier_uniform(self.conv2.weight, 2)
        
    def forward(self, x):
        
        y = self.conv1(x)
        y = self.norm1(y)
        y = self.a1(y)
        y = self.conv2(y)
        y = self.norm2(y)
        
        if self.resample:
            x = self.resample(x)
        
        return self.a2(x+y)

In [82]:
latent_dim = 128

In [96]:
# k = 3
# a = lambda: nn.LeakyReLU(.2)
# pool = nn.AvgPool2d
# norm = nn.BatchNorm2d
# E = nn.Sequential(
#     nn.Conv2d(  3,  64, k, 1, k//2), pool(2), norm(64), a(),
#     nn.Conv2d( 64, 128, k, 1, k//2), pool(2), norm(128), a(),
#     nn.Conv2d(128, 256, k, 1, k//2), pool(2), norm(256), a(),
#     nn.Conv2d(256, 512, k, 1, k//2), pool(4), norm(512), a(),
#     ChannelsToLinear(512, latent_dim)
# )
# for i,layer in enumerate(E):
#     if i%4==0:
#         init.xavier_uniform(layer.weight, 2)
h = 128
pool = nn.AvgPool2d
norm = nn.BatchNorm2d
a = lambda: nn.LeakyReLU(.2)
E = nn.Sequential(
    nn.Conv2d(3,h,5,2,2), norm(h), a(),
    ResBlock(h, downsample=2),
    ResBlock(h),
    ResBlock(h),
    a(), pool(8),
    ChannelsToLinear(h, latent_dim)
)
t = Variable(torch.randn(batch_size,3,32,32))
assert E(t).size() == (batch_size,latent_dim)

In [84]:
# k=5
# a = lambda: nn.LeakyReLU(.005)
# norm = nn.BatchNorm2d
# h = 64
# G = nn.Sequential(
#     LinearToChannels(latent_dim, 512), norm(512), a(),
#     nn.ConvTranspose2d(512, 256, 4, 1), norm(256), a(),
#     nn.ConvTranspose2d(256, 128, 2, 2), norm(128), a(),
#     nn.Conv2d( 128, h, k, 1, k//2), norm(h), a(),
#     nn.ConvTranspose2d(  h, h, 2, 2), norm(h), a(),
#     nn.ConvTranspose2d(  h, h, 2, 2), norm(h), a(),
#     nn.Conv2d( h, 3, k, 1, k//2), #nn.Tanh()
# )
# for i,layer in enumerate(G):
#     if i%3==0:
#         init.xavier_uniform(layer.weight, 2)

h = 128
pool = nn.AvgPool2d
norm = nn.BatchNorm2d
a = lambda: nn.LeakyReLU(.2)
G = nn.Sequential(
    LinearToChannels(latent_dim, h), norm(h), a(),
    nn.ConvTranspose2d(h, h, 4, 1), norm(h), a(),
    ResBlock(h, upsample=2),
    ResBlock(h, upsample=2),
    ResBlock(h, upsample=2),
    nn.Conv2d(h, 3, 1, 1), nn.Tanh()
)
t = Variable(torch.randn(batch_size,latent_dim))
assert G(t).size() == (batch_size,3,32,32)

In [97]:
# k = 3
# a = lambda: nn.LeakyReLU(.2)
# pool = nn.AvgPool2d
# norm = nn.BatchNorm2d
# D = nn.Sequential(
#     nn.Conv2d(  3,  64, k, 1, k//2), pool(2), norm(64), a(),
#     nn.Conv2d( 64, 128, k, 1, k//2), pool(2), norm(128), a(),
#     nn.Conv2d(128, 256, k, 1, k//2), pool(2), norm(256), a(),
#     nn.Conv2d(256, 512, k, 1, k//2), pool(4), norm(512), a(),
#     ChannelsToLinear(512, 1), nn.Sigmoid()
# )
# for i,layer in enumerate(D):
#     if i%4==0:
#         init.xavier_uniform(layer.weight, 2)
        
h = 128
pool = nn.AvgPool2d
norm = nn.BatchNorm2d
a = lambda: nn.LeakyReLU(.2)
D = nn.Sequential(
    nn.Conv2d(3,h,5,2,2), norm(h), a(),
    ResBlock(h, downsample=2),
    ResBlock(h),
    ResBlock(h),
    a(), pool(8),
    ChannelsToLinear(h, 1), nn.Sigmoid()
)
    
t = Variable(torch.randn(batch_size,3,32,32))
assert D(t).size() == (batch_size,1)

In [86]:
h = 128
a = lambda: nn.LeakyReLU(.2)
norm = nn.BatchNorm1d
C = nn.Sequential(
    nn.Linear(latent_dim, h), norm(h), a(),
    nn.Linear(h, h), norm(h), a(),
    nn.Linear(h, h), norm(h), a(),
    nn.Linear(h, 1), nn.Sigmoid(),
)

for i,layer in enumerate(C):
    if i%3==0:
        init.xavier_uniform(layer.weight, 2)

t = Variable(torch.randn(batch_size,latent_dim))
assert C(t).size() == (batch_size,1)

In [98]:
model = alphagan.AlphaGAN(E, G, D, C, latent_dim, lambd=10)

In [99]:
diag = []
def log_fn(d):
    d = pd.DataFrame(d)
    diag.append(d)
    print(d)

In [None]:
model.fit(
    X_train, X_test,
    log_fn = log_fn,
    disc_iters=16, ae_iters=16,
    n_epochs=40
)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(14,8))
diagnostic = pd.concat([pd.DataFrame(d.stack(), columns=[i]).T for i,d in enumerate(diag)])
cols = list('rgbcmy')
colors = defaultdict(lambda: cols.pop())
for c in diagnostic:
    component, dataset = c
    kw = {}
    if dataset=='valid':
        kw['label'] = component
    else:
        kw['ls'] = '--'
    ax.plot(diagnostic[c].values, c=colors[component], **kw)
ax.legend(bbox_to_anchor=(1, 0.7))

In [None]:
# samples
z, x = model(16, mode='sample')
fig, ax = plt.subplots(1,1,figsize=(16,4))
ax.imshow(make_grid(x.data, normalize=True).numpy().transpose(1,2,0), interpolation='nearest')
# ax.imshow(make_grid(x.data, range=(0,1)).numpy().transpose(1,2,0), interpolation='nearest')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(16,4))
# training reconstructions
x = cifar[train_idxs][:12]
z, x_rec = model(x)
ax.imshow(make_grid(
    torch.cat((x, x_rec.data)), nrow=12, normalize=True
).numpy().transpose(1,2,0), interpolation='nearest')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(16,4))
# test reconstructions
x = cifar_test[test_idxs][:12]
z, x_rec = model(x)
ax.imshow(make_grid(
    torch.cat((x, x_rec.data)), nrow=12, normalize=True
).numpy().transpose(1,2,0), interpolation='nearest')