A notebook to test the idea of distributional autoencoders i.e., autoencoders which take as input a distribution and map it to a representation in ways that preserve the information in the distribution. In practice, a distribution can be represented by a finite collection of points.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kbrgan
import kbrgan.kernel as kernel
import kbrgan.glo as glo
import kbrgan.main as main
import kbrgan.plot as plot
import kbrgan.embed as embed
import kbrgan.net as net
import kbrgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
use_cuda = True and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
# torch.set_default_tensor_type(tensor_type)

## CIFAR10 data

In [None]:
import kbrgan.cifar10.util as cifar10_util

In [None]:
# load data
trdata_folder = glo.data_file('cifar10')
trdata = torchvision.datasets.CIFAR10(trdata_folder, train=True, download=True,
                        transform=transforms.Compose([
                           transforms.ToTensor(),
#                            transforms.Normalize((0.1307,), (0.3081,))
                       ]))

In [None]:
# pixel intensity range
stats.describe(trdata[2][0].numpy().reshape(-1))

In [None]:
# see the data
ntr = trdata.train_data.shape[0]
img_size = trdata.train_data.shape[1:]

# randomly select a few images
k = 3*8
inds = np.random.choice(ntr, size=k, replace=False)
xs = [trdata[i][0] for i in inds] 
# classes
ys = [trdata[i][1] for i in inds] 

# x = x.to(device)
print('{} randomly chosen images:'.format(k))
plot.show_torch_imgs(xs, figsize=(10, 6))

In [None]:
print('CIFAR10 classes and their indices:')
cifar10_class_inds = cifar10_util.label_class_list()
display(cifar10_class_inds)

In [None]:
# Pick only some classes for simplicity
# classes = [1, 9]
classes = [1]
# classes = list(range(10))
# numpy arrays
X = trdata.train_data
Y = np.array(trdata.train_labels)

# filter data according to the chosen classes
tr_inds = [Y[i] in classes for i in range(len(Y)) ]
Xtr = X[tr_inds]
Ytr = Y[tr_inds]

In [None]:
# normalize the range to be from min to max
minmax = (0.0, 1.0)
0c
Tr = torch.utils.data.TensorDataset(torch.tensor(Xtr.transpose(0, 3, 1, 2), device='cpu', dtype=torch.float), 
                                    torch.tensor(Ytr, device='cpu', dtype=torch.float))

batch_size = 2**8
train_loader = torch.utils.data.DataLoader(Tr, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# see the data of the selected classes

# randomly select a few images
k = 4*8
nTr = Tr.tensors[0].shape[0]
inds = np.random.choice(nTr, size=k, replace=False)
xs = [Tr[i][0] for i in inds] 
# classes
ys = [Tr[i][1] for i in inds] 

# x = x.to(device)
print('{} randomly chosen images:'.format(k))
plot.show_torch_imgs(xs, figsize=(12, 6), normalize=False)

In [None]:
# pixel intensity range
stats.describe(Tr[8][0].numpy().reshape(-1))

In [None]:
import argparse
parser = argparse.ArgumentParser(description='Train a DCGAN on CIFAR10')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
parser.add_argument('--prob_model_dir', type=str, help='interval between image sampling')
parser.add_argument('--classes', type=int, help='a list of integers (0-9) denoting the classes to consider', nargs='+')

args = parser.parse_args(['--n_epochs', '2'])
vars(args)

In [None]:
''.join(map(str, [2,3,1]))

In [None]:
sorted([3,2,6])

## Train a distributional autoencoder

In [None]:
class Extractor1(net.SerializableModule):
    def __init__(self, channels=3, minmax=(0.0, 1.0)):
        super(Extractor1, self).__init__()
        self.minmax = minmax
        def conv_leaky_max(in_filters, out_filters, bn=True):
            block = [   nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=2, padding=1),
                        nn.LeakyReLU(0.2, inplace=True),
                      ]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *conv_leaky_max(channels, 16, bn=False), # input = 32x32
            *conv_leaky_max(16, 32), # 
            *conv_leaky_max(32, 64), # 
            *conv_leaky_max(64, 96), # output 2x2
            nn.MaxPool2d(kernel_size=2, stride=1, padding=0),
        )

    def forward(self, img):
        # normalize image value range to be in [-1, 1]
        minmax = self.minmax
        mi, ma = minmax[0], minmax[1]
        img = (img - mi)/float(ma - mi)*2.0 - 1
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        # print(out.shape)
        return out


hyperparameters for the training

In [None]:
# number of epochs
n_epochs = 80

# a function to return the number of points to draw to construct
# an empirical distribution. Points are drawn from a minibatch.
# The range of this function should be positive integers.
func_subbatch_size = lambda n: 1+stats.poisson.rvs(mu=5, size=n)

# create a network
network = Extractor1(channels=3, minmax=(0.0, 1.0))
network = network.to(device)

# output dimension of the network
output_dim = network(Tr[[0]][0].to(device)).shape[1]

# optimizer
optimizer = torch.optim.Adam(network.parameters(), lr=1e-2)

# number of times to sample empirical distributions per minibatch
n_sample_per_minibatch = 10

# regularization parameter
reg = 1e-2

print('output dimension: ', output_dim)

In [None]:
Iden = torch.eye(output_dim, dtype=torch.float, device=device)
list_losses = []
# training
for epoch in range(n_epochs):
    for batch_idx, (batch, _) in enumerate(train_loader):
        BX = batch.to(device)
        subbatch_sizes = func_subbatch_size(n_sample_per_minibatch)
        # minibatch mean embedding
        BY = network(BX)
        batch_embed = torch.mean(BY, dim=0)
        # orthogonality constraint
        ortho_penalty = torch.sum((BY.t().mm(BY) - Iden)**2)   
        
        minibatch_loss = 0
        for si in range(n_sample_per_minibatch):
            subbatch_inds = np.random.choice(BX.shape[0], subbatch_sizes[si], replace=False)
            subBX = BX[subbatch_inds]     

            # loss: averaged MMD on subbatch embedding
            sub_embed = torch.mean(network(subBX), dim=0)
            # subbatch mean embedding
            sub_loss = torch.sum((batch_embed - sub_embed)**2) 
            minibatch_loss += sub_loss/float(n_sample_per_minibatch)
            
        minibatch_loss += reg*ortho_penalty
        optimizer.zero_grad()
        minibatch_loss.backward()
        # update the parameters
        optimizer.step()        
        list_losses.append(minibatch_loss.item())
#                 print(sub_embed[:10])


In [None]:
inds = list(range(len(list_losses)))[50:]
losses = np.array(list_losses)
plt.plot(inds, losses[inds], label='tr-loss')
plt.xlabel('#minibatch update')
plt.ylabel('Loss')
plt.legend()

In [None]:
ortho_penalty