-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraceMarkovEnum_ELBO - a trace implementation of ELBO-based SVI that supports model side enumeration of Markov variables #2716
Conversation
All current tests have empty guides, so I'll add some tests where markov variables are enumerated in the guide. @eb8680 do you have any suggestions on what kind of other tests need to be added? |
You've hit the edge of Pyro's enumeration functionality :) You could port the model-guide pairs in |
log_prob = reduce(funsor.ops.add, factors, funsor.Number(funsor.ops.UNITS[funsor.ops.add])) | ||
# compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q | ||
elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) | ||
elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note these two lines will not be correct for markov plates; you should try to write a guide enumeration test where this causes the sequential and vectorized markov ELBO values to disagree. I believe some of the test cases I suggested would be sufficient.
To fix this, you will need to rewrite these two lines as another modified_partial_sum_product
computation. Note that
funsor.Integrate(log_prob, integrand, reduced_vars)
is equivalent to
(log_prob.exp() * integrand).reduce(ops.add, reduced_vars)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note these two lines will not be correct for markov plates
Ok. I haven't checked the math yet. So there might be some more errors.
is equivalent to
Thanks for explaining this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After taking a closer look it seems like modified_partial_sum_product
might not work for guide enumerated models. The reason is that cost (integrand) logq(z)
of transition factors has to be split into the sum factors logq(z_0) + logq(z_1|z_0) + ...
before integrating w.r.t. q(z_0)q(z_1|z_0) ...
. I can explain this in more details at the Zoom meeting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words the prod
operation for the integrand is ops.add
and for the log_prob.exp()
is ops.mul
.
return terms | ||
|
||
|
||
@copy_docs_from(_OrigTraceEnum_ELBO) | ||
class TraceMarkovEnum_ELBO(ELBO): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a future PR, once we're confident that TraceMarkovEnum_ELBO
is correct, we can just merge it with TraceEnum_ELBO
instead of keeping the implementations separate.
We'll need to update the Funsor dependency to the latest |
@ordabayevy tests should pass after you merge the latest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also have tried to add a model with mixed enumeration, similar to the model that I'm using for my project. m
is enumerated in the guide and theta
markov variable is enumerated in the model. But then realized that modified_partial_sum_product
has to eliminate all same&higher ordinal variables first before eliminating markov variables. So it seems like guide side enumeration is needed for my model after all.
def model_9(data, history, vectorized):
theta_dim, m_dim = 3, 2
theta_init = pyro.param("theta_init", lambda: torch.rand(theta_dim), constraint=constraints.simplex)
theta_trans = pyro.param("theta_trans", lambda: torch.rand((theta_dim, theta_dim)), constraint=constraints.simplex)
m_prior = pyro.param("m_prior", lambda: torch.rand((theta_dim, m_dim)), constraint=constraints.simplex)
locs = pyro.param("locs", lambda: torch.rand(m_dim))
with pyro.plate("targets", size=data.shape[-2], dim=-2) as targets:
targets = targets[:, None]
theta_prev = None
markov_loop = \
pyro.vectorized_markov(name="frames", size=data.shape[-1], dim=-1, history=history) if vectorized \
else pyro.markov(range(data.shape[-1]), history=history)
for i in markov_loop:
theta_curr = pyro.sample(
"theta_{}".format(i), dist.Categorical(
theta_init if isinstance(i, int) and i < 1 else theta_trans[theta_prev]),
infer={"enumerate": "parallel"})
m = pyro.sample("m_{}".format(i), dist.Categorical(m_prior[theta_curr]))
pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., m], 1.), obs=Vindex(data)[targets, i])
theta_prev = theta_curr
def guide_9(data, history, vectorized):
m_dim = 2
m_probs = pyro.param("m_probs",
lambda: torch.rand((data.shape[-2], data.shape[-1], m_dim)),
constraint=constraints.simplex)
with pyro.plate("targets", size=data.shape[-2], dim=-2) as targets:
targets = targets[:, None]
markov_loop = \
pyro.vectorized_markov(name="frames", size=data.shape[-1], dim=-1, history=history) if vectorized \
else pyro.markov(range(data.shape[-1]), history=history)
for i in markov_loop:
pyro.sample("m_{}".format(i), dist.Categorical(Vindex(m_probs)[targets, i]),
infer={"enumerate": "parallel"})
with pyro_backend("contrib.funsor"): | ||
with pytest.raises( | ||
NotImplementedError, | ||
match="TraceMarkovEnum_ELBO does not yet support guide side enumeration"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xfail
test for guide enumerated models.
|
||
# guide side enumeration is not supported | ||
if any(guide_terms["plate_to_step"].values()): | ||
raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side enumeration") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise NotImplementedError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change error message to "TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"
vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) | ||
actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True) | ||
|
||
assert_close(actual_loss, expected_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests look good, but can you also check gradients?
expected_loss = ...
params = tuple(pyro.get_param_store().values())
expected_grads = torch.autograd.grad(expected_loss, params)
vectorized_elbo = ...
actual_loss = ...
actual_grads = torch.autograd.grad(actual_loss, params)
assert_close(actual_loss, expected_loss)
for actual_grad, expected_grad in zip(actual_grads, expected_grads):
assert_close(actual_grad, expected_grad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.autograd.grad
gives RuntimeError
:
> return Variable._execution_engine.run_backward(
outputs, grad_outputs_, retain_graph, create_graph,
inputs, allow_unused)
E RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
Do I have to add clear_param_store()
in the models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could add clear_param_store()
to the beginning of the test body.
vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) | ||
actual_loss = vectorized_elbo.differentiable_loss(model, guide, weeks_data, days_data, history, True) | ||
|
||
assert_close(actual_loss, expected_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: check gradients
vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) | ||
actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, True) | ||
|
||
assert_close(actual_loss, expected_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: check gradients (even though this test is currently xfailed)
Ok, I'll try to get the parallel ELBO computations working after this PR is merged. |
Also |
What do you think the ideal behavior here should be? We have a bit of validation logic in |
Probably either by extending/altering that validation logic or controlling Edit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code and tests for model enumeration look good to me. We can defer all guide enumeration issues to future PRs.
Hi @eb8680 . I was also trying to get the guide side enumeration working. It seems to work for |
@ordabayevy I'm not sure I follow, but I believe the necessary parallel forward-backward algorithms are already implemented in |
@eb8680 I looked up what adjoint operators are. |
@eb8680 is there a reference or explanation what |
Well, the implementation is currently a bit cluttered, but it computes derivatives of abstract sum-product expressions (including any expression computed with As described in this tutorial paper, it turns out that these derivatives correspond to marginal distributions when the sum and product operations are |
WIP. This first commit has a draft implementation of the
TraceMarkovEnum_ELBO
withmodified_partial_sum_product
. Tests compare elbos forpyro.markov
models computed usingTraceEnum_ELBO
andpyro.vectorized_markov
models computed usingTraceMarkovEnum_ELBO
. Note: all models have an empty guide.Note
TraceMarkovEnum_ELBO
only supports model side Markov enumeration. All discrete variable with the same&higher ordinal must also be enumerated in the model.