In [1]:
data_dir = '/home/victor/data-ssd'

In [2]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 1
%aimport alphagan

In [3]:
import sys
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

In [4]:
class ChannelsToLinear(nn.Linear):
    def forward(self, x):
        return super().forward(x.squeeze(-1).squeeze(-1))
class LinearToChannels(nn.Linear):
    def forward(self, x):
        return super().forward(x).unsqueeze(-1).unsqueeze(-1)      

In [5]:
latent_dim = 32

In [6]:
k = 3
encoder = nn.Sequential(
    nn.Conv2d(  3,  16, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 16,  32, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 32,  64, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 64, 128, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d(128, 256, k, 1, k//2), nn.AvgPool2d(2),
    ChannelsToLinear(256, latent_dim)
)

In [9]:
generator = nn.Sequential(
    LinearToChannels(latent_dim, 1024), nn.ReLU(),
    nn.ConvTranspose2d(1024, 64, 4, 1), nn.ReLU(),
    nn.ConvTranspose2d(  64, 32, 2, 2), nn.ReLU(),
    nn.ConvTranspose2d(  32, 16, 2, 2), nn.ReLU(),
    nn.ConvTranspose2d(  16,  3, 2, 2), nn.Sigmoid(),
)

In [10]:
k = 3
D = nn.Sequential(
    nn.Conv2d(  3,  16, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 16,  32, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 32,  64, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d( 64, 128, k, 1, k//2), nn.AvgPool2d(2), nn.ReLU(),
    nn.Conv2d(128, 256, k, 1, k//2), nn.AvgPool2d(2),
    ChannelsToLinear(256, 1), nn.Sigmoid()
)

In [12]:
C_h = 512
C = nn.Sequential(
    nn.Linear(latent_dim, C_h), nn.ReLU(),
    nn.Linear(C_h, C_h), nn.ReLU(),
    nn.Linear(C_h, 1), nn.Sigmoid(),
)

In [23]:
batch_size = 128
assert encoder(Variable(torch.randn(batch_size,3,32,32))).size() == (batch_size,latent_dim)
assert generator(Variable(torch.randn(batch_size,latent_dim))).size() == (batch_size,3,32,32)
assert D(Variable(torch.randn(batch_size,3,32,32))).size() == (batch_size,1)
assert C(Variable(torch.randn(batch_size,latent_dim))).size() == (batch_size,1)

In [26]:
model = alphagan.AlphaGAN(encoder, generator, D, C, latent_dim)

In [14]:
cifar = datasets.CIFAR100(
    data_dir,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean= [ 0.5071, 0.4865, 0.4409 ],
            std = [ 0.2673, 0.2564, 0.2762 ]
        ),
    ]),
    target_transform=None,
    download=False)
cifar = torch.stack(list(zip(*cifar))[0])
cifar.size()

In [27]:
X = DataLoader(cifar[:512], batch_size=batch_size, shuffle=True)

In [28]:
model.fit(
    X,
    log_fn = lambda x:print(x),
    n_epochs=1
)

{'train_encoder_loss': nan, 'train_generator_loss': 42.84655, 'train_D_loss': 1.6198111e-06, 'train_C_loss': nan}

