Skip to content

Commit

Permalink
Add batching rule for some custom primitives (#178)
Browse files Browse the repository at this point in the history
* add custom prims batch rule

* rename num_warmup_steps to num_warmup

* fix lint

* adjust local variable name
  • Loading branch information
fehiepsi authored and neerajprad committed May 31, 2019
1 parent c1cd447 commit 401c565
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def run_inference(dept, male, applications, admit, rng, args):
init_params, potential_fn, constrain_fn = initialize_model(
rng, glmm, dept, male, applications, admit)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup_steps)
hmc_state = init_kernel(init_params, args.num_warmup)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: constrain_fn(hmc_state.z))
return hmc_states
Expand Down Expand Up @@ -109,7 +109,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC')
parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int)
parser.add_argument('--num-warmup-steps', nargs='?', default=500, type=int)
parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
main(args)
42 changes: 42 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,27 @@ def xlogy(x, y):
return lax._safe_mul(x, np.log(y))


def _xlogy_batching_rule(batched_args, batch_dims):
x, y = batched_args
bx, by = batch_dims
# promote shapes
sx, sy = np.shape(x), np.shape(y)
nx = len(sx) + int(bx is None)
ny = len(sy) + int(by is None)
nd = max(nx, ny)
x = np.reshape(x, (1,) * (nd - len(sx)) + sx)
y = np.reshape(y, (1,) * (nd - len(sy)) + sy)
# correct bx, by due to promoting
bx = bx + nd - len(sx) if bx is not None else nd - len(sx) - 1
by = by + nd - len(sy) if by is not None else nd - len(sy) - 1
# move bx, by to front
x = batching.move_dim_to_front(x, bx)
y = batching.move_dim_to_front(y, by)
return xlogy(x, y), 0


ad.defjvp(xlogy.primitive, _xlogy_jvp_lhs, _xlogy_jvp_rhs)
batching.primitive_batchers[xlogy.primitive] = _xlogy_batching_rule


def _xlog1py_jvp_lhs(g, x, y):
Expand All @@ -337,13 +357,33 @@ def _xlog1py_jvp_rhs(g, x, y):
return g * lax._safe_mul(x, np.reciprocal(1 + y))


def _xlog1py_batching_rule(batched_args, batch_dims):
x, y = batched_args
bx, by = batch_dims
# promote shapes
sx, sy = np.shape(x), np.shape(y)
nx = len(sx) + int(bx is None)
ny = len(sy) + int(by is None)
nd = max(nx, ny)
x = np.reshape(x, (1,) * (nd - len(sx)) + sx)
y = np.reshape(y, (1,) * (nd - len(sy)) + sy)
# correct bx, by due to promoting
bx = bx + nd - len(sx) if bx is not None else nd - len(sx) - 1
by = by + nd - len(sy) if by is not None else nd - len(sy) - 1
# move bx, by to front
x = batching.move_dim_to_front(x, bx)
y = batching.move_dim_to_front(y, by)
return xlog1py(x, y), 0


@custom_transforms
def xlog1py(x, y):
x, y = _promote_args_like(osp_special.xlog1py, x, y)
return lax._safe_mul(x, np.log1p(y))


ad.defjvp(xlog1py.primitive, _xlog1py_jvp_lhs, _xlog1py_jvp_rhs)
batching.primitive_batchers[xlog1py.primitive] = _xlog1py_batching_rule


def entr(p):
Expand Down Expand Up @@ -373,6 +413,7 @@ def cumsum(x):


ad.defjvp(cumsum.primitive, lambda g, x: np.cumsum(g, axis=-1))
batching.defvectorized(cumsum.primitive)


@custom_transforms
Expand All @@ -383,6 +424,7 @@ def cumprod(x):
# XXX this implementation does not address the case x=0, hence the result in that case will be nan
# Ref: https://stackoverflow.com/questions/40916955/how-to-compute-gradient-of-cumprod-safely
ad.defjvp2(cumprod.primitive, lambda g, ans, x: np.cumsum(g / x, axis=-1) * ans)
batching.defvectorized(cumprod.primitive)


def promote_shapes(*args, shape=()):
Expand Down
36 changes: 36 additions & 0 deletions test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,42 @@ def test_standard_gamma_batch():
assert_allclose(samples[i], standard_gamma(rngs[i], alphas[i]))


@pytest.mark.parametrize('prim', [
xlogy,
xlog1py,
])
def test_binop_batch_rule(prim):
bx = np.array([1., 2., 3.])
by = np.array([2., 3., 4.])
x = np.array(1.)
y = np.array(2.)

actual_bx_by = vmap(lambda x, y: prim(x, y))(bx, by)
for i in range(3):
assert_allclose(actual_bx_by[i], prim(bx[i], by[i]))

actual_x_by = vmap(lambda y: prim(x, y))(by)
for i in range(3):
assert_allclose(actual_x_by[i], prim(x, by[i]))

actual_bx_y = vmap(lambda x: prim(x, y))(bx)
for i in range(3):
assert_allclose(actual_bx_y[i], prim(bx[i], y))


@pytest.mark.parametrize('prim', [
cumsum,
cumprod,
])
def test_unop_batch_rule(prim):
rng = random.PRNGKey(0)
bx = random.normal(rng, (3, 5))

actual = vmap(prim)(bx)
for i in range(3):
assert_allclose(actual[i], prim(bx[i]))


@pytest.mark.parametrize('p, shape', [
(np.array([0.1, 0.9]), ()),
(np.array([0.2, 0.8]), (2,)),
Expand Down

0 comments on commit 401c565

Please sign in to comment.