In [None]:
import numpy as np
import torch
import torchvision
from torchvision.datasets import CIFAR10, CIFAR100

from connectivity_representation_learning import *

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

### Load CIFAR-10 data

In [None]:
def imshow(img, ax=None):
    npimg = img.cpu().data.numpy()
    if not ax:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    else:
        ax.imshow(np.transpose(npimg, (1, 2, 0)))

In [None]:
cifar_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()])
    
cifar = CIFAR10(
    root='../datasets',
    train=True,
    transform=cifar_transform,
    download=False)

In [None]:
# only select a subset to speed-up learning
n = 1000
loader = torch.utils.data.DataLoader(cifar, batch_size=n, shuffle=False)
dataiter = iter(loader)
data, _ = dataiter.next()

### Model setup

In [None]:
config_layers = {
    'type': 'conv2d',
    'input_size': (32, 32),
    'filters': [3, 16, 32, 64],
    'emb_size': 160,
}

use_cuda = True
eta = 2.0
tol = 1e-4
lr = 0.001
batch_size = 10
n_epochs = 10

### Training

In [None]:
model_no_penalty = Model(config_layers, 
                         eta=eta,
                         tol=tol,
                         lr=lr,
                         batch_size=batch_size,
                         dim_batch=1,
                         use_cuda=use_cuda,
                         connectivity_penalty=0.0,
                        )

model_no_penalty.train(data, n_epochs*10)
torch.save(model_no_penalty.state_dict(), 'cifar10_no_penalty.pt')

model_penalty = Model(config_layers, 
                      eta=eta,
                      tol=tol,
                      lr=lr,
                      batch_size=batch_size,
                      dim_batch=1,
                      use_cuda=use_cuda,
                      connectivity_penalty=20.0,
                     )

model_penalty.train(data, n_epochs)
torch.save(model_penalty.state_dict(), 'cifar10_penalty.pt')

model_penalty_branches = Model(config_layers, 
                               eta=eta,
                               tol=tol,
                               lr=lr,      
                               batch_size=batch_size,
                               dim_batch=16,
                               use_cuda=use_cuda,
                               connectivity_penalty=20.0,
                              )

model_penalty_branches.train(data, n_epochs)
torch.save(model_penalty_branches.state_dict(), 'cifar10_penalty_branches.pt')

In [None]:
# model_no_penalty.load_state_dict(torch.load('cifar10_no_penalty.pt'))
# model_penalty.load_state_dict(torch.load('cifar10_penalty.pt'))
# model_penalty_branches.load_state_dict(torch.load('cifar10_penalty_branches.pt'))

models = [model_no_penalty, model_penalty, model_penalty_branches]

### Reconstruction of training images

In [None]:
fig, axess = plt.subplots(figsize=(12, 5), nrows=3, ncols=2)

for i, axes in enumerate(axess):
    images = data[:4]
    model = models[i]

    ax = axes[0]
    imshow(torchvision.utils.make_grid(images), ax)

    ax = axes[1]
    images_reconstructed = model.autoencoder(images.to(model.device))
    imshow(torchvision.utils.make_grid(images_reconstructed), ax)

plt.tight_layout()
plt.show()

### Reconstruction of unseen images

In [None]:
loader = torch.utils.data.DataLoader(cifar, batch_size=4, shuffle=True)
dataiter = iter(loader)
images, _ = dataiter.next()

fig, axess = plt.subplots(figsize=(12, 5), nrows=3, ncols=2)

for i, axes in enumerate(axess):
    images = data[:4]
    model = models[i]

    ax = axes[0]
    imshow(torchvision.utils.make_grid(images), ax)

    ax = axes[1]
    images_reconstructed = model.autoencoder(images.to(model.device))
    imshow(torchvision.utils.make_grid(images_reconstructed), ax)

plt.tight_layout()
plt.show()