In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kbrgan
import kbrgan.kernel as kernel
import kbrgan.main as main
import kbrgan.embed as embed
import kbrgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import autograd.numpy as np
import scipy.stats as stats
import torch

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Kernel embedding in 2D

In [None]:
d = 2
n = 500
seed = 7
torch.manual_seed(seed);

In [None]:
X = torch.randn(n, d) + torch.tensor([-3, 3.0])
Xnp = X.numpy()
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko');

In [None]:
med = util.meddistance(X.numpy(), subsample=1000)
print('median distance: ', med)
# Gaussian kernel
k = kernel.PTKGauss(sigma2=med**2/2)
# k = kernel.PTKPoly(c=1.0, d=2)
# k = kernel.PTKLinear()
# create a mean embedding
em = embed.PTImplicitKEmb(k, X)

## Optimize points jointly to minimize the moment matching loss

In [None]:
"""
Sample from a mean embedding by performing kernel moment matching.
""";
n_sample = 10

# Y = stack of output samples to be optimized.
# Initialize by picking a subset from X
# Y = torch.tensor(X[np.random.choice(n, n_sample)] + torch.randn(n_sample, d)*0.1, requires_grad=True)

# initialize randomly
Y = torch.tensor(np.random.randn(n_sample, d), dtype=torch.float, requires_grad=True)
Y0 = Y.detach().clone()
Y0np = Y0.numpy()

In [None]:
# optimizer = torch.optim.Adam([Y], lr=1e-2)
# optimizer = torch.optim.SGD([Y], lr=5e-3)
optimizer = torch.optim.RMSprop([Y], lr=5e-2)
# optimization
n_iter = 500
losses = []
Kxx_term = k.eval(X, X).mv(em.weights).dot(em.weights)
for t in range(n_iter):
    # regularization to prevent things from blowing up
#     reg = 1e-5*torch.mean(torch.sum(Y**2, 1))
    reg = 0
    loss =  torch.mean(k.eval(Y, Y)) - 2.0*torch.mean(k.eval(Y, X).mv(em.weights))  + Kxx_term + reg
    losses.append(loss.item())
    optimizer.zero_grad()
    
    # compute the gradients
    loss.backward()
    # updates
    optimizer.step()
    

In [None]:
plt.plot(np.arange(n_iter)+1, losses, 'b-')
plt.ylabel('Moment matching loss')
plt.xlabel('Iteration')

In [None]:
Ynp = Y.detach().clone().numpy()
plt.figure(figsize=(8, 6))
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko', label='Data', alpha=0.4);
plt.plot(Y0np[:, 0], Y0np[:, 1], 'sb', markersize=11, alpha=0.7, label='Initial')
plt.plot(Ynp[:, 0], Ynp[:, 1], '^r', markersize=11, alpha=0.7, label='Optimized')
plt.legend()

## Kernel herding (greedy optimization)

In [None]:
n_sample = 10
# number of optimization iterations for each point yt
n_iter = 500

def fn_make_optimizer(params):
    return torch.optim.RMSprop(params, lr=5e-2)
#     return torch.optim.Adam(params, lr=1e-2)
#     return torch.optim.SGD(params, lr=1e-3)
Y_greedy, Y0 = embed.kernel_herding(em, n_sample, fn_make_optimizer, n_iter)

In [None]:
Ynp = Y_greedy.detach().numpy()
Y0np = Y0.detach().numpy()

plt.figure(figsize=(8, 6))
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko', label='Data', alpha=0.4);
plt.plot(Y0np[:, 0], Y0np[:, 1], 'sb', markersize=11, alpha=0.7, label='Initial')
plt.plot(Ynp[:, 0], Ynp[:, 1], '^r', markersize=11, alpha=0.7, label='Optimized')
plt.legend()