Notebook to test kernel herding with a GAN generator.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kbrgan
import kbrgan.kernel as kernel
import kbrgan.glo as glo
import kbrgan.main as main
import kbrgan.plot as plot
import kbrgan.embed as embed
import kbrgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
use_cuda = True and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.set_default_tensor_type(torch.cuda.FloatTensor if use_cuda else torch.FloatTensor)


## Feature extractor for MNIST

In [None]:
classifier = torchvision.models.resnet18(pretrained=True)
classifier = classifier.eval()
classifier = classifier.to(device)
def extractor(imgs):
    """
    Feature extractor
    """
    self = classifier
    x=imgs   
#     up = nn.Upsample(size=224, mode='bilinear')
    up = nn.Upsample(size=96, mode='bilinear')
    x = up(x)
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)
#     print(x.shape)
    return x.view(-1, 64*24*24)
#     return x.view(-1, 64*16*16)
#     print(x.shape)
#     return x.view(-1, 64*112*112)
#     return x.view(-1, 64*56*56)
#     x = self.layer1(x)
#     x = self.layer2(x)
#     x = x.view(-1, 100352 )
#     return x
#     return x.view(-1, 64*56*56)
    
#x = self.layer3(x)

## A generator for MNIST

In [None]:
import torch.nn as nn
class Generator(nn.Module):
    '''
        Generative Network
    '''
    def __init__(self, dataset='celebA'):
        
        super(Generator, self).__init__()
        
        z_size=100
        out_size=3
        ngf=128
        
        self.z_size = z_size
        self.ngf = ngf
        self.out_size = out_size

        self.main = nn.Sequential(
            # input size is z_size
            nn.ConvTranspose2d(self.z_size, self.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(inplace=True),
            # state size: (ngf * 8) x 4 x 4
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(inplace=True),
            # state size: (ngf * 4) x 8 x 8
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(inplace=True),
            # state size: (ngf * 2) x 16 x 16
            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(inplace=True),
            # state size: ngf x 32 x 32
            nn.ConvTranspose2d(self.ngf, self.out_size, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size: out_size x 64 x 64
        )

        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, input):

        output = self.main(input)

        return output

    def load(self,f):
        """
        Load a Generator from a file. To be used with save().
        """
        self.load_state_dict(torch.load((f)))
        return True

In [None]:
import kbrgan.mnist.dcgan as mnist_dcgan

gan_fname = 'GAN_{}_G.pkl'.format(24)
gan_fpath = glo.prob_model_folder('lsun_dcgan', gan_fname)

# load a model
g = Generator()
g.load(gan_fpath)

latent_dim = 100
f_noise = lambda n: torch.randn(n, latent_dim).float()

#sample_z_ = Variable(torch.rand((batch_size, z_dim)).view(-1, z_dim, 1, 1), requires_grad=False).cuda()

## Optimize points jointly to minimize the moment matching loss

With a GAN generator. Optimize in the latent space.

In [None]:
def sample_from_dir(dir_data, num_sample):
    list_selected = []
    #labels = np.array([data[i][1] for i in range(len(data))])
    for item in range(num_sample):
        list_selected.extend(homo_data)
    # stack all
    selected = torch.stack(list_selected)
    return selected

In [None]:
from PIL import Image
num_input = 1
test_img_dir = glo.data_file('test_bedroom')
imgs = []
for path in np.random.permutation(os.listdir(test_img_dir))[0:num_input]:
    img = Image.open(os.path.join(test_img_dir,path))

    img = img.resize((64,64))
    img = np.transpose(np.array(img),(2,0,1))
    imgs.append(img/255.0)
    
#img = np.transpose(np.array(img),(2,0,1)).reshape((1,3,64,64))
X = torch.tensor(imgs).float()
X = X.to(device)
n = X.shape[0]

# A vector of weights for all the points in X
weights = torch.ones(n)/float(n)
weights = weights.to(device)
plot.show_torch_imgs(X)

In [None]:
# initial points in the latent space
# n_sample = 2*2**3
n_sample = 8
# noise vectors
Z = f_noise(n_sample).view(-1, latent_dim, 1, 1)
Z = Z.to(device)

Z.requires_grad = True
Y0 = g(Z)

# plot the initial points in the image space
plot.show_torch_imgs(Y0.detach(), nrow=8, figsize=(12,6))

In [None]:
# optimizer = torch.optim.SGD([Y], lr=5e-3)
optimizer = torch.optim.RMSprop([Z], lr=2e-2)
# optimizer = torch.optim.Adam([Z], lr=1e-2)
# kernel on the extracted features

# k = kernel.PTKPoly(c=1e-1, d=2)
# k = kernel.PTKIMQ()
# k = kernel.PTKIMQ(c=1e+0, b=-0.5)
# k = kernel.PTKLaplace(sigma=1)
# k = kernel.PTKLinear()
k = kernel.PTKL1Distance(sigma=2e+3)

# kernel on the latent noise vectors
# k = kernel.PTKFuncCompose(kgauss, classifier)

# pre-extract the features of X. Fixed throughout the optimization

with torch.no_grad():
    FX = extractor(X)

# med = util.meddistance(FX.detach().cpu().numpy(), subsample=1000)
# k = kernel.PTKGauss(sigma2=med**2)

In [None]:
# optimization
n_iter = 1200
losses = []
sample_interval = 300
# avgpool = torch.nn.AvgPool2d(5, stride=1, padding=0)

mean_KFX = torch.mean(k.eval(FX, FX))
for t in range(n_iter):
    # need to resize since Mnist uses 28x28. The generator generates 32x32
    gens = g(Z)
#     resized = torch.stack([resize_gen_img(I) for I in gens], 0)
#     resized = avgpool(gens)
#     plot.show_torch_imgs(resized)
    F_gz = extractor(gens)
    KF_gz = k.eval(F_gz, F_gz)
#     print(KF_gz)
    
    # encourage the latent noise vectors to concentrate around 0
    Z_reg = 1e-2*torch.mean(torch.mean(Z**2, 1))
#     Z_reg = -torch.mean(torch.log(4.0**2-Z**2))
    loss = torch.mean(KF_gz)  - 2.0*torch.mean(k.eval(F_gz, FX).mv(weights)) + mean_KFX  + Z_reg
    losses.append(loss.item())
    
    optimizer.zero_grad()
    
    # compute the gradients
    loss.backward(retain_graph=True)
    # updates
    optimizer.step()
    
    #--------- plots the generated images ----
    if t%sample_interval==0:
        with torch.no_grad():
            gen = g(Z.detach().clone())
#             gen = Z.detach().clone()
#             gen = Z.grad.detach().clone()
            plot.show_torch_imgs(gen, figsize=(12, 6))
            plt.show()
    

In [None]:
# input points
figsize = (12, 6)
plot.show_torch_imgs(X)
plt.title('Input')
plot.show_torch_imgs(Y0.detach(), nrow=8, figsize=figsize)
plt.title('Initialized')
plot.show_torch_imgs(gen, figsize=figsize)
plt.title('Output')

In [None]:
plt.plot(losses)
plt.xlabel('Optimization iteration')
plt.ylabel('Herding loss')

In [None]:
plt.imshow(Z.detach().cpu().squeeze().numpy())
plt.colorbar()