Notebook to test kernel herding (solving the kernel moment matching sequentially) with a GAN generator.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import cadgan
import cadgan.kernel as kernel
import cadgan.glo as glo
import cadgan.main as main
import cadgan.plot as plot
import cadgan.embed as embed
import cadgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
use_cuda = True and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.set_default_tensor_type(torch.cuda.FloatTensor if use_cuda else torch.FloatTensor)

## Feature extractor for MNIST

In [None]:
# load the model
import cadgan.mnist.classify as mnist_classify
from cadgan.mnist.classify import MnistClassifier


classifier = mnist_classify.MnistClassifier(load=True)
classifier = classifier.eval()
classifier = classifier.to(device)
# classifier = classifier.cuda()

def extractor(imgs):
    """
    Feature extractor
    """
#     return classifier.features(imgs)
    self = classifier
    x = imgs
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2(x), 2))

#     x = x.view(-1, 10*12*12)
    x = x.view(-1, 320)
#     x = x.view(-1)
#     x = F.relu(self.fc1(x))
    return x


In [None]:
# load MNIST data
mnist_folder = glo.data_file('mnist')
mnist_dataset = torchvision.datasets.MNIST(mnist_folder, download=True,train=False, 
                        transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
print(mnist_dataset)

In [None]:
xy = mnist_dataset[92]
x = xy[0]
x = x.unsqueeze(0)
x = x.to(device)

# plot
xnp = np.transpose(xy[0].numpy(), (1, 2, 0))
xnp = xnp.squeeze()
plt.imshow(xnp)
print('features: ', classifier(x))

## A generator for MNIST

In [None]:
import cadgan.mnist.dcgan as mnist_dcgan
import cadgan.mnist.util as mnist_util

# load a model
g = mnist_dcgan.Generator(load=True)

latent_dim = 100
f_noise = lambda n: torch.randn(n, latent_dim).float()

## Optimize points jointly to minimize the moment matching loss

With a GAN generator. Optimize in the latent space.

In [None]:
# label_counts = [(1, 3), (9, 3)]
# label_counts = [(6, 2), (8, 2)]
label_counts = [(6, 4)]
# label_counts = [(1,1), (2,1), (3,1), (4,1)]
# label_counts = [(i, 5) for i in range(10)]
# label_counts = [(0, 6), (5, 3)]
X = mnist_util.pt_sample_by_labels(mnist_dataset, label_counts)
X = X.to(device)
n = X.shape[0]

# A vector of weights for all the points in X
weights = torch.ones(n)/float(n)
weights = weights.to(device)
plot.show_torch_imgs(X)

In [None]:
# kernel on the extracted features
# k = kernel.PTKGauss(sigma2=50.0)
# k = kernel.PTKPoly(c=1e-1, d=2)
# k = kernel.PTKIMQ(c=1e+1, b=-0.5)
# k = kernel.PTKIMQ()
k = kernel.PTKLinear()

# kernel on the latent noise vectors
# k = kernel.PTKFuncCompose(kgauss, classifier)

# pre-extract the features of X. Fixed throughout the optimization
with torch.no_grad():
    FX = extractor(X)

In [None]:
X.shape

In [None]:
# initial points in the latent space
# n_sample = 2*2**3
# n_sample = 2*8
n_sample = 6
# noise vectors
Z = f_noise(n_sample)
Z = Z.to(device)

Z.requires_grad = True
Y0 = g(Z)

# plot the initial points in the image space
plot.show_torch_imgs(Y0.detach(), nrow=8)

Kernel herding (sequentially solving the kernel moment matching problem) with a generator

In [None]:
# def kernel_generator_herding(X, weights, g, extractor, 
#                              k, Z0, fn_make_optimizer=None, n_iter=200, ):
"""
X: Pytorch tensor containing samples in the target mean embedding
weights: Pytorch vector containing weights for the samples in X
g: an instance of cadgan.gen.PTNoiseTransformer representing a generator
extractor: a feature extractor
k: a kernel. An instance of PTKernel. This is used on top of the outputs from the extractor.
Z0: Pytorch tensor containing initial noise vectors to be optimized further.
    Z0[i, :] is point i. Each point is used in order.
fn_make_optimizer: a function: params -> a torch.optim.XXX optimizer. 
    A function that constructs an optimizer from a list of parameters.
n_iter: number of iterations for optimizing each y_i 

Return (Y, Y0), 
    Y: a Pytorch tensor of size n_sample x dim. Optimization result
    Y0: a Pytorch tensor of size n_sample x dim. Initial points picked
"""
Z0 = Z
n_iter = 200

In [None]:
if n_sample <= 0:
    raise ValueError('n_sample must be > 0. Was {}'.format(n_sample))
# if fn_make_optimizer is None:
fn_make_optimizer = lambda params: torch.optim.Adam(params, lr=5e-2)

g = g.eval()
n = X.shape[0]
# a stack of all initial points
Y0 = []
# first iteration. Initialize.
z1 = Z0[[0]].detach().clone()
z1.requires_grad = True
y1 = g(z1)
Y0.append(y1.detach().clone())

# pre-extract features of X
with torch.no_grad():
    FX = extractor(X)

Losses = np.zeros((n_sample, n_iter))
# mean_KFX = torch.mean(k.eval(FX, FX))

# optimization for the first point
optimizer1 = fn_make_optimizer([z1])
reg = 1e-3
for it in range(n_iter):
    y1 = g(z1)
    fea1 = extractor(y1)
    z1_reg = reg*torch.sum(z1**2)
    loss1 = -2.0*k.eval(fea1, FX).mv(weights) + k.eval(fea1, fea1).reshape(-1) + z1_reg
    
    Losses[0, it] = loss1.item()
    # optimize z1
    optimizer1.zero_grad()

    # compute the gradients
    loss1.backward(retain_graph=True)
    # updates
    optimizer1.step()

# tensor to store the optimized points
Y = torch.cat([y1], dim=0)
# extracted features of points until iteration t-1
FY = torch.cat([fea1], dim=0)

# optimized Z
Z = torch.cat([z1], dim=0)
for t in range(2, n_sample+1):
    zt = Z0[[t-1]].detach().clone()
    zt.requires_grad = True
    yt = g(zt)
    Y0.append(yt.clone())

    optimizert = fn_make_optimizer([zt])  
    # optimization loop
    for it in range(n_iter):
        yt = g(zt)
        feat = extractor(yt)
        zt_reg = reg*torch.sum(zt**2)
        # optimize the rest of y2, ...y_{n_sample}
        losst =  - (2.0/t)*k.eval(feat, FX).mv(weights) \
            + (2.0/t**2)*torch.sum(k.eval(FY, feat)) + (1.0/t**2)*k.eval(feat, feat).reshape(-1) \
            + zt_reg

        Losses[t-1, it] = losst.item()
#             print(losst.item())
        # optimize zt
        optimizert.zero_grad()
        losst.backward(retain_graph=True)
        optimizert.step()

    # Now we have yt. Add it to the current set Y
    Y = torch.cat([Y, yt], dim=0)
    FY = torch.cat([FY, feat], dim=0)
    Z = torch.cat([Z, zt], dim=0)

assert Y.shape[0] == n_sample
assert FY.shape[0] == n_sample
Y0 = torch.cat(Y0, 0)
#     return Y, Y0
            
    

In [None]:
plt.plot(Losses.T)
plt.xlabel('Optimization iteration')
plt.ylabel('Herding loss')

In [None]:
# input points
plot.show_torch_imgs(X)
plt.title('Input')

plot.show_torch_imgs(Y0.detach(), nrow=8)
plt.title('Initialized')

gen = Y.detach().cpu()
plot.show_torch_imgs(gen)
plt.title('Output')

In [None]:
plt.imshow(Z.cpu().detach().numpy())
plt.colorbar()