In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random, api
import matplotlib.pyplot as plt

from svgd import SVGD, get_bandwidth

In [2]:
from jax.scipy.stats import norm

@jit
def logp(x):
    """
    IN: single scalar np array x. alternatively, [x] works too
    OUT: scalar logp(x)
    """
    return np.squeeze(np.sum(norm.logpdf(x, loc=0, scale=1)))

In [4]:
n = 100
stepsize = 0.01
L = int(1 / stepsize)
svgd_adaptive = SVGD(logp, n_iter=L, adaptive_kernel=True, get_bandwidth=get_bandwidth)

In [5]:
# generate data
key = random.PRNGKey(0)
x0 = random.normal(key, (n,1)) - 10



In [6]:
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0)

### compare `jit`ted and non`jit`ted times

In [6]:
%%timeit
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0)

1min 38s ± 8.7 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0)
xout.block_until_ready()

226 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit
with api.disable_jit():
    xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0)

1min 8s ± 7.33 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
with api.disable_jit():
    xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0)