In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
from jax.ops import index_update, index

from utils import batched_normsq

## measuring KL divergence
$$ D_{KL}(q \Vert p) = E_{x \sim q}\big[\log \frac{q(x)}{p(x)}\big] $$

In [None]:
def kl(x, p):
    """
    IN: 
    * x is an np array of shape (n, d) representing n samples of a variable in R^d
    * p is a callable that computes a pdf
    OUT:
    the KL-divergence between the empirical distribution of x and the distribution p.
    """
    return None

def kernelized_stein_discrepancy(x, p, kernel):
    """
    IN: 
    * x is an np array of shape (n, d) representing n samples of a variable in R^d
    * p is a callable that computes a pdf
    OUT:
    the stein discrepancy between the empirical distribution of x and the distribution p.
    """
    

## updating numpy slices

In [2]:
z = np.zeros((3, 3))
z



DeviceArray([[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]], dtype=float32)

In [3]:
# z[1, :] = 1 # doesn't work

In [4]:
znew = index_update(z, index[1, :], 1)
znew

DeviceArray([[0., 0., 0.],
             [1., 1., 1.],
             [0., 0., 0.]], dtype=float32)

As I understand it, if we do `znew = index_update(z, ...)` inside a `jit`-compiled function, then if `z` isn't used again, it doesn't take up more memory and the update is done "in-place".

# distance matrix

In [5]:
def old_pairwise_distances(x):
    """
    IN: n x d array: n observations of d-dimensional samples
    OUT: symmetric n x n distance matrix of distances (d(x_i, x_j))_ij. Here d = euclidian distance
    """
    if x.ndim == 1:
        x = np.expand_dims(x, axis=1) # x is n x 1 matrix now
    assert x.ndim == 2
    n = x.shape[0]

    out = np.zeros((n, n))
    for i, xi in enumerate(x):
        repeated = np.tile(xi, (n, 1))
        distances = batched_normsq(repeated - x) # k(x_1, x_i) for i = 1, ..., 10
        out = index_update(out, index[i, :], distances)
    return np.sqrt(out)


more efficiently:

In [6]:
def pairwise_distances(x):
    """
    IN: n x d array: n observations of d-dimensional samples
    OUT: np array of shape (l,) where l = (n^2 - n) / 2
    Consists of distances d(x1, x2), d(x1, x3), ..., d(xn-1, xn)
    """
    assert x.ndim == 2
    n = x.shape[0]

    distances = []
    for i, xi in enumerate(x[:-1]):
        repeated = np.tile(xi, (n - i - 1, 1))
        v = batched_normsq(repeated - x[i+1:]) # length n - i - 1
        distances.extend(v)
    return np.sqrt(np.array(distances))

def getn(l):
    """
    IN: l = n^2 - n / 2
    OUT: n (positive integer solution)
    """
    n = (1 + np.sqrt(1 + 8*l)) / 2
    assert np.equal(np.mod(n, 1), 0) # make sure n is an integer
    return int(n)

def get_distance_matrix(distances):
    """
    IN: output from `pairwise_distances`, an array of length l = n^2 - n / 2
    OUT: a symmetric n x n distance matrix with entries d(x_i, x_j)
    """
    l = distances.shape[0]
    n = getn(l)
    out = np.zeros((n, n))
    out[np.triu_indices(n, k = 1)]

    out = index_update(out, index[np.triu_indices(n, k=1)], distances)
    out = out + out.T
    return out

In [7]:
x = np.array([[1, 2, 3], [1, 1, 1], [1, 1, 1], [2, 3, 4]])
x

DeviceArray([[1, 2, 3],
             [1, 1, 1],
             [1, 1, 1],
             [2, 3, 4]], dtype=int32)

In [8]:
get_distance_matrix(pairwise_distances(x)) == old_pairwise_distances(x)

DeviceArray([[ True,  True,  True,  True],
             [ True,  True,  True,  True],
             [ True,  True,  True,  True],
             [ True,  True,  True,  True]], dtype=bool)

## new update rule

In [9]:
def phi_j(x, y, logp, kernel):
    """
    IN: 
    x and y are arrays of length d
    kernel is a function that takes two arguments, k(x, y)
    logp is the log of a differentiable pdf p
    
    OUT: 
    \nabla_x log(p(x)) * k(x, y) + \nabla_x k(x, y)
    
    that is, phi(x_i) = \sum_j phi_j(x_j, x_i)
    """
    assert x.ndim == 1 and y.ndim == 1
    return grad(logp)(x) * kernel(x, y) + grad(kernel)(x, y)

phi_j_batched = vmap(phi_j, (0, 0, None, None), 0)

def update(x, logp, stepsize, kernel):
    """
    IN:
    x is an np array of shape n x d
    logp is the log of a differentiable pdf p
    stepsize is a float
    kernel is a differentiable function k(x, y, h) that computes the rbf kernel
    
    OUT:
    xnew = x + stepsize * \phi^*(x)
    that is, xnew is an array of shape n x d. The entries of x are the updated particles.
    
    note that this is an inefficient way to do things, since we're computing k(x, y) twice for each x, y combination.
    """
    assert x.ndim == 2
    
    xnew = []
    n = x.shape[0]
    for i, xi in enumerate(x):
        repeated = np.tile(xi, (n, 1))
        xnew.append(stepsize * np.sum(phi_j_batched(x, repeated, logp, kernel), axis = 0))
    xnew = np.array(xnew)
    xnew += x

    return xnew

In [10]:
def rbf_old(x, h):
    """
    [Not used]
    x is a n x d matrix (n observations of features with dimension d)
    h is a scalar parameter

    OUT:
    a n x n "distance" matrix [k(x_i, x_j)]_{i, j \in 1, ..., n}
    """
    if x.ndim == 1:
        x = np.expand_dims(x, axis=1)
    assert x.ndim == 2
    n = x.shape[0]

    out = None
    for i, xi in enumerate(x):
        repeated = np.tile(xi, (n, 1))
        kernels = batched_rbf(repeated, x, h) # k(x_1, x_i) for i = 1, ..., 10
        kernels = np.expand_dims(kernels, axis=0)
        if out is None:
            out = kernels
        else:
            out = np.concatenate((out, kernels), axis=0)
    return out
