In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
from jax import lax

import time

# compare times for jitted svgd: fixed param vs not fixed

In [2]:
from svgd import svgd, fixed_param_svgd

In [3]:
from jax.scipy.stats import norm
from svgd import kernel_param_update_rule

@jit
def logp(x):
    """
    IN: single scalar np array x. alternatively, [x] works too
    OUT: scalar logp(x)
    """
    return np.squeeze(np.sum(norm.logpdf(x, loc=0, scale=1)))

n = 10
stepsize = 0.01
L = int(1 / stepsize)

# generate data
key = random.PRNGKey(1)
x = random.normal(key, (n,1)) - 10

kernel_param = kernel_param_update_rule(x)
hfun = lambda x: kernel_param



First just once

In [4]:
%timeit svgd(x, logp, stepsize, L, hfun).block_until_ready()

261 µs ± 65.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit fixed_param_svgd(x, logp, stepsize, L, kernel_param).block_until_ready()

246 µs ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


And now see what happens if we sweep over a grid of `kernel_param` values. First, define param search using regular svgd:

In [9]:
def get_mse(logp, n, stepsize, L, m, q, kernel_param_update_rule=None):
    """
    IN:
    * logp: callable, computes log(p(x)).
    * n: integer, number of particles
    * stepsize: float
    * L: integer, number of SVGD steps
    * m: integer, number of samples for averaging MSE
    * q: callable, takes as argument a seed key and outputs samples of initial distributio q0
    * kernel_param_update_rule: callable, takes the updated particles as input and outputs an updated set of kernel parameters.
    If kernel_param constant, set to lambda x: const

    OUT:
    * dictionary of MSE values

    """
    mse1 = []
    mse2 = []
    mse3 = []
    for seed in range(0, m):
        key = random.PRNGKey(seed)
        x = q(key, n)

        xout = svgd(x, logp, stepsize, L, kernel_param_update_rule)
        mse1.append((np.mean(xout) - 2/3)**2)
        mse2.append((np.mean(xout**2) - 5)**2)

        w = random.normal(key, (1,))
        mse3.append((np.mean(np.cos(w * xout) - np.exp(-w**2 / 2))**2))

    mse1 = np.mean(np.array(mse1))
    mse2 = np.mean(np.array(mse2))
    mse3 = np.mean(np.array(mse3))

    mse = {
    "E[x]": mse1,
    "E[x^2]": mse2,
    "E[cos(wx)]": mse3
    }
    
    return mse

default_q = lambda key, n: random.normal(key, shape=(n,1)) - 10
def kernel_param_search(logp, n, stepsize, L, m, kernel_param_grid, q=default_q):
    """
    IN:
    * logp: callable, computes log(p(x)).
    * n: integer, number of particles
    * stepsize: float
    * L: integer, number of SVGD steps
    * kernel_param_grid: one-dimensional np array
    * m: integer, number of samples for computing MSE
    * q: callable, takes as argument a seed key and outputs samples of initial distributio q0

    
    OUT:
    * dictionary consisting of three np arrays of the same length as kernel_param_grid. Entries are MSE values.
    """
    logs = []
    mse1s = []
    mse2s = []
    mse3s = []

    for h in kernel_param_grid:
        mse1, mse2, mse3 = list(get_mse(logp, n, stepsize, L, m, q, lambda x: h).values())
        
        mse1s.append(mse1)
        mse2s.append(mse2)
        mse3s.append(mse3)
    
    mses = {
        "E[x]": mse1s,
        "E[x^2]": mse2s,
        "E[cos(wx)]": mse3s
    }
    
    return mses

In [10]:
def get_mse_fixed_param(logp, n, stepsize, L, m, q, kernel_param):
    """
    IN:
    * logp: callable, computes log(p(x)).
    * n: integer, number of particles
    * stepsize: float
    * L: integer, number of SVGD steps
    * m: integer, number of samples for averaging MSE
    * q: callable, takes as argument a seed key and outputs samples of initial distributio q0

    OUT:
    * dictionary of MSE values

    """
    mse1 = []
    mse2 = []
    mse3 = []
    for seed in range(0, m):
        key = random.PRNGKey(seed)
        x = q(key, n)

        xout = fixed_param_svgd(x, logp, stepsize, L, kernel_param)
        mse1.append((np.mean(xout) - 2/3)**2)
        mse2.append((np.mean(xout**2) - 5)**2)

        w = random.normal(key, (1,))
        mse3.append((np.mean(np.cos(w * xout) - np.exp(-w**2 / 2))**2))

    mse1 = np.mean(np.array(mse1))
    mse2 = np.mean(np.array(mse2))
    mse3 = np.mean(np.array(mse3))

    mse = {
    "E[x]": mse1,
    "E[x^2]": mse2,
    "E[cos(wx)]": mse3
    }
    
    return mse

default_q = lambda key, n: random.normal(key, shape=(n,1)) - 10
def fixed_kernel_param_search(logp, n, stepsize, L, m, kernel_param_grid, q=default_q):
    """
    IN:
    * logp: callable, computes log(p(x)).
    * n: integer, number of particles
    * stepsize: float
    * L: integer, number of SVGD steps
    * kernel_param_grid: one-dimensional np array
    * m: integer, number of samples for computing MSE
    * q: callable, takes as argument a seed key and outputs samples of initial distributio q0

    
    OUT:
    * dictionary consisting of three np arrays of the same length as kernel_param_grid. Entries are MSE values.
    """
    logs = []
    mse1s = []
    mse2s = []
    mse3s = []

    for h in kernel_param_grid:
        mse1, mse2, mse3 = list(get_mse_fixed_param(logp, n, stepsize, L, m, q, h).values())
        
        mse1s.append(mse1)
        mse2s.append(mse2)
        mse3s.append(mse3)
    
    mses = {
        "E[x]": mse1s,
        "E[x^2]": mse2s,
        "E[cos(wx)]": mse3s
    }
    
    return mses

Get times:

In [11]:
n = 10
stepsize = 0.01
L = int(1 / stepsize)
kernel_param_grid = np.logspace(-10, 15, num=25, base=2) # params smaller than 2^10 are generally awful
m = 10

q = lambda key, n: random.normal(key, shape=(n,1)) - 10

In [12]:
start_time = time.time()
mses = kernel_param_search(logp, n, stepsize, L, m, kernel_param_grid, q)
elapsed_time = time.time() - start_time
print(elapsed_time)

34.6034951210022


In [13]:
start_time = time.time()
mses_f = fixed_kernel_param_search(logp, n, stepsize, L, m, kernel_param_grid, q=default_q)
elapsed_time_f = time.time() - start_time
print(elapsed_time_f)

2.4158875942230225


Hypothesis: inefficiency comes from recompiling every time for a new `kernel_param`.

Now the question is: did I just add this inefficiency while `jit`ing `svgd`? Or was it already there? Let's check.

```
def old_svgd(x, logp, stepsize, L, kernel_param_update_rule):
...
```

In [14]:
from svgd import update
def old_svgd(x, logp, stepsize, L, kernel_param_update_rule):
    """
    IN:
    * x is an np array of shape n x d
    * logp is the log of a differentiable pdf p (callable)
    * stepsize is a float
    * kernel_param is a positive scalar: bandwidth parameter for RBF kernel
    * L is an integer (number of iterations)
    * kernel_param_update_rule is a callable that takes the updated particles as input and outputs an updated set of kernel parameters. If supplied, the argument kernel_param will be ignored.

    OUT:
    * Updated particles x (np array of shape n x d) after L steps of SVGD
    * dictionary with logs
    """
    assert x.ndim == 2

    log = {
        "kernel_param": [],
        "particle_mean": [],
        "particle_var": []
    }

    
    for i in range(L):
        kernel_param = kernel_param_update_rule(x)
        x = update(x, logp, stepsize, kernel_param)

        update_dict = {
            "kernel_param": kernel_param,
            "particle_mean": np.mean(x, axis=0),
            "particle_var": np.var(x, axis=0)
        }

        for key in log.keys():
            log[key].append(update_dict[key])
            
    for key in log.keys():
        log[key] = np.array(log[key])

    return x#, log

In [15]:
%timeit old_svgd(x, logp, stepsize, L, hfun).block_until_ready()

149 ms ± 7.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


So seems the old SVGD was around 1000 times slower not counting recompiles. Note that `%timeit` assumes no recompiling. Let's see what happens if we define an `old_kernel_param_search`:

In [16]:
def get_mse_old(logp, n, stepsize, L, m, q, kernel_param_fun):
    """
    IN:
    * logp: callable, computes log(p(x)).
    * n: integer, number of particles
    * stepsize: float
    * L: integer, number of SVGD steps
    * m: integer, number of samples for averaging MSE
    * q: callable, takes as argument a seed key and outputs samples of initial distributio q0

    OUT:
    * dictionary of MSE values

    """
    mse1 = []
    mse2 = []
    mse3 = []
    for seed in range(0, m):
        key = random.PRNGKey(seed)
        x = q(key, n)

        xout = old_svgd(x, logp, stepsize, L, kernel_param_fun)
        mse1.append((np.mean(xout) - 2/3)**2)
        mse2.append((np.mean(xout**2) - 5)**2)

        w = random.normal(key, (1,))
        mse3.append((np.mean(np.cos(w * xout) - np.exp(-w**2 / 2))**2))

    mse1 = np.mean(np.array(mse1))
    mse2 = np.mean(np.array(mse2))
    mse3 = np.mean(np.array(mse3))

    mse = {
    "E[x]": mse1,
    "E[x^2]": mse2,
    "E[cos(wx)]": mse3
    }
    
    return mse

default_q = lambda key, n: random.normal(key, shape=(n,1)) - 10
def old_kernel_param_search(logp, n, stepsize, L, m, kernel_param_grid, q=default_q):
    """
    IN:
    * logp: callable, computes log(p(x)).
    * n: integer, number of particles
    * stepsize: float
    * L: integer, number of SVGD steps
    * kernel_param_grid: one-dimensional np array
    * m: integer, number of samples for computing MSE
    * q: callable, takes as argument a seed key and outputs samples of initial distributio q0

    
    OUT:
    * dictionary consisting of three np arrays of the same length as kernel_param_grid. Entries are MSE values.
    """
    logs = []
    mse1s = []
    mse2s = []
    mse3s = []

    for h in kernel_param_grid:
        mse1, mse2, mse3 = list(get_mse_old(logp, n, stepsize, L, m, q, lambda x: h).values())
        
        mse1s.append(mse1)
        mse2s.append(mse2)
        mse3s.append(mse3)
    
    mses = {
        "E[x]": mse1s,
        "E[x^2]": mse2s,
        "E[cos(wx)]": mse3s
    }
    
    return mses

In [17]:
start_time = time.time()
mses_o = old_kernel_param_search(logp, n, stepsize, L, m, kernel_param_grid, q=default_q)
elapsed_time_o = time.time() - start_time
print(elapsed_time_o)

42.18507409095764


This is around a third of the svgd version that recompiles at every new kernel param. This is consistent with the above hypothesis, since `old_svgd` here only recompiles every `m` runs when a new kernel param is tried. So hypothesis: `elapsed_time` / `elapsed_time_o` is circa equal to `m`.

In [18]:
print(m)
print(elapsed_time / elapsed_time_o)

10
0.8202781639399668


Damn, hypothesis discomfirmed (`m`=10, random seed `PRNGKey(1)`)

# another `fori_loop` test

In [None]:
def update_fun(i, lis):
    lis.append(i)
    del lis[0]
    return lis

lax.fori_loop(0, 10, update_fun, [1, 2, 3])

In [None]:
python_fori_loop(0, 10, update_fun, [1,2,3])

# `lax.scan`

Rough python equivalent:    

In [None]:
def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

Can we define `fori_loop` in terms of `scan`? Code taken from a github issue:

In [None]:
from jax import lax

In [None]:
def differentiable_fori_loop(lower, upper, body_fun, init_val):
    f = lambda x, i: (body_fun(i, x), ())
    result, _ = lax.scan(f, init_val, np.arange(lower, upper))
    return result

In [None]:
def python_fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

test:

In [None]:
f = lambda i, x: x**2 + i

In [None]:
lax.fori_loop(0, 3, f, 1)

In [None]:
python_fori_loop(0, 3, f, 1)

In [None]:
differentiable_fori_loop(0, 3, f, 1)

# fori loop

In [None]:
from jax.lax import fori_loop

In [None]:
a = 1

In [None]:
@jit
def test(L):
    return fori_loop(0, L, lambda i, n: n+1, 1)

In [None]:
test(10)

Compare a normal for loop, where we cant `jit` the function:

In [None]:
# @jit
def test2(L):
    x = 1 # init_val
    for _ in range(L):
        x = x + 1
    return x

In [None]:
test2(10)

### using `grad` with `fori_loop`

In [None]:
def test3(x):
    """output x^4 using fori loop to compute"""
    return fori_loop(0, 2, lambda i, y: y*y, x)

In [None]:
test3(3.)

In [None]:
grad(test3)(3.)

In [None]:
from jax import jacfwd
jacfwd(test3)(3.) # 4*x^3 = 4 * 27 = 108

# gpu vs cpu

In [None]:
from jax.lib import xla_bridge
backend = xla_bridge.get_backend().platform
print(xla_bridge.get_backend().platform)

In [None]:
key = random.PRNGKey(0)
d = 10000
x = random.normal(key, shape=(d, d))
key = random.split(key, 1)[0]
y = random.normal(key, shape=(d, d))

In [None]:
print(f"Running on {backend}.")
print()
%timeit np.matmul(x, y).block_until_ready()

## `np.sort`

In [None]:
x = np.array([[2,3,4],
              [3,2,1]])
print(x.shape)
np.sort(x)

In [None]:
np.sort(x, axis=0)

In [None]:
np.sort(x, axis=1)

# can `jit` take care of this?

In [None]:
batched_vdot = vmap(np.vdot)

In [None]:
def test(x, y):
    """
    do a lot of useless repetitive stuff. return np.vdot(x, y)
    """
    n = 2 * 10**6
    xtiled = np.tile(x, (n, 1))
    ytiled = np.tile(y, (n, 1))
    out = batched_vdot(xtiled, ytiled)
    return out[0]

jt = jit(test)

In [None]:
x = np.array([1,1,1])
y = np.array([1,2,3])

In [None]:
%timeit test(x, y)

%timeit jt(x, y).block_until_ready()

jup, it can

In [None]:
def quicktest(x, y):
    n = 2
    xtiled = np.tile(x, (n, 1))
    ytiled = np.tile(y, (n, 1))
    out = batched_vdot(xtiled, ytiled)
    return out[0]

qjt = jit(quicktest)

In [None]:
%timeit quicktest(x, y)

In [None]:
%timeit qjt(x, y).block_until_ready()

## kernel computation

In [None]:
from utils import single_rbf, ard
for x in np.linspace(-10, 10, 30):
    x = np.array([x])
    for y in np.linspace(-10, 10, 30):
        y = np.array([y])
        for h in np.linspace(1, 100, 5):
#             print("x = ", x)
#             print("y = ", y)
#             print()
#             print("rbf(x, y): ", single_rbf(x, y, h))
#             print("ard(x, y): ", ard(x, y, h))
#             print()
#             print("-----------")
            assert single_rbf(x, y, h) == ard(x, y, h)

## misc jax:

In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

In [None]:
# multiply matrices
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready()  # runs on the GPU

In [None]:
import numpy as onp  # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()

In [None]:
from jax import device_put

x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()

Using `jit`

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

In [None]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

take diff


In [None]:
@jit
def sum_logistic(x):
    return np.sum(1.0 / (1.0 + np.exp(-x)))

x = np.arange(3.)
print(sum_logistic(x))
print(grad(sum_logistic)(x))

In [None]:
x = random.normal(key, (10, 3))
batched_sum = vmap(sum_logistic)
batched_sum(x)

In [None]:
def test(x, y):
    return np.sum(x**2 + y**2)

In [None]:
x, y  = [random.normal(key, (10,3)), random.normal(key + 1, (10,3))]
print('single argument:', test(x[0], y[0]), '\n')
print('batch output shape:', vmap(test)(x, y).shape)
print('batch output:', vmap(test)(x, y))

In [None]:
np.append(np.array([1,2,3]), 4)

# pytorch distance matrix

In [None]:
import torch

In [None]:
x = torch.rand(10, 2)
x

In [None]:
x.split(1)

In [None]:
row = x.split(1)[0]
row

In [None]:
r_v = row.expand_as(x)
r_v

In [None]:
sq_dist = torch.sum((r_v - x) ** 2, 1)
print(sq_dist.shape)
sq_dist

In [None]:
sq_dist.view(1, -1).shape

In [None]:
def row_pairwise_distances(x, y=None, dist_mat=None):
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

# random stuff

## first question
does `jit` cache results if it needs them again? that is, does it skip over repeated computations?

In [None]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [None]:
m = 10**4
def h(x):
    for i in range(m):
        x += i
    return x

In [None]:
h(2)

In [None]:
def f1(x):
    """
    computation of h is repeated needlessly
    """
    out = 0
    for i in range(10):
        out += h(x) + i
    return out

In [None]:
def f2(x):
    """
    h(x) is computed only once
    """
    out = 0
    hx = h(x)
    for i in range(10):
        out += hx + i
    return out

In [None]:
s = 4

In [None]:
assert f1(s) == f2(s)

In [None]:
%timeit f1(s)
%timeit f2(s)
%timeit jit(f1)(s).block_until_ready()
%timeit jit(f2)(s).block_until_ready()

## next question:
does `lax.fori_loop` compile more quickly?

In [None]:
from jax import lax

In [None]:
m = 10**4
s = 4

In [None]:
def h1(x):
    body = lambda i, val: val + i
    for i in range(m):
        x = body(i, x)
    return x

In [None]:
def h2(x):
    body = lambda i, val: val + i
    return lax.fori_loop(0, m, body, init_val=x)

In [None]:
jit(h1)(s)

In [None]:
jit(h2)(s)

yeees, it does!

## Next question:
when we use `lax.fori_loop`, do we get the same speedup for repeated computations?

In [None]:
assert h(s) == h1(s)

In [None]:
# now we use the lax fori loop h1
def f1_lax(x):
    """
    computation of h is repeated needlessly
    """
    out = 0
#     for i in range(10):
#         out += h1(x) + i
        
    out = lax.fori_loop(0, 10, lambda i, val: val + h1(x) + i, init_val=out)
    return out

def f2_lax(x):
    """
    h(x) is computed only once
    """
    out = 0
    hx = h1(x)
#     for i in range(10):
#         out += hx + i
        
    out = lax.fori_loop(0, 10, lambda i, val: val + hx + i, init_val=out)
    return out

In [None]:
assert f1_lax(s) == f1(s)
assert f2_lax(s) == f2(s)
assert f1(s) == f2(s)

In [None]:
# still compiles fast:
jit(f1_lax)(s)

In [None]:
%timeit f1_lax(s)
%timeit f2_lax(s)
%timeit jit(f1_lax)(s).block_until_ready()
%timeit jit(f2_lax)(s).block_until_ready()

we do indeed have a speedup.

## How to plot 3D

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter

In [None]:
# Make data.
X = np.linspace(-5, 5, 50)
Y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(X, Y) # both shape (40, 40)

Z = X**2 + Y**2

# plot
fig = plt.figure()
ax = fig.gca(projection='3d')

surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)