In [14]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random, api
import matplotlib.pyplot as plt
import time

from svgd import SVGD, get_bandwidth

In [2]:
from jax.scipy.stats import norm

@jit
def logp(x):
    """
    IN: single scalar np array x. alternatively, [x] works too
    OUT: scalar logp(x)
    """
    return np.squeeze(np.sum(norm.logpdf(x, loc=0, scale=1)))

In [5]:
n = 100
stepsize = 0.01
L = int(1 / stepsize)
svgd_adaptive = SVGD(logp, n_iter_max=L, adaptive_kernel=True, get_bandwidth=get_bandwidth)

In [6]:
# generate data
key = random.PRNGKey(0)
x0 = random.normal(key, (n,1)) - 10



In [8]:
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0, n_iter=L)

### compare `jit`ted and non`jit`ted times

In [16]:
%%time
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0, n_iter=L)

CPU times: user 0 ns, sys: 5.95 ms, total: 5.95 ms
Wall time: 3.12 ms


In [22]:
%%time
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0, n_iter=L)
xout = xout.block_until_ready()

CPU times: user 1.14 s, sys: 7.81 ms, total: 1.15 s
Wall time: 591 ms


In [27]:
%%timeit
with api.disable_jit():
    xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0, n_iter=L)

1.59 s ± 94.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [24]:
st = time.time()
with api.disable_jit():
    xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0, n_iter=L)
end = time.time()
print(end - st)

0.8886773586273193


In [26]:
st = time.time()
xout, log = svgd_adaptive.svgd(x0, stepsize, bandwidth=0, n_iter=L)
xout.block_until_ready()
end = time.time()
print(end - st)

0.5922408103942871
