In [1]:
data_dir = '/home/victor/data-ssd'

In [2]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 1
%aimport alphagan

In [47]:
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
%matplotlib inline

In [45]:
cifar = datasets.CIFAR100(
    data_dir,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean= [ 0.5071, 0.4865, 0.4409 ],
            std = [ 0.2673, 0.2564, 0.2762 ]
        ),
    ]),
    target_transform=None,
    download=False)
cifar = torch.stack(list(zip(*cifar))[0])
cifar.size()

torch.Size([50000, 3, 32, 32])

In [5]:
cifar_test = datasets.CIFAR100(
    data_dir,
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean= [ 0.5071, 0.4865, 0.4409 ],
            std = [ 0.2673, 0.2564, 0.2762 ]
        ),
    ]),
    target_transform=None,
    download=False)
cifar_test = torch.stack(list(zip(*cifar_test))[0])
cifar_test.size()                          

torch.Size([10000, 3, 32, 32])

In [90]:
batch_size = 128
n_train, n_test = len(cifar)//1, 128#1024, batch_size

train_idxs = torch.LongTensor(np.random.permutation(len(cifar))[:n_train])
X_train = DataLoader(cifar[train_idxs], batch_size=batch_size, shuffle=True)
test_idxs = torch.LongTensor(np.random.permutation(len(cifar_test))[:n_test])
X_test = DataLoader(cifar_test[test_idxs], batch_size=batch_size, shuffle=False)

In [91]:
class ChannelsToLinear(nn.Linear):
    def forward(self, x):
        return super().forward(x.squeeze(-1).squeeze(-1))
class LinearToChannels(nn.Linear):
    def forward(self, x):
        return super().forward(x).unsqueeze(-1).unsqueeze(-1)      

In [92]:
latent_dim = 32

In [93]:
k = 3
encoder = nn.Sequential(
    nn.Conv2d(  3,  16, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 16,  32, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 32,  64, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 64, 512, k, 1, k//2), nn.AvgPool2d(4), nn.ReLU(),
    ChannelsToLinear(512, latent_dim)
)
for i,layer in enumerate(encoder):
    if i%3==0:
        init.xavier_uniform(layer.weight, 2)

In [94]:
k=5
generator = nn.Sequential(
    LinearToChannels(latent_dim, 512), nn.ReLU(),
    nn.ConvTranspose2d(512, 64, 4, 1), nn.ReLU(),
    nn.ConvTranspose2d(  64, 32, 2, 2), nn.ReLU(),
    nn.Conv2d( 32, 32, k, 1, k//2), nn.ReLU(),
    nn.ConvTranspose2d(  32, 16, 2, 2), nn.ReLU(),
    nn.ConvTranspose2d(  16,  8, 2, 2), nn.ReLU(),
    nn.Conv2d(  8,  3, k, 1, k//2),
)
for i,layer in enumerate(generator):
    if i%2==0:
        init.xavier_uniform(layer.weight, 2)

In [95]:
k = 3
D = nn.Sequential(
    nn.Conv2d(  3,  16, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 16,  32, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 32,  64, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 64, 512, k, 1, k//2), nn.AvgPool2d(4), nn.ReLU(),
    ChannelsToLinear(512, 1), nn.Sigmoid()
)
for i,layer in enumerate(D):
    if i%3==0:
        init.xavier_uniform(layer.weight, 2)

In [96]:
C_h = 512
C = nn.Sequential(
    nn.Linear(latent_dim, C_h), nn.ReLU(),
    nn.Linear(C_h, C_h//2), nn.ReLU(),
    nn.Linear(C_h//2, 1), nn.Sigmoid(),
)

In [97]:
assert encoder(Variable(torch.randn(batch_size,3,32,32))).size() == (batch_size,latent_dim)
assert generator(Variable(torch.randn(batch_size,latent_dim))).size() == (batch_size,3,32,32)
assert D(Variable(torch.randn(batch_size,3,32,32))).size() == (batch_size,1)
assert C(Variable(torch.randn(batch_size,latent_dim))).size() == (batch_size,1)

In [None]:
model = alphagan.AlphaGAN(encoder, generator, D, C, latent_dim, lambd=10)

In [None]:
diagnostic = []
def log_fn(d):
    d = pd.DataFrame(d)
    diagnostic.append(d)
    print(d)
model.fit(
    X_train, X_test,
    log_fn = log_fn,
    n_epochs=20
)




In [None]:
# samples
z, x = model(16, mode='sample')
fig, ax = plt.subplots(1,1,figsize=(16,4))
ax.imshow(make_grid(x.data, normalize=True).numpy().transpose(1,2,0), interpolation='nearest')
# ax.imshow(make_grid(x.data, range=(0,1)).numpy().transpose(1,2,0), interpolation='nearest')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(16,4))
# training reconstructions
x = cifar[train_idxs][:12]
z, x_rec = model(x)
ax.imshow(make_grid(
    torch.cat((x, x_rec.data)), nrow=12, normalize=True
).numpy().transpose(1,2,0), interpolation='nearest')

In [None]:
fig, ax = plt.subplots(1,1,figsize=(16,4))
# test reconstructions
x = cifar_test[test_idxs][:12]
z, x_rec = model(x)
ax.imshow(make_grid(
    torch.cat((x, x_rec.data)), nrow=12, normalize=True
).numpy().transpose(1,2,0), interpolation='nearest')