# KSD Bandwidth sweep

In [1]:
import sys
import os
sys.path.append("/home/lauro/code/msc-thesis/svgd/kernel_learning")
import json_tricks as json
import copy

from tqdm import tqdm
import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd, value_and_grad
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp
import jax
import pandas as pd
import haiku as hk
from jax.experimental import optimizers

import config

import utils
import metrics
import time
import plot
import stein
import kernels
import svgd
import distributions

from jax.experimental import optimizers

key = random.PRNGKey(0)



In [2]:
# initialize kernel, proposal dist, and target dist
def get_rbf_fn(bandwidth):
    logh = np.asarray(bandwidth)
    logh = np.squeeze(bandwidth)
    def rbf(x, y): 
        x, y = np.asarray(x), np.asarray(y)
        return np.exp(- np.sum((x - y)**2 / bandwidth**2) / 2)
    return rbf

proposal = distributions.Gaussian([0,0],1)
target = distributions.Funnel(2)

In [3]:
# comparison kernels
def constant(x, y): return np.array(1.)
def null(x, y): return np.array(0.)
def get_tophat_fn(bandwidth):
    def tophat(x, y): return np.squeeze(np.where(np.linalg.norm(x-y)<bandwidth, 1., 0.))
    return tophat

In [7]:
def ksd_sweep(proposal, target, n=200, m=5, grid=2**np.linspace(-5, 5, 25)):
    kernels_list = [get_rbf_fn, get_tophat_fn, kernels.get_funnel_kernel]
    
    @jit
    def get_ksds(p, lists): #TODO make better
        for kernelfn, ksdlist in zip(kernels_list, lists):
            kernel = kernelfn(p)
            ksdlist.append(stein.ksd_squared_u(samples, target.logpdf, kernel, False))
        return lists

    ksds = []
    for _ in range(m):
        grid_ksds = [[], [], []]
        samples = proposal.sample(n)
        for p in tqdm(grid):
            grid_ksds = get_ksds(p, grid_ksds)
        ksds.append(grid_ksds) # (nkernels, m)
        
    grid_means = onp.mean(ksds, axis=1)
    grid_vars = onp.std(ksds, axis=1)
    return list(zip(grid_means, grid_vars))

(grid_means_rbf, grid_var_rbf), (grid_means_top, grid_var_top), (grid_means_fun, grid_var_fun) = ksd_sweep(proposal, target)

100%|██████████| 25/25 [00:43<00:00,  1.72s/it]
100%|██████████| 25/25 [00:00<00:00, 1093.70it/s]
100%|██████████| 25/25 [00:00<00:00, 1099.71it/s]
100%|██████████| 25/25 [00:00<00:00, 1141.75it/s]
100%|██████████| 25/25 [00:00<00:00, 1224.97it/s]


ValueError: too many values to unpack (expected 3)

In [None]:
samples = proposal.sample(400)
ksd_c = stein.ksd_squared_u(samples, target.logpdf, constant, False)

In [None]:
fig, ax = plt.figure(figsize=[7,7])
plt.errorbar(grid, grid_means_rbf, yerr=grid_var_rbf, fmt="--o", label="KSD-U")
plt.errorbar(grid, grid_means_top, yerr=grid_var_top, fmt="--o", label="KSD-U-tophat")
plt.errorbar(grid, grid_means_fun, yerr=grid_var_fun, fmt="--o", label="KSD-U-funnel")

plt.axhline(y=ksd_c, label="one", color="r")
plt.xscale("log")
plt.legend()