-
Notifications
You must be signed in to change notification settings - Fork 222
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Factored out HMCECS proxies to contrib (#1748)
* sketched ecs_proxies * sketched 1 and 2 order taylor ecs proxy * added hessian approximation * fixed comment * lint * removed jacfwd * updated test case * updated test * removed approx * removed approx from taylor proxy * fixed reference for taylor proxy * fixed lint * fixed to 3.8 syntax * moved `test_block_update_partitioning` to contrib/test_ecs_proxies
- Loading branch information
1 parent
a967f69
commit d6f9897
Showing
5 changed files
with
315 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections import defaultdict, namedtuple | ||
import warnings | ||
|
||
from jax import hessian, jacobian, lax, numpy as jnp, random | ||
from jax.flatten_util import ravel_pytree | ||
|
||
from numpyro.distributions.transforms import biject_to | ||
from numpyro.handlers import block, substitute, trace | ||
|
||
TaylorTwoProxyState = namedtuple( | ||
"TaylorProxyState", | ||
"ref_subsample_log_liks," | ||
"ref_subsample_log_lik_grads," | ||
"ref_subsample_log_lik_hessians", | ||
) | ||
|
||
TaylorOneProxyState = namedtuple( | ||
"TaylorOneProxyState", "ref_subsample_log_liks," "ref_subsample_log_lik_grads," | ||
) | ||
|
||
|
||
def perturbed_method(subsample_plate_sizes, proxy_fn): | ||
def estimator(likelihoods, params, gibbs_state): | ||
subsample_log_liks = defaultdict(float) | ||
for fn, value, name, subsample_dim in likelihoods.values(): | ||
subsample_log_liks[name] += _sum_all_except_at_dim( | ||
fn.log_prob(value), subsample_dim | ||
) | ||
|
||
log_lik_sum = 0.0 | ||
|
||
proxy_value_all, proxy_value_subsample = proxy_fn( | ||
params, subsample_log_liks.keys(), gibbs_state | ||
) | ||
|
||
for ( | ||
name, | ||
subsample_log_lik, | ||
) in subsample_log_liks.items(): # loop over all subsample sites | ||
n, m = subsample_plate_sizes[name] | ||
|
||
diff = subsample_log_lik - proxy_value_subsample[name] | ||
|
||
unbiased_log_lik = proxy_value_all[name] + n * jnp.mean(diff) | ||
variance = n**2 / m * jnp.var(diff) | ||
log_lik_sum += unbiased_log_lik - 0.5 * variance | ||
return log_lik_sum | ||
|
||
return estimator | ||
|
||
|
||
def _sum_all_except_at_dim(x, dim): | ||
x = x.reshape((-1,) + x.shape[dim:]).sum(0) | ||
return x.reshape(x.shape[:1] + (-1,)).sum(-1) | ||
|
||
|
||
def _update_block(rng_key, num_blocks, subsample_idx, plate_size): | ||
size, subsample_size = plate_size | ||
rng_key, subkey, block_key = random.split(rng_key, 3) | ||
block_size = (subsample_size - 1) // num_blocks + 1 | ||
pad = block_size - (subsample_size - 1) % block_size - 1 | ||
|
||
chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) | ||
new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) | ||
subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) | ||
start = chosen_block * block_size | ||
subsample_idx_padded = lax.dynamic_update_slice_in_dim( | ||
subsample_idx_padded, new_idx, start, 0 | ||
) | ||
return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start | ||
|
||
|
||
def block_update(plate_sizes, num_blocks, rng_key, gibbs_sites, gibbs_state): | ||
u_new = {} | ||
for name, subsample_idx in gibbs_sites.items(): | ||
rng_key, u_new[name], *_ = _update_block( | ||
rng_key, num_blocks, subsample_idx, plate_sizes[name] | ||
) | ||
return u_new, gibbs_state | ||
|
||
|
||
def _block_update_proxy(num_blocks, rng_key, gibbs_sites, plate_sizes): | ||
u_new = {} | ||
pads = {} | ||
new_idxs = {} | ||
starts = {} | ||
for name, subsample_idx in gibbs_sites.items(): | ||
rng_key, u_new[name], pads[name], new_idxs[name], starts[name] = _update_block( | ||
rng_key, num_blocks, subsample_idx, plate_sizes[name] | ||
) | ||
return u_new, pads, new_idxs, starts | ||
|
||
|
||
def taylor_proxy(reference_params, degree): | ||
"""Control variate for unbiased log likelihood estimation using a Taylor expansion around a reference | ||
parameter. Suggested for subsampling in [1]. | ||
:param dict reference_params: Model parameterization at MLE or MAP-estimate. | ||
:param degree: number of terms in the Taylor expansion, either one or two. | ||
**References:** | ||
[1] On Markov chain Monte Carlo Methods For Tall Data | ||
Bardenet., R., Doucet, A., Holmes, C. (2017) | ||
""" | ||
|
||
def construct_proxy_fn( | ||
prototype_trace, | ||
subsample_plate_sizes, | ||
model, | ||
model_args, | ||
model_kwargs, | ||
num_blocks=1, | ||
): | ||
ref_params = { | ||
name: biject_to(prototype_trace[name]["fn"].support).inv(value) | ||
for name, value in reference_params.items() | ||
} | ||
|
||
ref_params_flat, unravel_fn = ravel_pytree(ref_params) | ||
|
||
def log_likelihood(params_flat, subsample_indices=None): | ||
if subsample_indices is None: | ||
subsample_indices = { | ||
k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() | ||
} | ||
params = unravel_fn(params_flat) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
params = { | ||
name: biject_to(prototype_trace[name]["fn"].support)(value) | ||
for name, value in params.items() | ||
} | ||
with ( | ||
block(), | ||
trace() as tr, | ||
substitute(data=subsample_indices), | ||
substitute(data=params), | ||
): | ||
model(*model_args, **model_kwargs) | ||
|
||
log_lik = {} | ||
for site in tr.values(): | ||
if site["type"] == "sample" and site["is_observed"]: | ||
for frame in site["cond_indep_stack"]: | ||
if frame.name in log_lik: | ||
log_lik[frame.name] += _sum_all_except_at_dim( | ||
site["fn"].log_prob(site["value"]), frame.dim | ||
) | ||
elif frame.name in subsample_indices: | ||
log_lik[frame.name] = _sum_all_except_at_dim( | ||
site["fn"].log_prob(site["value"]), frame.dim | ||
) | ||
return log_lik | ||
|
||
def log_likelihood_sum(params_flat, subsample_indices=None): | ||
return { | ||
k: v.sum() | ||
for k, v in log_likelihood(params_flat, subsample_indices).items() | ||
} | ||
|
||
if degree == 2: | ||
TPState = TaylorTwoProxyState | ||
elif 1: | ||
TPState = TaylorOneProxyState | ||
else: | ||
raise ValueError( | ||
"Taylor proxy only defined for first and second degree." | ||
) | ||
|
||
# those stats are dict keyed by subsample names | ||
ref_sum_log_lik = log_likelihood_sum(ref_params_flat) | ||
ref_sum_log_lik_grads = jacobian(log_likelihood_sum)(ref_params_flat) | ||
|
||
if degree == 2: | ||
ref_sum_log_lik_hessians = hessian(log_likelihood_sum)(ref_params_flat) | ||
|
||
def gibbs_init(rng_key, gibbs_sites): | ||
|
||
ref_subsamples_taylor = [ | ||
log_likelihood(ref_params_flat, gibbs_sites), | ||
jacobian(log_likelihood)(ref_params_flat, gibbs_sites), | ||
] | ||
|
||
if degree == 2: | ||
ref_subsamples_taylor.append( | ||
hessian(log_likelihood)(ref_params_flat, gibbs_sites) | ||
) | ||
|
||
return TPState(*ref_subsamples_taylor) | ||
|
||
def gibbs_update(rng_key, gibbs_sites, gibbs_state): | ||
u_new, pads, new_idxs, starts = _block_update_proxy( | ||
num_blocks, rng_key, gibbs_sites, subsample_plate_sizes | ||
) | ||
|
||
new_states = defaultdict(dict) | ||
new_ref_subsample_taylor = [ | ||
log_likelihood(ref_params_flat, new_idxs), | ||
jacobian(log_likelihood)(ref_params_flat, new_idxs), | ||
] | ||
|
||
if degree == 2: | ||
new_ref_subsample_taylor.append( | ||
hessian(log_likelihood)(ref_params_flat, new_idxs) | ||
) | ||
|
||
last_ref_subsample_taylor = list(gibbs_state._asdict().values()) | ||
|
||
for stat, new_block_values, last_values in zip( | ||
TPState._fields, | ||
new_ref_subsample_taylor, | ||
last_ref_subsample_taylor, | ||
): | ||
for name, subsample_idx in gibbs_sites.items(): | ||
size, subsample_size = subsample_plate_sizes[name] | ||
pad, start = pads[name], starts[name] | ||
new_value = jnp.pad( | ||
last_values[name], | ||
[(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1), | ||
) | ||
new_value = lax.dynamic_update_slice_in_dim( | ||
new_value, new_block_values[name], start, 0 | ||
) | ||
new_states[stat][name] = new_value[:subsample_size] | ||
|
||
gibbs_state = TPState(**new_states) | ||
return u_new, gibbs_state | ||
|
||
def proxy_fn(params, subsample_lik_sites, gibbs_state): | ||
params_flat, _ = ravel_pytree(params) | ||
params_diff = params_flat - ref_params_flat | ||
|
||
ref_subsample_log_liks = gibbs_state.ref_subsample_log_liks | ||
ref_subsample_log_lik_grads = gibbs_state.ref_subsample_log_lik_grads | ||
if degree == 2: | ||
ref_subsample_log_lik_hessians = ( | ||
gibbs_state.ref_subsample_log_lik_hessians | ||
) | ||
|
||
proxy_sum = defaultdict(float) | ||
proxy_subsample = defaultdict(float) | ||
for name in subsample_lik_sites: | ||
proxy_subsample[name] = ref_subsample_log_liks[name] + jnp.dot( | ||
ref_subsample_log_lik_grads[name], params_diff | ||
) | ||
high_order_terms = 0.0 | ||
if degree == 2: | ||
high_order_terms = 0.5 * jnp.dot( | ||
jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), | ||
params_diff, | ||
) | ||
|
||
proxy_subsample[name] = proxy_subsample[name] + high_order_terms | ||
|
||
proxy_sum[name] = ref_sum_log_lik[name] + jnp.dot( | ||
ref_sum_log_lik_grads[name], params_diff | ||
) | ||
|
||
high_order_terms = 0.0 | ||
if degree == 2: | ||
high_order_terms = 0.5 * jnp.dot( | ||
jnp.dot(ref_sum_log_lik_hessians[name], params_diff), | ||
params_diff, | ||
) | ||
proxy_sum[name] = proxy_sum[name] + high_order_terms | ||
|
||
return proxy_sum, proxy_subsample | ||
|
||
return proxy_fn, gibbs_init, gibbs_update | ||
|
||
return construct_proxy_fn |
Oops, something went wrong.