# Compute $\text{KL}(q_{T_\varepsilon} \Vert p)$, the KL div after an SVGD step

In [1]:
from jax import config
config.update("jax_debug_nans", True)

import sys
import os
sys.path.append("/home/lauro/code/msc-thesis/svgd/kernel_learning/")
import json_tricks as json
import copy
from functools import partial

from tqdm import tqdm
import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd, value_and_grad
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp
import jax
import pandas as pd
import haiku as hk
from jax.experimental import optimizers

import config

import utils
import metrics
import time
import plot
import stein
import kernels
import svgd
import distributions
import nets

from jax.experimental import optimizers

key = random.PRNGKey(0)



# KL Divergence

In theory, 
$$\text{KSD}(q \Vert p)^2 \approx \dfrac{d}{d \varepsilon} \text{KL}(q_{T_\varepsilon} \Vert p) \Big \vert_{\varepsilon=0}$$
when $T_\varepsilon(x) = x + \varepsilon \phi^*(x)$.

We can estimate the RHS term by 
* taking one SVGD step $x'_i = x_i + \hat \phi^*(x_i)$, and estimate directly
$$\text{KL}(q_{x} \Vert p(x)) - \text{KL}(q_{x'} \Vert p)$$
    - could compute $q$ by inverting the Jacobian of $T$.t
* estimating $\text{KL}(q \Vert p)$ and differentiating estimate

Recall that
$$
\text{KL}(q_T \Vert p) = \text{KL}(q \Vert p_{T^{-1}}) = E_{x \sim q} [\log q(x) - \log p(T x) - \log \det J_T(x)]
$$

In [2]:
def estimate_kl(logq, logp, samples):
    return np.mean(vmap(logq)(samples) - vmap(logp)(samples))

def pushforward_log(logpdf: callable, tinv: callable):
    """
    Arguments
        logpdf computes log(p(_)), where p is a PDF.
        tinv is the inverse of an injective transformation T: R^d --> R^d, x --> z
    
    Returns
        $\log p_T(z)$, where z = T(x). That is, the pushforward log pdf 
        $$\log p_T(z) = \log p(T^{-1} z) + \log \det(J_{T^{-1} z})$$
    """
    def pushforward_logpdf(z):
        det = np.linalg.det(jacfwd(tinv)(z))
#         if np.abs(det) < 0.001:
#             raise LinalgError("Determinant too small: T is not injective.")
        return logpdf(tinv(z)) + np.log(np.abs(det))
    return pushforward_logpdf

def kl_diff(logq, logp, x, transform):
    """
    Arguments:
        logq: computes log(q(x))
        logp: computes log(p(x))
        x: n samples from q, shape (n, d)
        transform: function T: R^d --> R^d, x --> z
    """
    # KL(q || p)
    kl1 = estimate_kl(logq, logp, x)
    
    # KL(q_T || p) = KL(q || p_{T^{-1}})
    z = vmap(transform)(x)
    logp_pullback = pushforward_log(logp, transform)
    kl2 = estimate_kl(logq, logp_pullback, x)
    return kl1 - kl2

# @partial(jit, static_argnums=(2, 3, 4, 5))
def get_kl_diff_and_ksd(key, eps, proposal, target, kernel, n_samples):
    proposal.threadkey = key
    x = proposal.sample(n_samples)
    print(x.shape)
    def transform(x):
        return x + eps * stein.phistar_i(x, samples, target.logpdf, kernel, aux=False)
    return kl_diff(proposal.logpdf, target.logpdf, x, transform), stein.ksd_squared_u(x, target.logpdf, kernel)

### seems to work

In [3]:
proposal = distributions.Gaussian(0, 1)
target   = distributions.Gaussian(4, 1)
eps = 0.001
trafo = lambda x: x + eps
dkl = kl_diff(proposal.logpdf, target.logpdf, proposal.sample(800), trafo) / eps # estimate of - d/deps KL(q_T_eps || p) 

def steinop(x):
    return stein.stein_operator(lambda x: np.asarray(1.), x, target.logpdf, transposed=False)
xs = proposal.sample(800)
stein_dkl = np.mean(vmap(steinop)(xs))

print(dkl - stein_dkl < 0.1)

True


## KSD test

In [4]:
proposal = distributions.Gaussian(0, 1)
target   = distributions.Gaussian(0, 9)
eps = 0.01

In [8]:
# rbf kernel
kernel = kernels.get_rbf_kernel(bandwidth=1)
ksd = stein.ksd_squared_u(proposal.sample(400), target.logpdf, kernel)
qs = proposal.sample(400)
def trafo(x):
    return x + eps * stein.phistar_i(x, qs, target.logpdf, kernel, aux=False)
dkl = kl_diff(proposal.logpdf, target.logpdf, proposal.sample(800), trafo) / eps
print("KSD =", ksd)
print("d/deps KL =", dkl)
print("absolute difference:", dkl - ksd)
print("relative difference:", (dkl - ksd) / ((dkl + ksd) / 2) )

KSD = 0.14718269
d/deps KL = 0.16075373
absolute difference: 0.013571039
relative difference: 0.08814182


In [6]:
# constant kernel
kernel = kernels.constant_kernel
ksd = stein.ksd_squared_u(proposal.sample(800), target.logpdf, kernel)
qs = proposal.sample(400)
def trafo(x):
    return x + eps * stein.phistar_i(x, qs, target.logpdf, kernel, aux=False)
dkl = kl_diff(proposal.logpdf, target.logpdf, proposal.sample(800), trafo) / eps
print("KSD =", ksd)
print("d/deps KL =", dkl)
print("absolute difference:", dkl - ksd)
print("relative difference:", (dkl - ksd) / ((dkl + ksd) / 2) )

KSD = -6.80155e-06
d/deps KL = 5.9604645e-06
absolute difference: 1.2762015e-05
relative difference: -30.346527
