Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceMarkovEnum_ELBO - a trace implementation of ELBO-based SVI that supports model side enumeration of Markov variables #2716

Merged
merged 9 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/contrib/funsor/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .elbo import ELBO # noqa: F401
from .trace_elbo import JitTrace_ELBO, Trace_ELBO # noqa: F401
from .tracetmc_elbo import JitTraceTMC_ELBO, TraceTMC_ELBO # noqa: F401
from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO # noqa: F401
from .traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO, TraceMarkovEnum_ELBO # noqa: F401
68 changes: 67 additions & 1 deletion pyro/contrib/funsor/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ def terms_from_trace(tr):
# data structure containing densities, measures, scales, and identification
# of free variables as either product (plate) variables or sum (measure) variables
terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.),
"plate_vars": frozenset(), "measure_vars": frozenset()}
"plate_vars": frozenset(), "measure_vars": frozenset(), "plate_to_step": dict()}
for name, node in tr.nodes.items():
# add markov dimensions to the plate_to_step dictionary
if node["type"] == "markov_chain":
terms["plate_to_step"][node["name"]] = node["value"]
if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample":
continue
# grab plate dimensions from the cond_indep_stack
Expand All @@ -39,9 +42,72 @@ def terms_from_trace(tr):
# grab the log-density, found at all sites except those that are not replayed
if node["is_observed"] or not node.get("replay_skipped", False):
terms["log_factors"].append(node["funsor"]["log_prob"])
# add plate dimensions to the plate_to_step dictionary
terms["plate_to_step"].update({plate: terms["plate_to_step"].get(plate, {}) for plate in terms["plate_vars"]})
return terms


@copy_docs_from(_OrigTraceEnum_ELBO)
class TraceMarkovEnum_ELBO(ELBO):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR, once we're confident that TraceMarkovEnum_ELBO is correct, we can just merge it with TraceEnum_ELBO instead of keeping the implementations separate.


def differentiable_loss(self, model, guide, *args, **kwargs):

# get batched, enumerated, to_funsor-ed traces from the guide and model
with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)

# extract from traces all metadata that we will need to compute the elbo
guide_terms = terms_from_trace(guide_tr)
model_terms = terms_from_trace(model_tr)

# guide side enumeration is not supported
if any(guide_terms["plate_to_step"].values()):
raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side enumeration")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise NotImplementedError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change error message to "TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"


# build up a lazy expression for the elbo
with funsor.interpreter.interpretation(funsor.terms.lazy):
# identify and contract out auxiliary variables in the model with partial_sum_product
contracted_factors, uncontracted_factors = [], []
for f in model_terms["log_factors"]:
if model_terms["measure_vars"].intersection(f.inputs):
contracted_factors.append(f)
else:
uncontracted_factors.append(f)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
markov_dims = frozenset({
plate for plate, step in model_terms["plate_to_step"].items() if step})
contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.modified_partial_sum_product(
funsor.ops.logaddexp, funsor.ops.add,
model_terms["log_measures"] + contracted_factors,
plate_to_step=model_terms["plate_to_step"],
eliminate=model_terms["measure_vars"] | markov_dims
)]

costs = contracted_costs + uncontracted_factors # model costs: logp
costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq

# finally, integrate out guide variables in the elbo and all plates
plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
elbo = to_funsor(0, output=funsor.Real)
for cost in costs:
# compute the marginal logq in the guide corresponding to this cost term
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp, funsor.ops.add,
guide_terms["log_measures"],
plates=plate_vars,
eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs)
)
# compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q
elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs))
elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs))
Copy link
Member

@eb8680 eb8680 Dec 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note these two lines will not be correct for markov plates; you should try to write a guide enumeration test where this causes the sequential and vectorized markov ELBO values to disagree. I believe some of the test cases I suggested would be sufficient.

To fix this, you will need to rewrite these two lines as another modified_partial_sum_product computation. Note that

funsor.Integrate(log_prob, integrand, reduced_vars)

is equivalent to

(log_prob.exp() * integrand).reduce(ops.add, reduced_vars)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note these two lines will not be correct for markov plates

Ok. I haven't checked the math yet. So there might be some more errors.

is equivalent to

Thanks for explaining this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After taking a closer look it seems like modified_partial_sum_product might not work for guide enumerated models. The reason is that cost (integrand) logq(z) of transition factors has to be split into the sum factors logq(z_0) + logq(z_1|z_0) + ... before integrating w.r.t. q(z_0)q(z_1|z_0) .... I can explain this in more details at the Zoom meeting.

Copy link
Member Author

@ordabayevy ordabayevy Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words the prod operation for the integrand is ops.add and for the log_prob.exp() is ops.mul.


# evaluate the elbo, using memoize to share tensor computation where possible
with funsor.memoize.memoize():
return -to_data(funsor.optimizer.apply_optimizer(elbo))


@copy_docs_from(_OrigTraceEnum_ELBO)
class TraceEnum_ELBO(ELBO):

Expand Down
156 changes: 130 additions & 26 deletions tests/contrib/funsor/test_vectorized_markov.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# put all funsor-related imports here, so test collection works without funsor
try:
import funsor
from funsor.testing import assert_close
import pyro.contrib.funsor
from pyroapi import distributions as dist
funsor.set_backend("torch")
from pyroapi import handlers, pyro, pyro_backend
from pyroapi import handlers, pyro, pyro_backend, infer
except ImportError:
pytestmark = pytest.mark.skip(reason="funsor is not installed")

Expand Down Expand Up @@ -280,25 +281,24 @@ def model_7(data, history, vectorized):
(model_7, torch.ones((5, 4), dtype=torch.long), "wxy", 1),
(model_7, torch.ones((50, 4), dtype=torch.long), "wxy", 1),
])
def test_vectorized_markov(model, data, var, history, use_replay):

with pyro_backend("contrib.funsor"), \
handlers.enum():
# sequential trace
trace = handlers.trace(model).get_trace(data, history, False)
def test_enumeration(model, data, var, history, use_replay):

with pyro_backend("contrib.funsor"):
with handlers.enum():
# sequential trace
trace = handlers.trace(model).get_trace(data, history, False)
# vectorized trace
vectorized_trace = handlers.trace(model).get_trace(data, history, True)
if use_replay:
vectorized_trace = handlers.trace(
handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True)

# sequential factors
factors = list()
for i in range(data.shape[-2]):
for v in var:
factors.append(trace.nodes["{}_{}".format(v, i)]["funsor"]["log_prob"])

# vectorized trace
vectorized_trace = handlers.trace(model).get_trace(data, history, True)
if use_replay:
vectorized_trace = handlers.trace(
handlers.replay(model, trace=vectorized_trace)).get_trace(data, history, True)

# vectorized factors
vectorized_factors = list()
for i in range(history):
Expand All @@ -315,7 +315,7 @@ def test_vectorized_markov(model, data, var, history, use_replay):

# assert correct factors
for f1, f2 in zip(factors, vectorized_factors):
funsor.testing.assert_close(f2, f1.align(tuple(f2.inputs)))
assert_close(f2, f1.align(tuple(f2.inputs)))

# assert correct step
actual_step = vectorized_trace.nodes["time"]["value"]
Expand Down Expand Up @@ -382,12 +382,18 @@ def model_8(weeks_data, days_data, history, vectorized):
(model_8, torch.ones(3), torch.zeros(9), "xy", "wz", 1),
(model_8, torch.ones(30), torch.zeros(50), "xy", "wz", 1),
])
def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay):
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay):

with pyro_backend("contrib.funsor"), \
handlers.enum():
# sequential factors
trace = handlers.trace(model).get_trace(weeks_data, days_data, history, False)
with pyro_backend("contrib.funsor"):
with handlers.enum():
# sequential factors
trace = handlers.trace(model).get_trace(weeks_data, days_data, history, False)

# vectorized trace
vectorized_trace = handlers.trace(model).get_trace(weeks_data, days_data, history, True)
if use_replay:
vectorized_trace = handlers.trace(
handlers.replay(model, trace=vectorized_trace)).get_trace(weeks_data, days_data, history, True)

factors = list()
# sequential weeks factors
Expand All @@ -399,12 +405,6 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his
for v in vars2:
factors.append(trace.nodes["{}_{}".format(v, j)]["funsor"]["log_prob"])

# vectorized trace
vectorized_trace = handlers.trace(model).get_trace(weeks_data, days_data, history, True)
if use_replay:
vectorized_trace = handlers.trace(
handlers.replay(model, trace=vectorized_trace)).get_trace(weeks_data, days_data, history, True)

vectorized_factors = list()
# vectorized weeks factors
for i in range(history):
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his

# assert correct factors
for f1, f2 in zip(factors, vectorized_factors):
funsor.testing.assert_close(f2, f1.align(tuple(f2.inputs)))
assert_close(f2, f1.align(tuple(f2.inputs)))

# assert correct step

Expand All @@ -457,3 +457,107 @@ def test_vectorized_markov_multi(model, weeks_data, days_data, vars1, vars2, his

assert actual_weeks_step == expected_weeks_step
assert actual_days_step == expected_days_step


def guide_empty(data, history, vectorized):
pass


@pytest.mark.parametrize("model,guide,data,history", [
(model_0, guide_empty, torch.rand(3, 5, 4), 1),
(model_1, guide_empty, torch.rand(5, 4), 1),
(model_2, guide_empty, torch.ones((5, 4), dtype=torch.long), 1),
(model_3, guide_empty, torch.ones((5, 4), dtype=torch.long), 1),
(model_4, guide_empty, torch.ones((5, 4), dtype=torch.long), 1),
(model_5, guide_empty, torch.ones((5, 4), dtype=torch.long), 2),
(model_6, guide_empty, torch.rand(5, 4), 1),
(model_6, guide_empty, torch.rand(100, 4), 1),
(model_7, guide_empty, torch.ones((5, 4), dtype=torch.long), 1),
(model_7, guide_empty, torch.ones((50, 4), dtype=torch.long), 1),
])
def test_model_enumerated_elbo(model, guide, data, history):

with pyro_backend("contrib.funsor"):
if history > 1:
pytest.xfail(reason="TraceMarkovEnum_ELBO does not yet support history > 1")

elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
expected_loss = elbo.differentiable_loss(model, guide, data, history, False)

vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True)

assert_close(actual_loss, expected_loss)
Copy link
Member

@eb8680 eb8680 Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests look good, but can you also check gradients?

expected_loss = ...
params = tuple(pyro.get_param_store().values())
expected_grads = torch.autograd.grad(expected_loss, params)

vectorized_elbo = ...
actual_loss = ...
actual_grads = torch.autograd.grad(actual_loss, params)

assert_close(actual_loss, expected_loss)
for actual_grad, expected_grad in zip(actual_grads, expected_grads):
    assert_close(actual_grad, expected_grad)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.autograd.grad gives RuntimeError:

>       return Variable._execution_engine.run_backward(
            outputs, grad_outputs_, retain_graph, create_graph,
            inputs, allow_unused)
E       RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Do I have to add clear_param_store() in the models?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add clear_param_store() to the beginning of the test body.



def guide_empty_multi(weeks_data, days_data, history, vectorized):
pass


@pytest.mark.parametrize("model,guide,weeks_data,days_data,history", [
(model_8, guide_empty_multi, torch.ones(3), torch.zeros(9), 1),
(model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1),
])
def test_model_enumerated_elbo_multi(model, guide, weeks_data, days_data, history):

with pyro_backend("contrib.funsor"):

elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
expected_loss = elbo.differentiable_loss(model, guide, weeks_data, days_data, history, False)

vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
actual_loss = vectorized_elbo.differentiable_loss(model, guide, weeks_data, days_data, history, True)

assert_close(actual_loss, expected_loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: check gradients



def model_10(data, vectorized):
init_probs = torch.tensor([0.5, 0.5])
transition_probs = pyro.param("transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=constraints.simplex)
emission_probs = pyro.param("emission_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=constraints.simplex)
x = None
markov_loop = \
pyro.vectorized_markov(name="time", size=len(data)) if vectorized \
else pyro.markov(range(len(data)))
for i in markov_loop:
probs = init_probs if x is None else transition_probs[x]
x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i])


def guide_10(data, vectorized):
init_probs = torch.tensor([0.5, 0.5])
transition_probs = pyro.param("transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=constraints.simplex)
x = None
markov_loop = \
pyro.vectorized_markov(name="time", size=len(data)) if vectorized \
else pyro.markov(range(len(data)))
for i in markov_loop:
probs = init_probs if x is None else transition_probs[x]
x = pyro.sample("x_{}".format(i), dist.Categorical(probs),
infer={"enumerate": "parallel"})


@pytest.mark.parametrize("model,guide,data,", [
(model_10, guide_10, torch.ones(5)),
])
def test_guide_enumerated_elbo(model, guide, data):

with pyro_backend("contrib.funsor"):
with pytest.raises(
NotImplementedError,
match="TraceMarkovEnum_ELBO does not yet support guide side enumeration"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xfail test for guide enumerated models.


elbo = infer.TraceEnum_ELBO(max_plate_nesting=4)
expected_loss = elbo.differentiable_loss(model, guide, data, False)

vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, True)

assert_close(actual_loss, expected_loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: check gradients (even though this test is currently xfailed)