Toy case, 

$X\sim N(\mu,\Sigma)$, $f(X)=[1,0]^T X+b$, $g(X)=BX+b$. Solve the following optimization problem 

$$ \min_{B,a,b} \text{HSIC}(f(X),g(X)) - B^TB $$

In [None]:
from IPython.core.getipython import get_ipython
get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')


In [None]:
import numpy as np
import numpy.random as npr
import scipy.stats as stats

import matplotlib.pyplot as plt

from kernel import gauss_kernel, linear_kernel
from mmd import mmd
from hsic import hsic

import jax
from jax.experimental import optimizers
from jax import grad, random
import jax.numpy as jnp

In [None]:
def plt_kde(X):
    fig, ax = plt.subplots()
    fig.set_size_inches(6,6)
    lim = np.min(X)-1, np.max(X)+1
    XX,YY = np.meshgrid(
        np.linspace(lim[0], lim[1], 100),
        np.linspace(lim[0], lim[1], 100))
    XY = np.vstack([XX.ravel(), YY.ravel()])
    kernel = stats.gaussian_kde(X.T)
    Z = kernel(XY).reshape(XX.shape)
    # ax.scatter(X[:,0], X[:,1], c='k', s=3)
    ax.imshow(Z, cmap='Blues',
              extent=[lim[0], lim[1], lim[0], lim[1]])
    ax.set_xlim([lim[0], lim[1]])
    ax.set_ylim([lim[0], lim[1]])
    return fig, ax

In [None]:
# Taken from https://github.com/IPL-UV/jaxkern

def sqeuclidean_distance(x, y):
    return jnp.sum((x - y) ** 2)

def euclidean_distance(x, y):
    return jnp.sqrt(sqeuclidean_distance(x, y))

def distmat(func, x, y):
    return jax.vmap(lambda x1: jax.vmap(lambda y1: func(x1, y1))(y))(x)

def cdist_sqeuclidean(x, y):
    """ Squared euclidean distance matrix """
    return distmat(sqeuclidean_distance, x, y)

def cdist_euclidean(x, y):
    """ Squared euclidean distance matrix """
    return distmat(euclidean_distance, x, y)

def rbf_kernel(X, Y, gamma=1.):
    """Radial Basis Function Kernel
        
            k(x,y)=exp(-\gamma*||x-y||**2)
                where \gamma   = 1/(2*sigma^2)
                      \sigma^2 = 1/(2*\gamma)
 
        X, Y    (n, d)
        Returns kernel matrix of size (n, n)
    """
    return jnp.exp(-gamma*cdist_sqeuclidean(X, Y))

def linear_kernel(X, Y):
    return jnp.dot(X, Y.T)

def estimate_sigma_median(X):
    """Estimate sigma using the median heuristic
            bandwidth = median(l2dist.([X,Y]))
                with \sigma = \sqrt(bandwidth/2)
        
        X, Y    (n, d)
    """
    D = cdist_euclidean(X, X)
    D = D[jnp.nonzero(D)]
    bandwidth = jnp.median(D)
    sigma = jnp.sqrt(bandwidth/2)
    return sigma

def hsic(X, Y, k, l):
    """ Computes empirical HSIC = tr(KHLH)
            where H is the centering matrix
    """
    K = k(X, Y)
    L = l(X, Y)
    m = len(K)
    H = jnp.eye(m) - 1/m
    statistic = jnp.trace(K.dot(H).dot(L).dot(H)) / (m**2)
    return statistic

In [None]:
n = 100
d = 2
lr = 1
num_steps = 40
batch_size = 128

npr.seed(0)
mu  = np.array([0,0])
cov = np.array([[1,0], [0,1]])
Xdist = stats.multivariate_normal(mu, cov)
X = Xdist.rvs(size=(n,))
X = jax.device_put(X)

sigma = estimate_sigma_median(X)
gamma = 1/(2*(sigma**2))
print(f'sigma={sigma}')
print(f'gamma={gamma}')

def kernel(X,Y):
    return rbf_kernel(X,Y,gamma=gamma)

A = jnp.array([[1,0]], dtype=jnp.float32)
params = {
    'B': jnp.array([[1,0]], dtype=jnp.float32),
#     'a': random.normal(key, (1,)),
#     'b': random.normal(key, (1,)),
}

def f(params, X):
    return jnp.dot(A,X)
def g(params, X):
    return jnp.dot(params['B'],X)
f = jax.vmap(f, (None, 0), 0)
g = jax.vmap(g, (None, 0), 0)

def loss_fn(params, X):
    fX = f(params, X)
    gX = g(params, X)
#     print(hsic(fX, gX, kernel, kernel), max(-(params['B']@params['B'].T)[0,0], -1))
    loss = hsic(fX, gX, kernel, kernel) + max(-(params['B']@params['B'].T)[0,0], -1)
    return loss


opt_init, opt_update, get_params = optimizers.sgd(lr)
opt_state = opt_init(params)

print("\nStarting training...")
for i in range(num_steps):
    start_time = time.time()
    params = get_params(opt_state)
    
    fX = f(params, X)
    gX = g(params, X)
    fgX = jnp.hstack([fX, gX])
    plt.scatter(fgX[:,0], fgX[:,1], label=i)
    plt.xlim((-5,5))
    plt.ylim((-5,5))
    
    l = loss_fn(params, X)
    grd = grad(loss_fn)(get_params(opt_state), X)
    opt_state = opt_update(i, grd, opt_state)
    
    epoch_time = time.time() - start_time
    
    
    
    
    if i%5 == 0:
        print(f'[{i:3}] time={epoch_time:4.4f}\t loss={l:5.8f}\t A^TB={jnp.dot(A,params["B"].T)} B={params["B"]}')
        
        
        
    
plt.legend()
    
    


In [None]:
print(params['B'])
# print(params['a'])
# print(params['b'])

In [None]:
-(params['B']@params['B'].T)[0,0]