In [1]:
import jax.numpy as np
from jax import grad, jit, vmap, random
from jax import lax
from jax.ops import index_update, index
from utils import squared_distance_matrix, ard

from jax import jacfwd

# new svgd update

In [3]:
def ard_matrix(x, bandwidth):
    """
    Arguments:
    * x, np array of shape (n, d)
    * kernel bandwidth, np array of shape (d,) or one-dimensional float
    
    Returns:
    * np array of shape (n, n) containing values k(xi, xj) for xi = x[i, :].
    """
    bandwidth = np.array(bandwidth)
    dsquared = vmap(squared_distance_matrix, 1)(x) # shape (d, n, n)
    if bandwidth.ndim > 0 and bandwidth.shape[0] > 1:
        bandwidth = bandwidth[:, np.newaxis, np.newaxis] # reshape bandwidth to have same shape as dsquared
    return np.exp(np.sum(- dsquared / bandwidth**2 / 2, axis=0)) # shape (n, n)

In [4]:
def update_new(x, logp, stepsize, bandwidth):
    km = lambda x: ard_matrix(x, bandwidth)
    kxy = km(x)
    dkxy = jacfwd(km)(x) # (n, n, n, d)
    dkxy = dkxy.diagonal(axis1=1, axis2=2) # (n, d, n)
    dlogp = grad(logp)(x)
    
    return x + stepsize * (np.einsum("il,ij->jl", dlogp, kxy) + np.sum(dkxy, axis=2))

In [16]:
rkey = random.PRNGKey(0)
x = random.normal(rkey, shape=(100, 1))
stepsize = 0.01
bandwidth = random.normal(rkey, shape=(1,))

from jax.scipy.stats import norm

@jit
def logp(x):
    """
    IN: single scalar np array x. alternatively, [x] works too
    OUT: scalar logp(x)
    """
    return np.squeeze(np.sum(norm.logpdf(x, loc=0, scale=1)))

In [17]:
np.all(update_new(x, logp, stepsize, bandwidth) - update(x, logp, stepsize, bandwidth) < 0.0001)

DeviceArray(True, dtype=bool)

### manual `dkxy`

In [None]:
x

In [None]:
import numpy as onp
test = onp.zeros((3, 3, 2))
for i, xi in enumerate(x):
    for j, xj in enumerate(x):
#         print(xi, xj)
        test[i, j] = grad(ard, argnums=1)(xi, xj, bandwidth)

In [None]:
km = lambda x: ard_matrix(x, bandwidth)
kxy = km(x)
dkxy = jacfwd(km)(x)
# print(dkxy[:, 0, 0, 0])
dkxy = dkxy.diagonal(axis1=1, axis2=2)#.reshape(3, 3, 2)

In [None]:
ax = 0
dkxy[:, ax, :] == test[:, :, ax]

In [None]:
# only issue left:
print(dkxy.shape)
print(test.shape)

In [None]:
onp.reshape(onp.array(dkxy), newshape=(3, 3, 2))

In [None]:
test

In [None]:
onp.sum(test, axis=1)

In [None]:
grad(ard)(x[0], x[0], bandwidth) + grad(ard)(x[1], x[0], bandwidth) + grad(ard)(x[2], x[0], bandwidth)

In [None]:
np.sum(dkxy, axis=2)

## Efficiently traceable pairwise_distances

In [None]:
@jit
def squared_distance_matrix(x):
    n = x.shape[0]
    if x.ndim == 1:
        x = np.reshape(x, (n, 1))
    xx = np.tile(x, (n, 1, 1)) # shape (n, n, d)
    diff = xx - xx.transpose((1, 0, 2))
    return vv_normsq(diff)

In [None]:
x = np.array([1,2,3])
squared_distance_matrix(x)

In [None]:
key = random.PRNGKey(0)
n = 10
d = 3
x = random.uniform(key, shape=(n, d))

In [None]:
x = np.array([1,2,3])
x = np.reshape(x, newshape=(3,1))

In [None]:
from scipy.spatial.distance import pdist, squareform

In [None]:
np.all(squareform(pdist(x)**2) - squared_distance_matrix(x) < 0.01)

## measuring KL divergence
$$ D_{KL}(q \Vert p) = E_{x \sim q}\big[\log \frac{q(x)}{p(x)}\big] $$

In [None]:
def kl(x, p):
    """
    IN: 
    * x is an np array of shape (n, d) representing n samples of a variable in R^d
    * p is a callable that computes a pdf
    OUT:
    the KL-divergence between the empirical distribution of x and the distribution p.
    """
    return None

def kernelized_stein_discrepancy(x, p, kernel):
    """
    IN: 
    * x is an np array of shape (n, d) representing n samples of a variable in R^d
    * p is a callable that computes a pdf
    OUT:
    the stein discrepancy between the empirical distribution of x and the distribution p.
    """
    

## updating numpy slices

In [None]:
z = np.zeros((3, 3))
z

In [None]:
# z[1, :] = 1 # doesn't work

In [None]:
znew = index_update(z, index[1, :], 1)
znew

As I understand it, if we do `znew = index_update(z, ...)` inside a `jit`-compiled function, then if `z` isn't used again, it doesn't take up more memory and the update is done "in-place".

# distance matrix

In [None]:
def old_pairwise_distances(x):
    """
    IN: n x d array: n observations of d-dimensional samples
    OUT: symmetric n x n distance matrix of distances (d(x_i, x_j))_ij. Here d = euclidian distance
    """
    if x.ndim == 1:
        x = np.expand_dims(x, axis=1) # x is n x 1 matrix now
    assert x.ndim == 2
    n = x.shape[0]

    out = np.zeros((n, n))
    for i, xi in enumerate(x):
        repeated = np.tile(xi, (n, 1))
        distances = batched_normsq(repeated - x) # k(x_1, x_i) for i = 1, ..., 10
        out = index_update(out, index[i, :], distances)
    return np.sqrt(out)


more efficiently:

In [None]:
def pairwise_distances(x):
    """
    IN: n x d array: n observations of d-dimensional samples
    OUT: np array of shape (l,) where l = (n^2 - n) / 2
    Consists of distances d(x1, x2), d(x1, x3), ..., d(xn-1, xn)
    """
    assert x.ndim == 2
    n = x.shape[0]

    distances = []
    for i, xi in enumerate(x[:-1]):
        repeated = np.tile(xi, (n - i - 1, 1))
        v = batched_normsq(repeated - x[i+1:]) # length n - i - 1
        distances.extend(v)
    return np.sqrt(np.array(distances))

def getn(l):
    """
    IN: l = n^2 - n / 2
    OUT: n (positive integer solution)
    """
    n = (1 + np.sqrt(1 + 8*l)) / 2
    assert np.equal(np.mod(n, 1), 0) # make sure n is an integer
    return int(n)

def get_distance_matrix(distances):
    """
    IN: output from `pairwise_distances`, an array of length l = n^2 - n / 2
    OUT: a symmetric n x n distance matrix with entries d(x_i, x_j)
    """
    l = distances.shape[0]
    n = getn(l)
    out = np.zeros((n, n))
    out[np.triu_indices(n, k = 1)]

    out = index_update(out, index[np.triu_indices(n, k=1)], distances)
    out = out + out.T
    return out

In [None]:
x = np.array([[1, 2, 3], [1, 1, 1], [1, 1, 1], [2, 3, 4]])
x

In [None]:
get_distance_matrix(pairwise_distances(x)) == old_pairwise_distances(x)

## new update rule

In [None]:
def phi_j(x, y, logp, kernel):
    """
    IN: 
    x and y are arrays of length d
    kernel is a function that takes two arguments, k(x, y)
    logp is the log of a differentiable pdf p
    
    OUT: 
    \nabla_x log(p(x)) * k(x, y) + \nabla_x k(x, y)
    
    that is, phi(x_i) = \sum_j phi_j(x_j, x_i)
    """
    assert x.ndim == 1 and y.ndim == 1
    return grad(logp)(x) * kernel(x, y) + grad(kernel)(x, y)

phi_j_batched = vmap(phi_j, (0, 0, None, None), 0)

def update(x, logp, stepsize, kernel):
    """
    IN:
    x is an np array of shape n x d
    logp is the log of a differentiable pdf p
    stepsize is a float
    kernel is a differentiable function k(x, y, h) that computes the rbf kernel
    
    OUT:
    xnew = x + stepsize * \phi^*(x)
    that is, xnew is an array of shape n x d. The entries of x are the updated particles.
    
    note that this is an inefficient way to do things, since we're computing k(x, y) twice for each x, y combination.
    """
    assert x.ndim == 2
    
    xnew = []
    n = x.shape[0]
    for i, xi in enumerate(x):
        repeated = np.tile(xi, (n, 1))
        xnew.append(stepsize * np.sum(phi_j_batched(x, repeated, logp, kernel), axis = 0))
    xnew = np.array(xnew)
    xnew += x

    return xnew

In [None]:
def rbf_old(x, h):
    """
    [Not used]
    x is a n x d matrix (n observations of features with dimension d)
    h is a scalar parameter

    OUT:
    a n x n "distance" matrix [k(x_i, x_j)]_{i, j \in 1, ..., n}
    """
    if x.ndim == 1:
        x = np.expand_dims(x, axis=1)
    assert x.ndim == 2
    n = x.shape[0]

    out = None
    for i, xi in enumerate(x):
        repeated = np.tile(xi, (n, 1))
        kernels = batched_rbf(repeated, x, h) # k(x_1, x_i) for i = 1, ..., 10
        kernels = np.expand_dims(kernels, axis=0)
        if out is None:
            out = kernels
        else:
            out = np.concatenate((out, kernels), axis=0)
    return out
