In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kbrgan
import kbrgan.kernel as kernel
import kbrgan.main as main
import kbrgan.embed as embed
import kbrgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import autograd.numpy as np
import scipy.stats as stats
import torch

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Kernel embedding in 2D

In [None]:
d = 2
n = 500
seed = 7
torch.manual_seed(seed);

In [None]:
X = torch.randn(n, d)
Xnp = X.numpy()
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko');

In [None]:
med = util.meddistance(X.numpy(), subsample=1000)
print('median distance: ', med)
# Gaussian kernel
k = kernel.PTKGauss(sigma2=med**2/2)
# create a mean embedding
em = embed.PTImplicitKEmb(k, X)

## Optimize points jointly to minimize the moment matching loss

In [None]:
"""
Sample from a mean embedding by performing kernel moment matching.
""";
n_sample = 30

# Y = stack of output samples to be optimized.
# Initialize by picking a subset from X
Y = torch.tensor(X[np.random.choice(n, n_sample)] + torch.randn(n_sample, 1)*0.1, requires_grad=True)

In [None]:
Ynp = Y.detach().numpy()

plt.figure(figsize=(8, 5))
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko', label='Data', alpha=0.4);
plt.plot(Ynp[:, 0], Ynp[:, 1], '^b', markersize=11, alpha=0.7, label='Initialized')
plt.legend()

In [None]:
# optimizer = torch.optim.Adam([Y], lr=1e-3)
# optimizer = torch.optim.SGD([Y], lr=5e-3)
optimizer = torch.optim.RMSprop([Y], lr=1e-3)
# optimization
n_iter = 1000
losses = []
for t in range(n_iter):
    loss = ( torch.mean(k.eval(Y, Y)) - 2.0*torch.mean(k.eval(Y, X).mv(em.weights)) )
    losses.append(loss.item())
    optimizer.zero_grad()
    
    # compute the gradients
    loss.backward()
    # updates
    optimizer.step()
    

In [None]:
plt.plot(np.arange(n_iter)+1, losses, 'b-')
plt.ylabel('Moment matching loss')
plt.xlabel('Iteration')

In [None]:
Ynp = Y.detach().numpy()

plt.figure(figsize=(8, 5))
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko', label='Data', alpha=0.4);
plt.plot(Ynp[:, 0], Ynp[:, 1], '^b', markersize=11, alpha=0.7, label='Optimized')
plt.legend()

## Kernel herding (greedy optimization)

In [None]:
def kernel_herding(emb, n_sample, fn_make_optimizer=None, n_iter=200, ):
    """
    emb: PTImplicitKEmb to sample from
    n_sample: number of points to sample
    fn_make_optimizer: a function: params -> a torch.optim.XXX optimizer. 
        A function that constructs an optimizer from a list of parameters.
    n_iter: number of iterations for optimizing each y_i 
    
    Return (Y, Y0), 
        Y: a Pytorch tensor of size n_sample x dim. Optimization result
        Y0: a Pytorch tensor of size n_sample x dim. Initial points picked
    """
    if n_sample <= 0:
        raise ValueError('n_sample must be > 0. Was {}'.format(n_sample))
    if fn_make_optimizer is None:
        fn_make_optimizer = lambda params: torch.optim.RMSprop(params, lr=1e-3)
        
    def pick_one_row(X):
        n = X.shape[0]
        return torch.tensor(X[np.random.choice(n,1)] + torch.randn(1)*0.1, requires_grad=True)
    X = emb.samples
    n = X.shape[0]
    k = emb.get_kernel()
    
    # a stack of all initial points
    Y0 = []
    # first iteration. Initialize by randomly picking a point in X.
    y1 = pick_one_row(X)
    Y0.append(y1.clone())
#     y1 = y1.unsqueeze(0)
    assert y1.ndimension() == 2, 'dim of y1 was {}'.format(y1.ndimension())
    
    optimizer1 = fn_make_optimizer([y1])
    for it in range(n_iter):
        loss1 = 2.0*emb.eval(y1).reshape(-1) - k.eval(y1, y1).reshape(-1)
        # optimize y1
        optimizer1.zero_grad()

        # compute the gradients
        loss1.backward()
        # updates
        optimizer1.step()
        
    Y = torch.cat([y1], dim=0)
    for t in range(2, n_sample+1):
        yt = pick_one_row(X)
        Y0.append(yt.clone())
        # add a dimension on axis=0
#         yt = yt.unsqueeze(0)
        
        optimizert = fn_make_optimizer([yt])
        
        # optimization loop
        for it in range(n_iter):
            # optimize the rest of y2, ...y_{n_sample}
            losst = 2.0*torch.sum(emb.eval(Y)) \
                - (2.0/t)*torch.sum(k.eval(Y, yt)) - (1.0/t)*k.eval(yt, yt).reshape(-1)
#             print(losst.item())
            # optimize yt
            optimizert.zero_grad()
            losst.backward()
            optimizert.step()
        
        # Now we have yt. Add it to the current set Y
        Y = torch.cat([Y, yt], dim=0)
        
    assert Y.shape[0] == n_sample
    Y0 = torch.cat(Y0, 0)
    return Y, Y0
            
    

In [None]:
n_sample = 20
# number of optimization iterations for each point yt
n_iter = 300

def fn_make_optimizer(params):
#     return torch.optim.RMSprop(params, lr=1e-3)
    return torch.optim.Adam(params, lr=1e-3)
#     return torch.optim.SGD(params, lr=1e-3)
Y_greedy, Y0 = kernel_herding(em, n_sample, fn_make_optimizer, n_iter)

In [None]:
Ynp = Y_greedy.detach().numpy()
Y0np = Y0.detach().numpy()

plt.figure(figsize=(8, 5))
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko', label='Data', alpha=0.4);
plt.plot(Y0np[:, 0], Y0np[:, 1], 'sb', markersize=11, alpha=0.7, label='Initial')
plt.plot(Ynp[:, 0], Ynp[:, 1], '^r', markersize=11, alpha=0.7, label='Optimized')
plt.legend()