-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraceMarkovEnum_ELBO - a trace implementation of ELBO-based SVI that supports model side enumeration of Markov variables #2716
Changes from all commits
bb6afb2
e87b754
9ed513f
443ac20
eaed1b9
9beb837
e237ebb
1bd8f2f
add3bc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,11 @@ def terms_from_trace(tr): | |
# data structure containing densities, measures, scales, and identification | ||
# of free variables as either product (plate) variables or sum (measure) variables | ||
terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.), | ||
"plate_vars": frozenset(), "measure_vars": frozenset()} | ||
"plate_vars": frozenset(), "measure_vars": frozenset(), "plate_to_step": dict()} | ||
for name, node in tr.nodes.items(): | ||
# add markov dimensions to the plate_to_step dictionary | ||
if node["type"] == "markov_chain": | ||
terms["plate_to_step"][node["name"]] = node["value"] | ||
if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample": | ||
continue | ||
# grab plate dimensions from the cond_indep_stack | ||
|
@@ -39,9 +42,72 @@ def terms_from_trace(tr): | |
# grab the log-density, found at all sites except those that are not replayed | ||
if node["is_observed"] or not node.get("replay_skipped", False): | ||
terms["log_factors"].append(node["funsor"]["log_prob"]) | ||
# add plate dimensions to the plate_to_step dictionary | ||
terms["plate_to_step"].update({plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]}) | ||
return terms | ||
|
||
|
||
@copy_docs_from(_OrigTraceEnum_ELBO) | ||
class TraceMarkovEnum_ELBO(ELBO): | ||
|
||
def differentiable_loss(self, model, guide, *args, **kwargs): | ||
|
||
# get batched, enumerated, to_funsor-ed traces from the guide and model | ||
with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ | ||
enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): | ||
guide_tr = trace(guide).get_trace(*args, **kwargs) | ||
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) | ||
|
||
# extract from traces all metadata that we will need to compute the elbo | ||
guide_terms = terms_from_trace(guide_tr) | ||
model_terms = terms_from_trace(model_tr) | ||
|
||
# guide side enumeration is not supported | ||
if any(guide_terms["plate_to_step"].values()): | ||
raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration") | ||
|
||
# build up a lazy expression for the elbo | ||
with funsor.interpreter.interpretation(funsor.terms.lazy): | ||
# identify and contract out auxiliary variables in the model with partial_sum_product | ||
contracted_factors, uncontracted_factors = [], [] | ||
for f in model_terms["log_factors"]: | ||
if model_terms["measure_vars"].intersection(f.inputs): | ||
contracted_factors.append(f) | ||
else: | ||
uncontracted_factors.append(f) | ||
# incorporate the effects of subsampling and handlers.scale through a common scale factor | ||
markov_dims = frozenset({ | ||
plate for plate, step in model_terms["plate_to_step"].items() if step}) | ||
contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.modified_partial_sum_product( | ||
funsor.ops.logaddexp, funsor.ops.add, | ||
model_terms["log_measures"] + contracted_factors, | ||
plate_to_step=model_terms["plate_to_step"], | ||
eliminate=model_terms["measure_vars"] | markov_dims | ||
)] | ||
|
||
costs = contracted_costs + uncontracted_factors # model costs: logp | ||
costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq | ||
|
||
# finally, integrate out guide variables in the elbo and all plates | ||
plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] | ||
elbo = to_funsor(0, output=funsor.Real) | ||
for cost in costs: | ||
# compute the marginal logq in the guide corresponding to this cost term | ||
log_prob = funsor.sum_product.sum_product( | ||
funsor.ops.logaddexp, funsor.ops.add, | ||
guide_terms["log_measures"], | ||
plates=plate_vars, | ||
eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs) | ||
) | ||
# compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q | ||
elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) | ||
elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note these two lines will not be correct for markov plates; you should try to write a guide enumeration test where this causes the sequential and vectorized markov ELBO values to disagree. I believe some of the test cases I suggested would be sufficient. To fix this, you will need to rewrite these two lines as another funsor.Integrate(log_prob, integrand, reduced_vars) is equivalent to (log_prob.exp() * integrand).reduce(ops.add, reduced_vars) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ok. I haven't checked the math yet. So there might be some more errors.
Thanks for explaining this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After taking a closer look it seems like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In other words the |
||
|
||
# evaluate the elbo, using memoize to share tensor computation where possible | ||
with funsor.memoize.memoize(): | ||
return -to_data(funsor.optimizer.apply_optimizer(elbo)) | ||
|
||
|
||
@copy_docs_from(_OrigTraceEnum_ELBO) | ||
class TraceEnum_ELBO(ELBO): | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,11 @@ | |
# put all funsor-related imports here, so test collection works without funsor | ||
try: | ||
import funsor | ||
from funsor.testing import assert_close | ||
import pyro.contrib.funsor | ||
from pyroapi import distributions as dist | ||
funsor.set_backend("torch") | ||
from pyroapi import handlers, pyro, pyro_backend | ||
from pyroapi import handlers, pyro, pyro_backend, infer | ||
except ImportError: | ||
pytestmark = pytest.mark.skip(reason="funsor is not installed") | ||
|
||
|
@@ -280,25 +281,24 @@ def model_7(data, history, vectorized): | |
(model_7, torch.ones((5, 4), dtype=torch.long), "wxy", 1), | ||
(model_7, torch.ones((50, 4), dtype=torch.long), "wxy", 1), | ||
]) | ||
def test_vectorized_markov(model, data, var, history, use_replay): | ||
|
||
with pyro_backend("contrib.funsor"), \ | ||
handlers.enum(): | ||
# sequential trace | ||
trace = handlers.trace(model).get_trace(data, history, False) | ||
def test_enumeration(model, data, var, history, use_replay): | ||
|
||
with pyro_backend("contrib.funsor"): | ||
with handlers.enum(): | ||
# sequential trace | ||
trace = handlers.trace(model).get_trace(data, history, False) | ||
# vectorized trace | ||
vectorized_trace = handlers.trace(model).get_trace(data, history, True) | ||
if use_replay: | ||
vectorized_trace = handlers.trace( | ||
handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True) | ||
|
||
# sequential factors | ||
factors = list() | ||
for i in range(data.shape[-2]): | ||
for v in var: | ||
factors.append(trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) | ||
|
||
# vectorized trace | ||
vectorized_trace = handlers.trace(model).get_trace(data, history, True) | ||
if use_replay: | ||
vectorized_trace = handlers.trace( | ||
handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True) | ||
|
||
# vectorized factors | ||
vectorized_factors = list() | ||
for i in range(history): | ||
|
@@ -315,7 +315,7 @@ def test_vectorized_markov(model, data, var, history, use_replay): | |
|
||
# assert correct factors | ||
for f1, f2 in zip(factors, vectorized_factors): | ||
funsor.testing.assert_close(f2, f1.align(tuple(f2.inputs))) | ||
assert_close(f2, f1.align(tuple(f2.inputs))) | ||
|
||
# assert correct step | ||
actual_step = vectorized_trace.nodes["time"]["value"] | ||
|
@@ -382,12 +382,18 @@ def model_8(weeks_data, days_data, history, vectorized): | |
(model_8, torch.ones(3), torch.zeros(9), "xy", "wz", 1), | ||
(model_8, torch.ones(30), torch.zeros(50), "xy", "wz", 1), | ||
]) | ||
def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): | ||
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): | ||
|
||
with pyro_backend("contrib.funsor"), \ | ||
handlers.enum(): | ||
# sequential factors | ||
trace = handlers.trace(model).get_trace(weeks_data, days_data, history, False) | ||
with pyro_backend("contrib.funsor"): | ||
with handlers.enum(): | ||
# sequential factors | ||
trace = handlers.trace(model).get_trace(weeks_data, days_data, history, False) | ||
|
||
# vectorized trace | ||
vectorized_trace = handlers.trace(model).get_trace(weeks_data, days_data, history, True) | ||
if use_replay: | ||
vectorized_trace = handlers.trace( | ||
handlers.replay(model, trace=vectorized_trace)).get_trace(weeks_data, days_data, history, True) | ||
|
||
factors = list() | ||
# sequential weeks factors | ||
|
@@ -399,12 +405,6 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his | |
for v in vars2: | ||
factors.append(trace.nodes["{}_{}".format(v, j)]["funsor"]["log_prob"]) | ||
|
||
# vectorized trace | ||
vectorized_trace = handlers.trace(model).get_trace(weeks_data, days_data, history, True) | ||
if use_replay: | ||
vectorized_trace = handlers.trace( | ||
handlers.replay(model, trace=vectorized_trace)).get_trace(weeks_data, days_data, history, True) | ||
|
||
vectorized_factors = list() | ||
# vectorized weeks factors | ||
for i in range(history): | ||
|
@@ -435,7 +435,7 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his | |
|
||
# assert correct factors | ||
for f1, f2 in zip(factors, vectorized_factors): | ||
funsor.testing.assert_close(f2, f1.align(tuple(f2.inputs))) | ||
assert_close(f2, f1.align(tuple(f2.inputs))) | ||
|
||
# assert correct step | ||
|
||
|
@@ -457,3 +457,122 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his | |
|
||
assert actual_weeks_step == expected_weeks_step | ||
assert actual_days_step == expected_days_step | ||
|
||
|
||
def guide_empty(data, history, vectorized): | ||
pass | ||
|
||
|
||
@pytest.mark.parametrize("model,guide,data,history", [ | ||
(model_0, guide_empty, torch.rand(3, 5, 4), 1), | ||
(model_1, guide_empty, torch.rand(5, 4), 1), | ||
(model_2, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), | ||
(model_3, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), | ||
(model_4, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), | ||
(model_5, guide_empty, torch.ones((5, 4), dtype=torch.long), 2), | ||
(model_6, guide_empty, torch.rand(5, 4), 1), | ||
(model_6, guide_empty, torch.rand(100, 4), 1), | ||
(model_7, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), | ||
(model_7, guide_empty, torch.ones((50, 4), dtype=torch.long), 1), | ||
]) | ||
def test_model_enumerated_elbo(model, guide, data, history): | ||
pyro.clear_param_store() | ||
|
||
with pyro_backend("contrib.funsor"): | ||
if history > 1: | ||
pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") | ||
|
||
elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) | ||
expected_loss = elbo.loss_and_grads(model, guide, data, history, False) | ||
expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) | ||
|
||
vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) | ||
actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) | ||
actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) | ||
|
||
assert_close(actual_loss, expected_loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests look good, but can you also check gradients? expected_loss = ...
params = tuple(pyro.get_param_store().values())
expected_grads = torch.autograd.grad(expected_loss, params)
vectorized_elbo = ...
actual_loss = ...
actual_grads = torch.autograd.grad(actual_loss, params)
assert_close(actual_loss, expected_loss)
for actual_grad, expected_grad in zip(actual_grads, expected_grads):
assert_close(actual_grad, expected_grad) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Do I have to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could add |
||
for actual_grad, expected_grad in zip(actual_grads, expected_grads): | ||
assert_close(actual_grad, expected_grad) | ||
|
||
|
||
def guide_empty_multi(weeks_data, days_data, history, vectorized): | ||
pass | ||
|
||
|
||
@pytest.mark.parametrize("model,guide,weeks_data,days_data,history", [ | ||
(model_8, guide_empty_multi, torch.ones(3), torch.zeros(9), 1), | ||
(model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1), | ||
]) | ||
def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, history): | ||
pyro.clear_param_store() | ||
|
||
with pyro_backend("contrib.funsor"): | ||
|
||
elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) | ||
expected_loss = elbo.loss_and_grads(model, guide, weeks_data, days_data, history, False) | ||
expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) | ||
|
||
vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) | ||
actual_loss = vectorized_elbo.loss_and_grads(model, guide, weeks_data, days_data, history, True) | ||
actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) | ||
|
||
assert_close(actual_loss, expected_loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: check gradients |
||
for actual_grad, expected_grad in zip(actual_grads, expected_grads): | ||
assert_close(actual_grad, expected_grad) | ||
|
||
|
||
def model_10(data, vectorized): | ||
init_probs = torch.tensor([0.5, 0.5]) | ||
transition_probs = pyro.param("transition_probs", | ||
torch.tensor([[0.75, 0.25], [0.25, 0.75]]), | ||
constraint=constraints.simplex) | ||
emission_probs = pyro.param("emission_probs", | ||
torch.tensor([[0.75, 0.25], [0.25, 0.75]]), | ||
constraint=constraints.simplex) | ||
x = None | ||
markov_loop = \ | ||
pyro.vectorized_markov(name="time", size=len(data)) if vectorized \ | ||
else pyro.markov(range(len(data))) | ||
for i in markov_loop: | ||
probs = init_probs if x is None else transition_probs[x] | ||
x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) | ||
pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i]) | ||
|
||
|
||
def guide_10(data, vectorized): | ||
init_probs = torch.tensor([0.5, 0.5]) | ||
transition_probs = pyro.param("transition_probs", | ||
torch.tensor([[0.75, 0.25], [0.25, 0.75]]), | ||
constraint=constraints.simplex) | ||
x = None | ||
markov_loop = \ | ||
pyro.vectorized_markov(name="time", size=len(data)) if vectorized \ | ||
else pyro.markov(range(len(data))) | ||
for i in markov_loop: | ||
probs = init_probs if x is None else transition_probs[x] | ||
x = pyro.sample("x_{}".format(i), dist.Categorical(probs), | ||
infer={"enumerate": "parallel"}) | ||
|
||
|
||
@pytest.mark.parametrize("model,guide,data,", [ | ||
(model_10, guide_10, torch.ones(5)), | ||
]) | ||
def test_guide_enumerated_elbo(model, guide, data): | ||
pyro.clear_param_store() | ||
|
||
with pyro_backend("contrib.funsor"): | ||
with pytest.raises( | ||
NotImplementedError, | ||
match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): | ||
|
||
elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) | ||
expected_loss = elbo.loss_and_grads(model, guide, data, False) | ||
expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) | ||
|
||
vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) | ||
actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, True) | ||
actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) | ||
|
||
assert_close(actual_loss, expected_loss) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: check gradients (even though this test is currently xfailed) |
||
for actual_grad, expected_grad in zip(actual_grads, expected_grads): | ||
assert_close(actual_grad, expected_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a future PR, once we're confident that
TraceMarkovEnum_ELBO
is correct, we can just merge it withTraceEnum_ELBO
instead of keeping the implementations separate.