In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kbrgan
import kbrgan.kernel as kernel
import kbrgan.main as main
import kbrgan.emb as emb
import kbrgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import autograd.numpy as np
import scipy.stats as stats
import torch

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Kernel embedding in 2D

In [None]:
d = 2
n = 500
seed = 7
torch.manual_seed(seed);

In [None]:
X = torch.randn(n, d)
Xnp = X.numpy()
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko');

In [None]:
med = util.meddistance(X.numpy(), subsample=1000)
print('median distance: ', med)
# Gaussian kernel
k = kernel.PTKGauss(sigma2=med**2/2)
# create a mean embedding
em = emb.PTImplicitKEmb(k, X)

In [None]:
"""
Sample from a mean embedding by performing kernel moment matching.
""";
n_sample = 20

# Y = stack of output samples to be optimized.
# Initialize by picking a subset from X
Y = torch.tensor(X[np.random.choice(n, n_sample)] + torch.randn(n_sample, 1)*0.1, requires_grad=True)

In [None]:
# optimizer = torch.optim.Adam([Y], lr=1e-3)
optimizer = torch.optim.SGD([Y], lr=5e-3)
# optimizer = torch.optim.RMSprop([Y], lr=1e-3)
# optimization
n_iter = 1000
losses = []
for t in range(n_iter):
    loss = ( torch.mean(k.eval(Y, Y)) - 2.0*torch.mean(k.eval(Y, X).mv(em.weights)) )
    losses.append(loss.item())
    optimizer.zero_grad()
    
    # compute the gradients
    loss.backward()
    # updates
    optimizer.step()
    

In [None]:
plt.plot(np.arange(n_iter)+1, losses, 'b-')
plt.ylabel('Moment matching loss')
plt.xlabel('Iteration')

In [None]:
Ynp = Y.detach().numpy()

plt.figure(figsize=(8, 5))
plt.plot(Xnp[:, 0], Xnp[:, 1], 'ko', label='Data', alpha=0.5);
plt.plot(Ynp[:, 0], Ynp[:, 1], '*b', markersize=11, alpha=0.7, label='Optimized')
plt.legend()

In [None]:
Ynp