From bb6afb21796cf2a735d413f56c9c8a3b9d86fea8 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 17 Dec 2020 00:38:47 -0500 Subject: [PATCH 1/8] draft implementation --- pyro/contrib/funsor/infer/__init__.py | 2 +- pyro/contrib/funsor/infer/traceenum_elbo.py | 64 ++++++++++++++++++- .../contrib/funsor/test_vectorized_markov.py | 17 ++++- 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/pyro/contrib/funsor/infer/__init__.py b/pyro/contrib/funsor/infer/__init__.py index 475c572235..4525e2cef5 100644 --- a/pyro/contrib/funsor/infer/__init__.py +++ b/pyro/contrib/funsor/infer/__init__.py @@ -6,4 +6,4 @@ from .elbo import ELBO # noqa: F401 from .trace_elbo import JitTrace_ELBO, Trace_ELBO # noqa: F401 from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401 -from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO # noqa: F401 +from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401 diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index c11f8b897a..fb98c2e077 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -18,8 +18,10 @@ def terms_from_trace(tr): # data structure containing densities, measures, scales, and identification # of free variables as either product (plate) variables or sum (measure) variables terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.), - "plate_vars": frozenset(), "measure_vars": frozenset()} + "plate_vars": frozenset(), "measure_vars": frozenset(), "plate_to_step": dict()} for name, node in tr.nodes.items(): + if node["type"] == "markov_chain": + terms["plate_to_step"][node["name"]] = node["value"] if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample": continue # grab plate dimensions from the cond_indep_stack @@ -39,9 +41,69 @@ def terms_from_trace(tr): # grab the log-density, found at all sites except those that are not replayed if node["is_observed"] or not node.get("replay_skipped", False): terms["log_factors"].append(node["funsor"]["log_prob"]) + for p in terms["plate_vars"]: + terms["plate_to_step"][p] = terms["plate_to_step"].get(p, {}) return terms +@copy_docs_from(_OrigTraceEnum_ELBO) +class TraceMarkovEnum_ELBO(ELBO): + + def differentiable_loss(self, model, guide, *args, **kwargs): + + # get batched, enumerated, to_funsor-ed traces from the guide and model + with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ + enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): + guide_tr = trace(guide).get_trace(*args, **kwargs) + model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) + + # extract from traces all metadata that we will need to compute the elbo + guide_terms = terms_from_trace(guide_tr) + model_terms = terms_from_trace(model_tr) + + # build up a lazy expression for the elbo + with funsor.interpreter.interpretation(funsor.terms.lazy): + # identify and contract out auxiliary variables in the model with partial_sum_product + contracted_factors, uncontracted_factors = [], [] + for f in model_terms["log_factors"]: + if model_terms["measure_vars"].intersection(f.inputs): + contracted_factors.append(f) + else: + uncontracted_factors.append(f) + # incorporate the effects of subsampling and handlers.scale through a common scale factor + contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.modified_partial_sum_product( + funsor.ops.logaddexp, funsor.ops.add, + model_terms["log_measures"] + contracted_factors, + plate_to_step=model_terms["plate_to_step"], + eliminate=model_terms["measure_vars"] | frozenset({ + p for p, s in model_terms["plate_to_step"].items() if s}) + )] + + costs = contracted_costs + uncontracted_factors # model costs: logp + costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq + + # finally, integrate out guide variables in the elbo and all plates + # dict1 | dict2 in python 3.9 + plate_to_step = {**guide_terms["plate_to_step"], **model_terms["plate_to_step"]} + plate_vars = frozenset(plate_to_step.keys()) + elbo = to_funsor(0, output=funsor.Real) + for cost in costs: + # compute the marginal logq in the guide corresponding to this cost term + log_prob = funsor.sum_product.modified_sum_product( + funsor.ops.logaddexp, funsor.ops.add, + guide_terms["log_measures"], + plate_to_step=plate_to_step, + eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs) + ) + # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q + elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) + elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) + + # evaluate the elbo, using memoize to share tensor computation where possible + with funsor.memoize.memoize(): + return -to_data(funsor.optimizer.apply_optimizer(elbo)) + + @copy_docs_from(_OrigTraceEnum_ELBO) class TraceEnum_ELBO(ELBO): diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 8726978f8b..7431d13617 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -11,10 +11,11 @@ # put all funsor-related imports here, so test collection works without funsor try: import funsor + from funsor.testing import assert_close import pyro.contrib.funsor from pyroapi import distributions as dist funsor.set_backend("torch") - from pyroapi import handlers, pyro, pyro_backend + from pyroapi import handlers, pyro, pyro_backend, infer except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") @@ -46,6 +47,10 @@ def model_0(data, history, vectorized): x_prev = x_curr +def guide(data, history, vectorized): + pass + + # x[t-1] --> x[t] --> x[t+1] # | | | # V V V @@ -327,6 +332,14 @@ def test_vectorized_markov(model, data, var, history, use_replay): expected_step |= frozenset({v_step}) assert actual_step == expected_step + with pyro_backend("contrib.funsor"): + if model != model_5: + elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) + expected_loss = elbo.differentiable_loss(model, guide, data, history, False) + vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) + actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) + assert_close(actual_loss, expected_loss) + # x[i-1] --> x[i] --> x[i+1] # | | | @@ -435,7 +448,7 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his # assert correct factors for f1, f2 in zip(factors, vectorized_factors): - funsor.testing.assert_close(f2, f1.align(tuple(f2.inputs))) + assert_close(f2, f1.align(tuple(f2.inputs))) # assert correct step From e87b7549c8a4ede9f624953a3c188f11ab1b9f96 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 17 Dec 2020 01:04:48 -0500 Subject: [PATCH 2/8] inline modified_sum_product --- pyro/contrib/funsor/infer/traceenum_elbo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index fb98c2e077..4938e297ed 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +from functools import reduce import funsor @@ -89,12 +90,13 @@ def differentiable_loss(self, model, guide, *args, **kwargs): elbo = to_funsor(0, output=funsor.Real) for cost in costs: # compute the marginal logq in the guide corresponding to this cost term - log_prob = funsor.sum_product.modified_sum_product( + factors = funsor.sum_product.modified_partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, guide_terms["log_measures"], plate_to_step=plate_to_step, eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs) ) + log_prob = reduce(funsor.ops.add, factors, funsor.Number(funsor.ops.UNITS[funsor.ops.add])) # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) From 9ed513f077f750d44f20dc9c503febcd6ced2a90 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 17 Dec 2020 13:03:32 -0500 Subject: [PATCH 3/8] xfail --- tests/contrib/funsor/test_vectorized_markov.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 7431d13617..2a8af28d35 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -320,7 +320,7 @@ def test_vectorized_markov(model, data, var, history, use_replay): # assert correct factors for f1, f2 in zip(factors, vectorized_factors): - funsor.testing.assert_close(f2, f1.align(tuple(f2.inputs))) + assert_close(f2, f1.align(tuple(f2.inputs))) # assert correct step actual_step = vectorized_trace.nodes["time"]["value"] @@ -333,12 +333,14 @@ def test_vectorized_markov(model, data, var, history, use_replay): assert actual_step == expected_step with pyro_backend("contrib.funsor"): - if model != model_5: - elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) - expected_loss = elbo.differentiable_loss(model, guide, data, history, False) - vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) - actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) - assert_close(actual_loss, expected_loss) + if history > 1: + pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") + + elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) + expected_loss = elbo.differentiable_loss(model, guide, data, history, False) + vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) + actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) + assert_close(actual_loss, expected_loss) # x[i-1] --> x[i] --> x[i+1] From eaed1b96999f36248f0fa872fabbf66664e10129 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 17 Dec 2020 17:09:11 -0500 Subject: [PATCH 4/8] use partial_unroll needs fix --- .../contrib/funsor/test_vectorized_markov.py | 71 ++++++++----------- 1 file changed, 29 insertions(+), 42 deletions(-) diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 2a8af28d35..de5f1c0566 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -16,6 +16,7 @@ from pyroapi import distributions as dist funsor.set_backend("torch") from pyroapi import handlers, pyro, pyro_backend, infer + from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") @@ -272,6 +273,7 @@ def model_7(data, history, vectorized): x_prev, w_prev = x_curr, w_curr + @pytest.mark.parametrize("use_replay", [True, False]) @pytest.mark.parametrize("model,data,var,history", [ (model_0, torch.rand(3, 5, 4), "xy", 1), @@ -287,59 +289,44 @@ def model_7(data, history, vectorized): ]) def test_vectorized_markov(model, data, var, history, use_replay): - with pyro_backend("contrib.funsor"), \ - handlers.enum(): - # sequential trace - trace = handlers.trace(model).get_trace(data, history, False) - - # sequential factors - factors = list() - for i in range(data.shape[-2]): - for v in var: - factors.append(trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) - - # vectorized trace - vectorized_trace = handlers.trace(model).get_trace(data, history, True) - if use_replay: - vectorized_trace = handlers.trace( - handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True) - - # vectorized factors - vectorized_factors = list() - for i in range(history): - for v in var: - vectorized_factors.append(vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) - for i in range(history, data.shape[-2]): - for v in var: - vectorized_factors.append( - vectorized_trace.nodes["{}_{}".format(v, slice(history, data.shape[-2]))]["funsor"]["log_prob"] - (**{"time": i-history}, - **{"{}_{}".format(k, slice(history-j, data.shape[-2]-j)): "{}_{}".format(k, i-j) - for j in range(history+1) for k in var}) - ) - + with pyro_backend("contrib.funsor"): + with handlers.enum(): + # sequential trace + trace = handlers.trace(model).get_trace(data, history, False) + # vectorized trace + vectorized_trace = handlers.trace(model).get_trace(data, history, True) + if use_replay: + vectorized_trace = handlers.trace( + handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True) + + terms = terms_from_trace(trace) + vectorized_terms = terms_from_trace(vectorized_trace) + + factors = terms["log_factors"] + vectorized_factors = vectorized_terms["log_factors"] + time_vars = frozenset({key for key, value in vectorized_terms["plate_to_step"].items() if value}) + markov_vars = set.union(*(set(chain) for time in time_vars + for chain in vectorized_terms["plate_to_step"][time])) + vectorized_factors, _, _ = funsor.sum_product.partial_unroll( + vectorized_factors, + eliminate=time_vars | markov_vars, + plate_to_step=vectorized_terms["plate_to_step"]) + + factor_names = [set(f.inputs) for f in factors] + vectorized_factors.sort(key=lambda x: factor_names.index(set(x.inputs))) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): assert_close(f2, f1.align(tuple(f2.inputs))) - # assert correct step - actual_step = vectorized_trace.nodes["time"]["value"] - # expected step: assume that all but the last var is markov - expected_step = frozenset() - for v in var[:-1]: - v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ - + tuple("{}_{}".format(v, slice(j, data.shape[-2]-history+j)) for j in range(history+1)) - expected_step |= frozenset({v_step}) - assert actual_step == expected_step - - with pyro_backend("contrib.funsor"): if history > 1: pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.differentiable_loss(model, guide, data, history, False) + vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) + assert_close(actual_loss, expected_loss) From 9beb8374ca62ccbfbb44775e0494d685e8147dce Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 20 Dec 2020 01:10:23 -0500 Subject: [PATCH 5/8] guide enumeration NotImplementedError --- pyro/contrib/funsor/infer/traceenum_elbo.py | 21 +- .../contrib/funsor/test_vectorized_markov.py | 188 ++++++++++++++---- 2 files changed, 156 insertions(+), 53 deletions(-) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index 4938e297ed..b9a5ed844b 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib -from functools import reduce import funsor @@ -22,6 +21,7 @@ def terms_from_trace(tr): "plate_vars": frozenset(), "measure_vars": frozenset(), "plate_to_step": dict()} for name, node in tr.nodes.items(): if node["type"] == "markov_chain": + # add markov dimensions to the plate_to_step dictionary terms["plate_to_step"][node["name"]] = node["value"] if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample": continue @@ -31,7 +31,8 @@ def terms_from_trace(tr): if node["funsor"].get("log_measure", None) is not None: terms["log_measures"].append(node["funsor"]["log_measure"]) # sum (measure) variables: the fresh non-plate variables at a site - terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"] + # terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"] + terms["measure_vars"] |= (frozenset(node["funsor"]["log_prob"].inputs) | {name}) - terms["plate_vars"] # grab the scale, assuming a common subsampling scale if node.get("replay_active", False) and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] and \ float(to_data(node["funsor"]["scale"])) != 1.: @@ -42,8 +43,8 @@ def terms_from_trace(tr): # grab the log-density, found at all sites except those that are not replayed if node["is_observed"] or not node.get("replay_skipped", False): terms["log_factors"].append(node["funsor"]["log_prob"]) - for p in terms["plate_vars"]: - terms["plate_to_step"][p] = terms["plate_to_step"].get(p, {}) + # add plate dimenstions to the plate_to_step dictionary + terms["plate_to_step"].update({plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]}) return terms @@ -62,6 +63,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs): guide_terms = terms_from_trace(guide_tr) model_terms = terms_from_trace(model_tr) + if any(guide_terms["plate_to_step"].values()): + raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side enumeration") + # build up a lazy expression for the elbo with funsor.interpreter.interpretation(funsor.terms.lazy): # identify and contract out auxiliary variables in the model with partial_sum_product @@ -84,19 +88,16 @@ def differentiable_loss(self, model, guide, *args, **kwargs): costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq # finally, integrate out guide variables in the elbo and all plates - # dict1 | dict2 in python 3.9 - plate_to_step = {**guide_terms["plate_to_step"], **model_terms["plate_to_step"]} - plate_vars = frozenset(plate_to_step.keys()) + plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] elbo = to_funsor(0, output=funsor.Real) for cost in costs: # compute the marginal logq in the guide corresponding to this cost term - factors = funsor.sum_product.modified_partial_sum_product( + log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, guide_terms["log_measures"], - plate_to_step=plate_to_step, + plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs) ) - log_prob = reduce(funsor.ops.add, factors, funsor.Number(funsor.ops.UNITS[funsor.ops.add])) # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index de5f1c0566..0fa8523a4d 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -16,7 +16,6 @@ from pyroapi import distributions as dist funsor.set_backend("torch") from pyroapi import handlers, pyro, pyro_backend, infer - from pyro.contrib.funsor.infer.traceenum_elbo import terms_from_trace except ImportError: pytestmark = pytest.mark.skip(reason="funsor is not installed") @@ -48,10 +47,6 @@ def model_0(data, history, vectorized): x_prev = x_curr -def guide(data, history, vectorized): - pass - - # x[t-1] --> x[t] --> x[t+1] # | | | # V V V @@ -273,7 +268,6 @@ def model_7(data, history, vectorized): x_prev, w_prev = x_curr, w_curr - @pytest.mark.parametrize("use_replay", [True, False]) @pytest.mark.parametrize("model,data,var,history", [ (model_0, torch.rand(3, 5, 4), "xy", 1), @@ -287,7 +281,7 @@ def model_7(data, history, vectorized): (model_7, torch.ones((5, 4), dtype=torch.long), "wxy", 1), (model_7, torch.ones((50, 4), dtype=torch.long), "wxy", 1), ]) -def test_vectorized_markov(model, data, var, history, use_replay): +def test_enumeration(model, data, var, history, use_replay): with pyro_backend("contrib.funsor"): with handlers.enum(): @@ -299,35 +293,39 @@ def test_vectorized_markov(model, data, var, history, use_replay): vectorized_trace = handlers.trace( handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True) - terms = terms_from_trace(trace) - vectorized_terms = terms_from_trace(vectorized_trace) - - factors = terms["log_factors"] - vectorized_factors = vectorized_terms["log_factors"] - time_vars = frozenset({key for key, value in vectorized_terms["plate_to_step"].items() if value}) - markov_vars = set.union(*(set(chain) for time in time_vars - for chain in vectorized_terms["plate_to_step"][time])) - vectorized_factors, _, _ = funsor.sum_product.partial_unroll( - vectorized_factors, - eliminate=time_vars | markov_vars, - plate_to_step=vectorized_terms["plate_to_step"]) - - factor_names = [set(f.inputs) for f in factors] - vectorized_factors.sort(key=lambda x: factor_names.index(set(x.inputs))) + # sequential factors + factors = list() + for i in range(data.shape[-2]): + for v in var: + factors.append(trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) + + # vectorized factors + vectorized_factors = list() + for i in range(history): + for v in var: + vectorized_factors.append(vectorized_trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"]) + for i in range(history, data.shape[-2]): + for v in var: + vectorized_factors.append( + vectorized_trace.nodes["{}_{}".format(v, slice(history, data.shape[-2]))]["funsor"]["log_prob"] + (**{"time": i-history}, + **{"{}_{}".format(k, slice(history-j, data.shape[-2]-j)): "{}_{}".format(k, i-j) + for j in range(history+1) for k in var}) + ) + # assert correct factors for f1, f2 in zip(factors, vectorized_factors): assert_close(f2, f1.align(tuple(f2.inputs))) - if history > 1: - pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") - - elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) - expected_loss = elbo.differentiable_loss(model, guide, data, history, False) - - vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) - actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) - - assert_close(actual_loss, expected_loss) + # assert correct step + actual_step = vectorized_trace.nodes["time"]["value"] + # expected step: assume that all but the last var is markov + expected_step = frozenset() + for v in var[:-1]: + v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + + tuple("{}_{}".format(v, slice(j, data.shape[-2]-history+j)) for j in range(history+1)) + expected_step |= frozenset({v_step}) + assert actual_step == expected_step # x[i-1] --> x[i] --> x[i+1] @@ -384,12 +382,18 @@ def model_8(weeks_data, days_data, history, vectorized): (model_8, torch.ones(3), torch.zeros(9), "xy", "wz", 1), (model_8, torch.ones(30), torch.zeros(50), "xy", "wz", 1), ]) -def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): +def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): - with pyro_backend("contrib.funsor"), \ - handlers.enum(): - # sequential factors - trace = handlers.trace(model).get_trace(weeks_data, days_data, history, False) + with pyro_backend("contrib.funsor"): + with handlers.enum(): + # sequential factors + trace = handlers.trace(model).get_trace(weeks_data, days_data, history, False) + + # vectorized trace + vectorized_trace = handlers.trace(model).get_trace(weeks_data, days_data, history, True) + if use_replay: + vectorized_trace = handlers.trace( + handlers.replay(model, trace=vectorized_trace)).get_trace(weeks_data, days_data, history, True) factors = list() # sequential weeks factors @@ -401,12 +405,6 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his for v in vars2: factors.append(trace.nodes["{}_{}".format(v, j)]["funsor"]["log_prob"]) - # vectorized trace - vectorized_trace = handlers.trace(model).get_trace(weeks_data, days_data, history, True) - if use_replay: - vectorized_trace = handlers.trace( - handlers.replay(model, trace=vectorized_trace)).get_trace(weeks_data, days_data, history, True) - vectorized_factors = list() # vectorized weeks factors for i in range(history): @@ -459,3 +457,107 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his assert actual_weeks_step == expected_weeks_step assert actual_days_step == expected_days_step + + +def guide_empty(data, history, vectorized): + pass + + +@pytest.mark.parametrize("model,guide,data,history", [ + (model_0, guide_empty, torch.rand(3, 5, 4), 1), + (model_1, guide_empty, torch.rand(5, 4), 1), + (model_2, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_3, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_4, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_5, guide_empty, torch.ones((5, 4), dtype=torch.long), 2), + (model_6, guide_empty, torch.rand(5, 4), 1), + (model_6, guide_empty, torch.rand(100, 4), 1), + (model_7, guide_empty, torch.ones((5, 4), dtype=torch.long), 1), + (model_7, guide_empty, torch.ones((50, 4), dtype=torch.long), 1), +]) +def test_model_enumerated_elbo(model, guide, data, history): + + with pyro_backend("contrib.funsor"): + if history > 1: + pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") + + elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) + expected_loss = elbo.differentiable_loss(model, guide, data, history, False) + + vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) + actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) + + assert_close(actual_loss, expected_loss) + + +def guide_empyt_multi(weeks_data, days_data, history, vectorized): + pass + + +@pytest.mark.parametrize("model,guide,weeks_data,days_data,history", [ + (model_8, guide_empyt_multi, torch.ones(3), torch.zeros(9), 1), + (model_8, guide_empyt_multi, torch.ones(30), torch.zeros(50), 1), +]) +def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, history): + + with pyro_backend("contrib.funsor"): + + elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) + expected_loss = elbo.differentiable_loss(model, guide, weeks_data, days_data, history, False) + + vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) + actual_loss = vectorized_elbo.differentiable_loss(model, guide, weeks_data, days_data, history, True) + + assert_close(actual_loss, expected_loss) + + +def model_10(data, vectorized): + init_probs = torch.tensor([0.5, 0.5]) + transition_probs = pyro.param("transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex) + emission_probs = pyro.param("emission_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex) + x = None + markov_loop = \ + pyro.vectorized_markov(name="time", size=len(data)) if vectorized \ + else pyro.markov(range(len(data))) + for i in markov_loop: + probs = init_probs if x is None else transition_probs[x] + x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) + pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i]) + + +def guide_10(data, vectorized): + init_probs = torch.tensor([0.5, 0.5]) + transition_probs = pyro.param("transition_probs", + torch.tensor([[0.75, 0.25], [0.25, 0.75]]), + constraint=constraints.simplex) + x = None + markov_loop = \ + pyro.vectorized_markov(name="time", size=len(data)) if vectorized \ + else pyro.markov(range(len(data))) + for i in markov_loop: + probs = init_probs if x is None else transition_probs[x] + x = pyro.sample("x_{}".format(i), dist.Categorical(probs), + infer={"enumerate": "parallel"}) + + +@pytest.mark.parametrize("model,guide,data,", [ + (model_10, guide_10, torch.ones(5)), +]) +def test_guide_enumerated_elbo_guide(model, guide, data): + + with pyro_backend("contrib.funsor"): + with pytest.raises( + NotImplementedError, + match="TraceMarkovEnum_ELBO does not yet support guide side enumeration"): + + elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) + expected_loss = elbo.differentiable_loss(model, guide, data, False) + + vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) + actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, True) + + assert_close(actual_loss, expected_loss) From e237ebbc0e88145405205aff998f1f9af59fb2e9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 20 Dec 2020 01:34:06 -0500 Subject: [PATCH 6/8] fix typos --- pyro/contrib/funsor/infer/traceenum_elbo.py | 13 +++++++------ tests/contrib/funsor/test_vectorized_markov.py | 8 ++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index b9a5ed844b..a75d5021f7 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -20,8 +20,8 @@ def terms_from_trace(tr): terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.), "plate_vars": frozenset(), "measure_vars": frozenset(), "plate_to_step": dict()} for name, node in tr.nodes.items(): + # add markov dimensions to the plate_to_step dictionary if node["type"] == "markov_chain": - # add markov dimensions to the plate_to_step dictionary terms["plate_to_step"][node["name"]] = node["value"] if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample": continue @@ -31,8 +31,7 @@ def terms_from_trace(tr): if node["funsor"].get("log_measure", None) is not None: terms["log_measures"].append(node["funsor"]["log_measure"]) # sum (measure) variables: the fresh non-plate variables at a site - # terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"] - terms["measure_vars"] |= (frozenset(node["funsor"]["log_prob"].inputs) | {name}) - terms["plate_vars"] + terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"] # grab the scale, assuming a common subsampling scale if node.get("replay_active", False) and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] and \ float(to_data(node["funsor"]["scale"])) != 1.: @@ -43,7 +42,7 @@ def terms_from_trace(tr): # grab the log-density, found at all sites except those that are not replayed if node["is_observed"] or not node.get("replay_skipped", False): terms["log_factors"].append(node["funsor"]["log_prob"]) - # add plate dimenstions to the plate_to_step dictionary + # add plate dimensions to the plate_to_step dictionary terms["plate_to_step"].update({plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]}) return terms @@ -63,6 +62,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): guide_terms = terms_from_trace(guide_tr) model_terms = terms_from_trace(model_tr) + # guide side enumeration is not supported if any(guide_terms["plate_to_step"].values()): raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side enumeration") @@ -76,12 +76,13 @@ def differentiable_loss(self, model, guide, *args, **kwargs): else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor + markov_dims = frozenset({ + plate for plate, step in model_terms["plate_to_step"].items() if step}) contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.modified_partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, model_terms["log_measures"] + contracted_factors, plate_to_step=model_terms["plate_to_step"], - eliminate=model_terms["measure_vars"] | frozenset({ - p for p, s in model_terms["plate_to_step"].items() if s}) + eliminate=model_terms["measure_vars"] | markov_dims )] costs = contracted_costs + uncontracted_factors # model costs: logp diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 0fa8523a4d..4640188d43 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -490,13 +490,13 @@ def test_model_enumerated_elbo(model, guide, data, history): assert_close(actual_loss, expected_loss) -def guide_empyt_multi(weeks_data, days_data, history, vectorized): +def guide_empty_multi(weeks_data, days_data, history, vectorized): pass @pytest.mark.parametrize("model,guide,weeks_data,days_data,history", [ - (model_8, guide_empyt_multi, torch.ones(3), torch.zeros(9), 1), - (model_8, guide_empyt_multi, torch.ones(30), torch.zeros(50), 1), + (model_8, guide_empty_multi, torch.ones(3), torch.zeros(9), 1), + (model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1), ]) def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, history): @@ -547,7 +547,7 @@ def guide_10(data, vectorized): @pytest.mark.parametrize("model,guide,data,", [ (model_10, guide_10, torch.ones(5)), ]) -def test_guide_enumerated_elbo_guide(model, guide, data): +def test_guide_enumerated_elbo(model, guide, data): with pyro_backend("contrib.funsor"): with pytest.raises( From 1bd8f2fdb08cceb858f74b270e33a553ae323f1d Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 20 Dec 2020 16:47:28 -0500 Subject: [PATCH 7/8] address comments --- pyro/contrib/funsor/infer/traceenum_elbo.py | 2 +- .../contrib/funsor/test_vectorized_markov.py | 26 ++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/pyro/contrib/funsor/infer/traceenum_elbo.py b/pyro/contrib/funsor/infer/traceenum_elbo.py index a75d5021f7..d725da2c50 100644 --- a/pyro/contrib/funsor/infer/traceenum_elbo.py +++ b/pyro/contrib/funsor/infer/traceenum_elbo.py @@ -64,7 +64,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs): # guide side enumeration is not supported if any(guide_terms["plate_to_step"].values()): - raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side enumeration") + raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration") # build up a lazy expression for the elbo with funsor.interpreter.interpretation(funsor.terms.lazy): diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 4640188d43..434e8482ca 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -482,12 +482,16 @@ def test_model_enumerated_elbo(model, guide, data, history): pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) - expected_loss = elbo.differentiable_loss(model, guide, data, history, False) + expected_loss = elbo.loss_and_grads(model, guide, data, history, False) + expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) - actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) + actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) + actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) + for actual_grad, expected_grad in zip(actual_grads, expected_grads): + assert_close(actual_grad, expected_grad) def guide_empty_multi(weeks_data, days_data, history, vectorized): @@ -503,12 +507,16 @@ def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, histor with pyro_backend("contrib.funsor"): elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) - expected_loss = elbo.differentiable_loss(model, guide, weeks_data, days_data, history, False) + expected_loss = elbo.loss_and_grads(model, guide, weeks_data, days_data, history, False) + expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) - actual_loss = vectorized_elbo.differentiable_loss(model, guide, weeks_data, days_data, history, True) + actual_loss = vectorized_elbo.loss_and_grads(model, guide, weeks_data, days_data, history, True) + actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) + for actual_grad, expected_grad in zip(actual_grads, expected_grads): + assert_close(actual_grad, expected_grad) def model_10(data, vectorized): @@ -552,12 +560,16 @@ def test_guide_enumerated_elbo(model, guide, data): with pyro_backend("contrib.funsor"): with pytest.raises( NotImplementedError, - match="TraceMarkovEnum_ELBO does not yet support guide side enumeration"): + match="TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"): elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) - expected_loss = elbo.differentiable_loss(model, guide, data, False) + expected_loss = elbo.loss_and_grads(model, guide, data, False) + expected_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) - actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, True) + actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, True) + actual_grads = (value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) + for actual_grad, expected_grad in zip(actual_grads, expected_grads): + assert_close(actual_grad, expected_grad) From add3bc257685f0588a9c5260361680da2b3ff24e Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 20 Dec 2020 20:34:49 -0500 Subject: [PATCH 8/8] add clear_param_store --- tests/contrib/funsor/test_vectorized_markov.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/contrib/funsor/test_vectorized_markov.py b/tests/contrib/funsor/test_vectorized_markov.py index 434e8482ca..ac34469f26 100644 --- a/tests/contrib/funsor/test_vectorized_markov.py +++ b/tests/contrib/funsor/test_vectorized_markov.py @@ -476,6 +476,7 @@ def guide_empty(data, history, vectorized): (model_7, guide_empty, torch.ones((50, 4), dtype=torch.long), 1), ]) def test_model_enumerated_elbo(model, guide, data, history): + pyro.clear_param_store() with pyro_backend("contrib.funsor"): if history > 1: @@ -503,6 +504,7 @@ def guide_empty_multi(weeks_data, days_data, history, vectorized): (model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1), ]) def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, history): + pyro.clear_param_store() with pyro_backend("contrib.funsor"): @@ -556,6 +558,7 @@ def guide_10(data, vectorized): (model_10, guide_10, torch.ones(5)), ]) def test_guide_enumerated_elbo(model, guide, data): + pyro.clear_param_store() with pyro_backend("contrib.funsor"): with pytest.raises(