Skip to content

Commit

Permalink
Fix issues with new version of JAX (#202)
Browse files Browse the repository at this point in the history
* adapt new custom transform api

* fix the issue at cumprod

* drop the case progbar=None

* fix bug at cumsum

* cleanup unnecessary batching_rules
  • Loading branch information
fehiepsi committed Jun 12, 2019
1 parent b6bc4be commit ea87411
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 87 deletions.
96 changes: 25 additions & 71 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import scipy.special as osp_special

import jax.numpy as np
from jax import canonicalize_dtype, custom_transforms, device_get, jit, lax, random, vmap
from jax.interpreters import ad, batching
from jax import canonicalize_dtype, core, custom_transforms, defjvp, device_get, jit, lax, random, vmap
from jax.interpreters import ad, batching, partial_eval, xla
from jax.lib import xla_bridge
from jax.numpy.lax_numpy import _promote_args_like
from jax.scipy.linalg import solve_triangular
Expand Down Expand Up @@ -56,10 +56,7 @@ def _next_kxv(kxv):

# TODO: use upstream implementation when available because it is 2x faster
def _standard_gamma_impl(key, alpha):
if key.ndim > 1:
keys = vmap(lambda k: random.split(k, np.size(alpha[0])))(key)
else:
keys = random.split(key, alpha.size)
keys = random.split(key, alpha.size)
alphas = np.reshape(alpha, -1)
keys = np.reshape(keys, (-1, 2))
samples = vmap(_standard_gamma_one)(keys, alphas)
Expand Down Expand Up @@ -174,24 +171,13 @@ def _standard_gamma_grad(sample, alpha):
return grads.reshape(alpha.shape)


def _standard_gamma_batching_rule(batched_args, batch_dims):
x, y = batched_args
bx, by = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
if i is not None)
x = batching.bdim_at_front(x, bx, size, force_broadcast=True)
y = batching.bdim_at_front(y, by, size, force_broadcast=True)
return _standard_gamma_p(x, y), 0


@custom_transforms
def _standard_gamma_p(key, alpha):
return _standard_gamma_impl(key, alpha)


ad.defjvp2(_standard_gamma_p.primitive, None,
lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha))
batching.primitive_batchers[_standard_gamma_p.primitive] = _standard_gamma_batching_rule
defjvp(_standard_gamma_p, None,
lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha))


@partial(jit, static_argnums=(2, 3))
Expand Down Expand Up @@ -307,15 +293,15 @@ def multinomial(key, p, n, shape=()):
return _multinomial(key, p, n, shape)


def _xlogy_jvp_lhs(g, x, y):
def _xlogy_jvp_lhs(g, ans, x, y):
shape = lax.broadcast_shapes(np.shape(g), np.shape(y))
g = np.broadcast_to(g, shape)
y = np.broadcast_to(y, shape)
g, y = _promote_args_like(osp_special.xlogy, g, y)
return lax._safe_mul(g, np.log(y))


def _xlogy_jvp_rhs(g, x, y):
def _xlogy_jvp_rhs(g, ans, x, y):
shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
g = np.broadcast_to(g, shape)
x = np.broadcast_to(x, shape)
Expand All @@ -329,72 +315,32 @@ def xlogy(x, y):
return lax._safe_mul(x, np.log(y))


def _xlogy_batching_rule(batched_args, batch_dims):
x, y = batched_args
bx, by = batch_dims
# promote shapes
sx, sy = np.shape(x), np.shape(y)
nx = len(sx) + int(bx is None)
ny = len(sy) + int(by is None)
nd = max(nx, ny)
x = np.reshape(x, (1,) * (nd - len(sx)) + sx)
y = np.reshape(y, (1,) * (nd - len(sy)) + sy)
# correct bx, by due to promoting
bx = bx + nd - len(sx) if bx is not None else nd - len(sx) - 1
by = by + nd - len(sy) if by is not None else nd - len(sy) - 1
# move bx, by to front
x = batching.move_dim_to_front(x, bx)
y = batching.move_dim_to_front(y, by)
return xlogy(x, y), 0
defjvp(xlogy, _xlogy_jvp_lhs, _xlogy_jvp_rhs)


ad.defjvp(xlogy.primitive, _xlogy_jvp_lhs, _xlogy_jvp_rhs)
batching.primitive_batchers[xlogy.primitive] = _xlogy_batching_rule


def _xlog1py_jvp_lhs(g, x, y):
def _xlog1py_jvp_lhs(g, ans, x, y):
shape = lax.broadcast_shapes(np.shape(g), np.shape(y))
g = np.broadcast_to(g, shape)
y = np.broadcast_to(y, shape)
g, y = _promote_args_like(osp_special.xlog1py, g, y)
return lax._safe_mul(g, np.log1p(y))


def _xlog1py_jvp_rhs(g, x, y):
def _xlog1py_jvp_rhs(g, ans, x, y):
shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
g = np.broadcast_to(g, shape)
x = np.broadcast_to(x, shape)
x, y = _promote_args_like(osp_special.xlog1py, x, y)
return g * lax._safe_mul(x, np.reciprocal(1 + y))


def _xlog1py_batching_rule(batched_args, batch_dims):
x, y = batched_args
bx, by = batch_dims
# promote shapes
sx, sy = np.shape(x), np.shape(y)
nx = len(sx) + int(bx is None)
ny = len(sy) + int(by is None)
nd = max(nx, ny)
x = np.reshape(x, (1,) * (nd - len(sx)) + sx)
y = np.reshape(y, (1,) * (nd - len(sy)) + sy)
# correct bx, by due to promoting
bx = bx + nd - len(sx) if bx is not None else nd - len(sx) - 1
by = by + nd - len(sy) if by is not None else nd - len(sy) - 1
# move bx, by to front
x = batching.move_dim_to_front(x, bx)
y = batching.move_dim_to_front(y, by)
return xlog1py(x, y), 0


@custom_transforms
def xlog1py(x, y):
x, y = _promote_args_like(osp_special.xlog1py, x, y)
return lax._safe_mul(x, np.log1p(y))


ad.defjvp(xlog1py.primitive, _xlog1py_jvp_lhs, _xlog1py_jvp_rhs)
batching.primitive_batchers[xlog1py.primitive] = _xlog1py_batching_rule
defjvp(xlog1py, _xlog1py_jvp_lhs, _xlog1py_jvp_rhs)


def cholesky_inverse(matrix):
Expand Down Expand Up @@ -433,19 +379,27 @@ def cumsum(x):
return np.cumsum(x, axis=-1)


ad.defjvp(cumsum.primitive, lambda g, x: np.cumsum(g, axis=-1))
batching.defvectorized(cumsum.primitive)
defjvp(cumsum, lambda g, ans, x: np.cumsum(g, axis=-1))


@custom_transforms
def cumprod(x):
# XXX work around the issue: batching rule for 'reduce_window' not implemented
# when using @custom_transforms decorator
def _cumprod_impl(x):
return np.cumprod(x, axis=-1)


cumprod_p = core.Primitive('cumprod')
cumprod_p.def_impl(_cumprod_impl)
cumprod_p.def_abstract_eval(partial(partial_eval.abstract_eval_fun, _cumprod_impl))
xla.translations[cumprod_p] = partial(xla.lower_fun, _cumprod_impl)
# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
ad.defjvp2(cumprod.primitive, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
batching.defvectorized(cumprod.primitive)
ad.defjvp2(cumprod_p, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
batching.defvectorized(cumprod_p)


def cumprod(x):
return cumprod_p.bind(x)


def promote_shapes(*args, shape=()):
Expand Down
9 changes: 2 additions & 7 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,8 @@ def init_kernel(init_params,
if run_warmup and num_warmup > 0:
# JIT if progress bar updates not required
if not progbar:
# TODO: keep jit version and remove non-jit version for the next jax release
if progbar is None: # NB: if progbar=None, we jit fori_loop
hmc_state = jit(fori_loop, static_argnums=(2,))(
0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
else:
hmc_state = fori_loop(
0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
hmc_state = jit(fori_loop, static_argnums=(2,))(
0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state)
else:
with tqdm.trange(num_warmup, desc='warmup') as t:
for i in t:
Expand Down
8 changes: 2 additions & 6 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,8 @@ def _body_fn(i, vals):
collection = ops.index_update(collection, i, ravel_fn(val))
return val, collection

# TODO: keep jit version and remove non-jit version for the next jax release
if progbar is None: # NB: if progbar=None, we jit fori_loop
_, collection = jit(fori_loop, static_argnums=(2,))(0, upper, _body_fn,
(init_val, collection))
else:
_, collection = fori_loop(0, upper, _body_fn, (init_val, collection))
_, collection = jit(fori_loop, static_argnums=(2,))(0, upper, _body_fn,
(init_val, collection))
else:
diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
progbar_desc = progbar_opts.pop('progbar_desc', '')
Expand Down
4 changes: 1 addition & 3 deletions test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def model(data):
true_probs = np.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
# TODO: having progbar=None just to test if that value works,
# change it to False when jit fori_loop issue is resolved
samples = mcmc(warmup_steps, num_samples, init_params, constrain_fn=constrain_fn, progbar=None,
samples = mcmc(warmup_steps, num_samples, init_params, constrain_fn=constrain_fn, progbar=False,
print_summary=False, potential_fn=potential_fn, algo=algo, trajectory_length=1.,
dense_mass=dense_mass)
assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)
Expand Down

0 comments on commit ea87411

Please sign in to comment.