Toy case, 

$X\sim N(\mu,\Sigma)$, $f(X)=[1,0]^T X+b$, $g(X)=BX+b$. Solve the following optimization problem 

$$ \min_{B,a,b} \text{HSIC}(f(X),g(X)) - (B - [1,0])^2 $$

In [None]:
from IPython.core.getipython import get_ipython
get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')


In [None]:
import numpy as np
import numpy.random as npr
import scipy.stats as stats

import matplotlib.pyplot as plt
from jaxkern import (
    rbf_kernel, linear_kernel, estimate_sigma_median, hsic, mmd, squared_l2_norm)

import jax
from jax.experimental import optimizers
from jax import grad, random
import jax.numpy as jnp

In [None]:
def plt_kde(X):
    fig, ax = plt.subplots()
    fig.set_size_inches(6,6)
    lim = np.min(X)-1, np.max(X)+1
    XX,YY = np.meshgrid(
        np.linspace(lim[0], lim[1], 100),
        np.linspace(lim[0], lim[1], 100))
    XY = np.vstack([XX.ravel(), YY.ravel()])
    kernel = stats.gaussian_kde(X.T)
    Z = kernel(XY).reshape(XX.shape)
    # ax.scatter(X[:,0], X[:,1], c='k', s=3)
    ax.imshow(Z, cmap='Blues',
              extent=[lim[0], lim[1], lim[0], lim[1]])
    ax.set_xlim([lim[0], lim[1]])
    ax.set_ylim([lim[0], lim[1]])
    return fig, ax

In [None]:
n = 1000
d = 2
lr = .2
num_steps = 1000
batch_size = 128

npr.seed(0)
mu  = np.array([1,1])
cov = np.array([[1,0], [0,1]])
Xdist = stats.multivariate_normal(mu, cov)
X = Xdist.rvs(size=(n,))
X = jax.device_put(X)

sigma = estimate_sigma_median(X)
gamma = 1/(2*(sigma**2))
print(f'sigma={sigma}')
print(f'gamma={gamma}')

def kernel(X,Y):
    return rbf_kernel(X,Y,gamma=gamma)

a1 = jnp.array([[.5,.5]], dtype=jnp.float32)
params = {
    'a2': a1.copy(),
}

def f(params, X):
    return jnp.dot(a1,X)
def g(params, X):
    return jnp.dot(params['a2'],X)
f = jax.vmap(f, (None, 0), 0)
g = jax.vmap(g, (None, 0), 0)

def loss_fn(params, X):
    fX = f(params, X)
    gX = g(params, X)
    loss = hsic(fX, gX, kernel, kernel) + squared_l2_norm(jnp.vstack((a1,params['a2']))@mu.T - mu.T)
    return loss
loss_fn = jax.jit(loss_fn)
grad_fn = jax.jit(grad(loss_fn))


num_complete_batches, leftover = divmod(n, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(n)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield X[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.sgd(lr)
opt_state = opt_init(params)

import itertools
itercount = itertools.count()

fig, axs = plt.subplots(())

print("\nStarting training...")
for i in range(num_steps):
    start_time = time.time()
    for j in range(num_batches):
        batch = next(batches)
        params = get_params(opt_state)
        fX = f(params, batch)
        gX = g(params, batch)
        fgX = jnp.hstack([fX, gX])
        l = loss_fn(params, batch) 
        params_grad = grad_fn(params, batch)
        opt_state = opt_update(next(itercount), params_grad, opt_state)
    epoch_time = time.time() - start_time
    
    if i%50 == 0:
        print(f'[{i:3}] time={epoch_time:4.4f}\t loss={l:5.8f}\t'
              f'a2={params["a2"]}')
        
        A = jnp.vstack((a1, params['a2']))
        ycov = A@cov@A.T
        print(ycov)
        
        plt.scatter(fgX[:,0], fgX[:,1], label=i)
        plt.xlim((-5,5))
        plt.ylim((-5,5))
    
plt.legend()
    
    


In [None]:

params = get_params(opt_state)
print(params)

fX = f(params, X)
gX = g(params, X)
fgX = jnp.hstack([fX, gX])



l = loss_fn(params, X) 

params_grad = grad_fn(params, X)
l, params_grad