Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceMarkovEnum_ELBO - a trace implementation of ELBO-based SVI that supports model side enumeration of Markov variables #2716

Merged
merged 9 commits into from
Dec 21, 2020

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Dec 17, 2020

WIP. This first commit has a draft implementation of the TraceMarkovEnum_ELBO with modified_partial_sum_product. Tests compare elbos for pyro.markov models computed using TraceEnum_ELBO and pyro.vectorized_markov models computed using TraceMarkovEnum_ELBO. Note: all models have an empty guide.

Note

TraceMarkovEnum_ELBO only supports model side Markov enumeration. All discrete variable with the same&higher ordinal must also be enumerated in the model.

@ordabayevy
Copy link
Member Author

All current tests have empty guides, so I'll add some tests where markov variables are enumerated in the guide. @eb8680 do you have any suggestions on what kind of other tests need to be added?

@eb8680
Copy link
Member

eb8680 commented Dec 17, 2020

do you have any suggestions on what kind of other tests need to be added?

You've hit the edge of Pyro's enumeration functionality :) You could port the model-guide pairs in test_elbo_hmm_in_model, test_elbo_hmm_in_guide, test_hmm_enumerate_model_and_guide, and test_elbo_dbn_growth from tests/infer/test_enum.py and compare the ELBO values for markov and vectorized_markov as in your updated test_vectorized_markov.

log_prob = reduce(funsor.ops.add, factors, funsor.Number(funsor.ops.UNITS[funsor.ops.add]))
# compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q
elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs))
elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs))
Copy link
Member

@eb8680 eb8680 Dec 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note these two lines will not be correct for markov plates; you should try to write a guide enumeration test where this causes the sequential and vectorized markov ELBO values to disagree. I believe some of the test cases I suggested would be sufficient.

To fix this, you will need to rewrite these two lines as another modified_partial_sum_product computation. Note that

funsor.Integrate(log_prob, integrand, reduced_vars)

is equivalent to

(log_prob.exp() * integrand).reduce(ops.add, reduced_vars)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note these two lines will not be correct for markov plates

Ok. I haven't checked the math yet. So there might be some more errors.

is equivalent to

Thanks for explaining this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After taking a closer look it seems like modified_partial_sum_product might not work for guide enumerated models. The reason is that cost (integrand) logq(z) of transition factors has to be split into the sum factors logq(z_0) + logq(z_1|z_0) + ... before integrating w.r.t. q(z_0)q(z_1|z_0) .... I can explain this in more details at the Zoom meeting.

Copy link
Member Author

@ordabayevy ordabayevy Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words the prod operation for the integrand is ops.add and for the log_prob.exp() is ops.mul.

return terms


@copy_docs_from(_OrigTraceEnum_ELBO)
class TraceMarkovEnum_ELBO(ELBO):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR, once we're confident that TraceMarkovEnum_ELBO is correct, we can just merge it with TraceEnum_ELBO instead of keeping the implementations separate.

@eb8680
Copy link
Member

eb8680 commented Dec 17, 2020

We'll need to update the Funsor dependency to the latest master before this can be merged.

@eb8680
Copy link
Member

eb8680 commented Dec 17, 2020

@ordabayevy tests should pass after you merge the latest dev branch

Copy link
Member Author

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have tried to add a model with mixed enumeration, similar to the model that I'm using for my project. m is enumerated in the guide and theta markov variable is enumerated in the model. But then realized that modified_partial_sum_product has to eliminate all same&higher ordinal variables first before eliminating markov variables. So it seems like guide side enumeration is needed for my model after all.

def model_9(data, history, vectorized):
    theta_dim, m_dim = 3, 2
    theta_init = pyro.param("theta_init", lambda: torch.rand(theta_dim), constraint=constraints.simplex)
    theta_trans = pyro.param("theta_trans", lambda: torch.rand((theta_dim, theta_dim)), constraint=constraints.simplex)
    m_prior = pyro.param("m_prior", lambda: torch.rand((theta_dim, m_dim)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(m_dim))

    with pyro.plate("targets", size=data.shape[-2], dim=-2) as targets:
        targets = targets[:, None]
        theta_prev = None
        markov_loop = \
            pyro.vectorized_markov(name="frames", size=data.shape[-1], dim=-1, history=history) if vectorized \
            else pyro.markov(range(data.shape[-1]), history=history)
        for i in markov_loop:
            theta_curr = pyro.sample(
                "theta_{}".format(i), dist.Categorical(
                    theta_init if isinstance(i, int) and i < 1 else theta_trans[theta_prev]),
                infer={"enumerate": "parallel"})
            m = pyro.sample("m_{}".format(i), dist.Categorical(m_prior[theta_curr]))
            pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., m], 1.), obs=Vindex(data)[targets, i])
            theta_prev = theta_curr


def guide_9(data, history, vectorized):
    m_dim = 2
    m_probs = pyro.param("m_probs",
                         lambda: torch.rand((data.shape[-2], data.shape[-1], m_dim)),
                         constraint=constraints.simplex)

    with pyro.plate("targets", size=data.shape[-2], dim=-2) as targets:
        targets = targets[:, None]
        markov_loop = \
            pyro.vectorized_markov(name="frames", size=data.shape[-1], dim=-1, history=history) if vectorized \
            else pyro.markov(range(data.shape[-1]), history=history)
        for i in markov_loop:
            pyro.sample("m_{}".format(i), dist.Categorical(Vindex(m_probs)[targets, i]),
                        infer={"enumerate": "parallel"})

with pyro_backend("contrib.funsor"):
with pytest.raises(
NotImplementedError,
match="TraceMarkovEnum_ELBO does not yet support guide side enumeration"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xfail test for guide enumerated models.


# guide side enumeration is not supported
if any(guide_terms["plate_to_step"].values()):
raise NotImplementedError("TraceMarkovEnum_ELBO does not yet support guide side enumeration")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise NotImplementedError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change error message to "TraceMarkovEnum_ELBO does not yet support guide side Markov enumeration"

vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, history, True)

assert_close(actual_loss, expected_loss)
Copy link
Member

@eb8680 eb8680 Dec 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests look good, but can you also check gradients?

expected_loss = ...
params = tuple(pyro.get_param_store().values())
expected_grads = torch.autograd.grad(expected_loss, params)

vectorized_elbo = ...
actual_loss = ...
actual_grads = torch.autograd.grad(actual_loss, params)

assert_close(actual_loss, expected_loss)
for actual_grad, expected_grad in zip(actual_grads, expected_grads):
    assert_close(actual_grad, expected_grad)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.autograd.grad gives RuntimeError:

>       return Variable._execution_engine.run_backward(
            outputs, grad_outputs_, retain_graph, create_graph,
            inputs, allow_unused)
E       RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Do I have to add clear_param_store() in the models?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add clear_param_store() to the beginning of the test body.

vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
actual_loss = vectorized_elbo.differentiable_loss(model, guide, weeks_data, days_data, history, True)

assert_close(actual_loss, expected_loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: check gradients

vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4)
actual_loss = vectorized_elbo.differentiable_loss(model, guide, data, True)

assert_close(actual_loss, expected_loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: check gradients (even though this test is currently xfailed)

@eb8680
Copy link
Member

eb8680 commented Dec 20, 2020

it seems like guide side enumeration is needed for my model after all

Ok, I'll try to get the parallel ELBO computations working after this PR is merged.

@ordabayevy
Copy link
Member Author

ordabayevy commented Dec 20, 2020

Also model_9, guide_9 pair above runs without any errors but computes wrong ELBO. It's harder to catch the error and I'm not sure whether that should happen in modified_partial_sum_product or TraceMarkovEnum_ELBO. Or we can leave it up to the user not to write such models.

@ordabayevy ordabayevy changed the title TraceMarkovEnum_ELBO (TraceEnum_ELBO with modified_partial_sum_product) WIP TraceMarkovEnum_ELBO (TraceEnum_ELBO with modified_partial_sum_product) Dec 20, 2020
@ordabayevy ordabayevy changed the title TraceMarkovEnum_ELBO (TraceEnum_ELBO with modified_partial_sum_product) TraceMarkovEnum_ELBO - a trace implementation of ELBO-based SVI that supports model side enumeration of Markov variables Dec 20, 2020
@eb8680
Copy link
Member

eb8680 commented Dec 21, 2020

model_9, guide_9 pair above runs without any errors but computes wrong ELBO. It's harder to catch the error

What do you think the ideal behavior here should be? We have a bit of validation logic in pyro.infer.TraceEnum_ELBO that forbids related behavior (no guide enumeration can be in a more local plate context than any downstream model enumeration), but it's not immediately obvious to me how or whether to extend that logic to this case.

@ordabayevy
Copy link
Member Author

ordabayevy commented Dec 21, 2020

Probably either by extending/altering that validation logic or controlling cond_indep_stack. I believe this issue arises in the first place because we label m as conditionally independent across frames dimension when that's not entirely true and just a matter of convenience. So I would prefer the latter approach where m wouldn't have frames in its cond_indep_stack.

Edit
At least after integrating out theta in the model m is not conditionally independent across frames dimension.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code and tests for model enumeration look good to me. We can defer all guide enumeration issues to future PRs.

@eb8680 eb8680 merged commit f9ad37b into pyro-ppl:dev Dec 21, 2020
@ordabayevy
Copy link
Member Author

Hi @eb8680 . I was also trying to get the guide side enumeration working. It seems to work for model_9 and model_10 here with one problem. The strategy is to compute forward sequental sum terms alpha(z_t) = sum(z_0,...,z_t-1) [q(z_0) ... q(z_t|q(z_t-1)] and backward sequential sum terms beta(z_t) = sum(z_t+1,...,z_T) [q(z_t+1|z_t) ... q(z_T|z_T-1)] in parallel and then integrate cost terms in parallel as well sum(z_t,z_t-1) [alpha(z_t-1)*cost(z_t|z_t-1)*beta(z_t)]. However, I only have the naive implementation of forward and backward term calculations (naive_prefix_sum and naive_suffix_sum). Forward terms are the terms in Eq 5 in Temporal Parallelization of Bayesian Smoothers paper and require the down sweep calculations in the algorithm in Figure 1. That's where I got stuck because I couldn't figure out how to do it with funsor Tensors.

@eb8680
Copy link
Member

eb8680 commented Dec 28, 2020

@ordabayevy I'm not sure I follow, but I believe the necessary parallel forward-backward algorithms are already implemented in funsor.adjoint. The main challenge is applying them in a way that addresses #2724.

@ordabayevy
Copy link
Member Author

ordabayevy commented Dec 28, 2020

@eb8680 I looked up what adjoint operators are. Now it all makes sense (if I understood it correctly)!! I'll try to update compute_expectation to handle markov variables too.

@ordabayevy
Copy link
Member Author

@eb8680 is there a reference or explanation what funsor.adjoint does? I'm still having a hard time understanding what it does and how to use it.

@eb8680
Copy link
Member

eb8680 commented Dec 29, 2020

is there a reference or explanation what funsor.adjoint does?

Well, the implementation is currently a bit cluttered, but it computes derivatives of abstract sum-product expressions (including any expression computed with modified_partial_sum_product) with respect to input factors using reverse-mode automatic differentiation. The API isn't all that different from torch.autograd.grad, except that it currently requires specifying a sum and product operation in addition to a root expression and input factors; there are some usage examples in the tests.

As described in this tutorial paper, it turns out that these derivatives correspond to marginal distributions when the sum and product operations are logaddexp and add and the input factors are unnormalized log-probability tensors, and that computing multiple marginal distributions in this way is highly efficient because it maximizes the amount of intermediate memory and computation shared between marginals.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants