Notebook to explore generative models for CIFAR10 (or subsets of it)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import cadgan
import cadgan.kernel as kernel
import cadgan.glo as glo
import cadgan.main as main
import cadgan.plot as plot
import cadgan.embed as embed
import cadgan.net as net
import cadgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
use_cuda = True and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
# torch.set_default_tensor_type(tensor_type)

## CIFAR10 data

In [None]:
import cadgan.cifar10.util as cifar10_util
print('CIFAR10 classes and their indices:')
cifar10_class_inds = cifar10_util.label_class_list()
display(cifar10_class_inds)

In [None]:
# classes = [0,]
classes = [1]
Tr = cifar10_util.load_cifar10_class_subsets(classes)

In [None]:
# see the data of the selected classes

# randomly select a few images
k = 3*8
nTr = Tr.tensors[0].shape[0]
inds = np.random.choice(nTr, size=k, replace=False)
xs = [Tr[i][0] for i in inds] 
# classes
ys = [Tr[i][1] for i in inds] 

# x = x.to(device)
print('{} randomly chosen images:'.format(k))
plot.show_torch_imgs(xs, figsize=(12, 6), normalize=False)

## Load generators

These generators are trained using code in `cadgan.cifar10`. For instance, `cadgan.cifar10.dcgan` can be executed to train DCGAN generators.

In [None]:
n_epochs = 300
batch_size = 32
class_summary = ''.join(map(str, classes))
folder_name = 'cifar10_c{}-dcgan'.format(class_summary)
model_name = '{}-ep{}_bs{}.pt'.format(folder_name, n_epochs, batch_size)
model_path = glo.prob_model_folder(folder_name, model_name)
print('model path: ', model_path)

In [None]:
if not os.path.exists(model_path):
    #use this if you want to train them manually
    #!python ../cadgan/cifar10/dcgan.py --classes=1 --n_epochs=300 --batch_size=32
    
    #or download from google drive
    #note that you need to find file id for the corresponding model file
    #1vst7cCckUaIKNbYxVMNhgN70pMjXP85e == cifar10_c1-dcgan-ep300_bs32.pt
    from google_drive_downloader import GoogleDriveDownloader as gdd
    gdd.download_file_from_google_drive(file_id='1vst7cCckUaIKNbYxVMNhgN70pMjXP85e',
                                        dest_path=model_path)

In [None]:
# load the model
import cadgan.cifar10.dcgan as cifar10_dcgan
from cadgan.cifar10.dcgan import PatsornGenerator1
generator = torch.load(model_path, map_location=lambda storage, loc: storage)
print(generator)

Sample from the generator

In [None]:
n_sample = 4*8
with torch.no_grad():
    Xsam = generator.eval().sample(n_sample)

print('{} sampled images from the model'.format(n_sample))
plot.show_torch_imgs(Xsam, figsize=(12, 6), normalize=False)

Interpolation in the latent space

In [None]:
with torch.no_grad():
    gev = generator.eval()
    z_from = gev.sample_noise(1)
    z_to = gev.sample_noise(1)

In [None]:
# linear interpolation
a = torch.linspace(0, 1, 4*8)
A = torch.stack((a, 1.0-a), 1)
Z_start = torch.cat((z_from, z_to))
Z = A.mm(Z_start)

# sample from the generator using the interpolated noise vectors
with torch.no_grad():
    Xsam = gev(Z)
    
# show the images
plot.show_torch_imgs(Xsam, figsize=(12, 6))