Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update distribution shape inference to handle independent dims #402

Merged
merged 18 commits into from
Dec 17, 2020

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Dec 1, 2020

Addresses #386, pyro-ppl/pyro#2702

Consider the following snippet of Pyro code simplified from a test in pyro-ppl/pyro#2702 :

with pyro.plate("data", T, dim=-1):
    y = pyro.sample("y", dist.Normal(torch.ones(3), 1.).to_event(1))
    pyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data)

Attempting to wrap this in a poutine.collapse() context fails when dist.Normal(y, 1) is incorrectly converted to a Funsor. In particular, there is a mismatch between y.output, which is Reals[3] and the .output value Real expected for the loc param of funsor.torch.distributions.Normal given by funsor.distribution.Distribution._infer_value_domain.

In cases like the above, however, it is always possible to infer the correct parameter and value .output shapes generically using the broadcasting logic in the underlying backend distribution. This PR implements this behavior in funsor.distribution.DistributionMeta.__call__ and distribution_to_data.

As a result, it is possible after this PR to represent Independent distributions without an intermediary funsor.Independent by passing a Variable with extra output dimensions as the value:

# diagonal multivariate normal in pyro:
dist.Normal(zeros(3), 1.).to_event(1)

# funsor equivalent, after this PR - note Reals[3] for x
Normal(loc=zeros(3), scale=1., value=Variable("x", Reals[3]))

This should considerably simplify pattern-matching over Independent distributions.

To make this work, converting a funsor.distribution.Distribution to data via funsor.to_data will have to include a step handling nontrivial event shapes by calling .to_event and unsqueezing parameters.

More ambitiously, we could also attempt to handle to_funsor conversion of backend Independent distributions by substituting an appropriately shaped Variable for their value rather than resorting to lazy application of funsor.terms.Independent, which would make pattern-matching for collapseing multivariate distributions much easier, but it's not clear yet whether this can be done generically, especially in the presence of transforms. I have not attempted to do this in this PR.

Tasks:

  • Update funsor.distribution.DistributionMeta
  • Update distribution_to_data
  • Add tests
  • Fix DirichletMultinomial broadcasting errors

@eb8680 eb8680 added the WIP label Dec 1, 2020
# The arguments to _infer_value_domain are the .output shapes of parameters,
# so any extra batch dimensions that aren't part of the instance event_shape
# must be broadcasted output dimensions by construction.
out_shape = instance.batch_shape + instance.event_shape
Copy link
Member Author

@eb8680 eb8680 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change to _infer_value_domain is the conceptual meat of the PR.

params, value = params[:-1], params[-1]
params = params + (Variable("value", value.output),)
instance = reflect(cls, *params)
raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to refactor eager_log_prob to use Distribution._get_raw_dist() to get the new tests to pass.

funsor/distribution.py Outdated Show resolved Hide resolved
domains[k] = domain if domain is not None else to_funsor(v).output

# broadcast individual param domains with Funsor inputs
# this avoids .expand-ing underlying parameter tensors
Copy link
Member

@fehiepsi fehiepsi Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected domain of scale for Normal(Reals[2], 1.) and Normal(Reals[2], torch.ones(2))? Currently, domains["scale"] will be Real in both case. The second case will trigger an error at to_funsor(v, output=domains[k]) below.

In either case, I guess we need to rewrite eager_normal or eager_mvn to address Reals[2] loc. Maybe there is some trick to avoid doing so. cc @fritzo

Copy link
Member Author

@eb8680 eb8680 Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected domain of scale for Normal(Reals[2], 1.) and Normal(Reals[2], torch.ones(2))?

In the first case, it's Real, and in the second, it's Reals[2]. I guess I should add a second broadcasting condition below to handle the case where the parameter is a raw tensor:

if ops.is_numeric_array(v):  # at this point we know all of v's dims are output dims
    domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]

@fritzo
Copy link
Member

fritzo commented Dec 8, 2020

It looks like you just need to relax precision of test_dirichlet_density() for the jax backend.

@eb8680
Copy link
Member Author

eb8680 commented Dec 8, 2020

@fehiepsi there are a ton of unrelated new failures on the last JAX build, any idea what's going on? Seems like a mix of weird new numerical errors and ValueError: cannot compare objects of type <class 'jax.interpreters.xla._DeviceArray'>.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic to get expected outputs looks reasonable to me, pending further comments by @fritzo.

I'll try to address eager_normal, eager_mvn issues in a follow-up PR. Currently, those eager rules assume Real args but we have Reals or a mix of Real, Reals now.

@eb8680
Copy link
Member Author

eb8680 commented Dec 10, 2020

I'll try to address eager_normal, eager_mvn issues in a follow-up PR. Currently, those eager rules assume Real args but we have Reals or a mix of Real, Reals now.

Hmm, I wonder if we'll have to do this for all the other eager distribution patterns in funsor/distribution.py as well...

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(thanks for your patience @eb8680)

funsor/distribution.py Outdated Show resolved Hide resolved
funsor/distribution.py Show resolved Hide resolved
funsor_event_shape = funsor_dist.value.output.shape

# attempt to generically infer the independent output dimensions
instance = funsor_dist.dist_class(**{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond the scope of this PR, I'm concerned with the increasing overhead of shape computations that need to do tensor ops. I like @fehiepsi's recent suggestion of implementing .forward_event_shape() for transforms. I think it would be worthwhile to discuss and think about extensions to the Distribution interface that could replace all this need to create an throw away dummy distributions.

(Indeed in theory an optimizing compiler could remove all this overhead, but in practice our tensor backends either incur super-linear compile time cost, or fail to cover the wide range of probabilistic models we would like to handle. And while these dummy tensor ops are cheap, they add noise to debugging efforts.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree the repeated creation of distribution instances here is not ideal. Perhaps we could add counterparts of some of the shape inference methods from TFP (e.g. event_shape_tensor, param_shapes) upstream in torch.distributions.

funsor/distribution.py Outdated Show resolved Hide resolved
Comment on lines +218 to +223
dist.Independent.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.Independent.rsample = dist.Independent.sample
dist.MaskedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.MaskedDistribution.rsample = dist.MaskedDistribution.sample
dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO pointing to a NumPyro issue to fix this bug, so we can delete this workaround once the bug is fixed? cc @fehiepsi

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad Should we add those new attributes to NumPyro distributions? We can make default behaviors for them so that there would be only a few changes in the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In numpyro, for distributions that have reparametrized samplers available both will be the same so we can just add a default Distribution.rsample method which delegates to sample and throws a NotImplemented error when not available.

funsor/torch/distributions.py Show resolved Hide resolved
@eb8680
Copy link
Member Author

eb8680 commented Dec 17, 2020

@fritzo any other comments? I'd like to bump the Funsor dependency version in Pyro to the latest master to unblock @ordabayevy's PR pyro-ppl/pyro#2716 and we might as well include these changes too.

@fritzo
Copy link
Member

fritzo commented Dec 17, 2020

Thanks for the reminder and thanks for addressing nits! LGTM.

@fritzo fritzo merged commit c685dde into master Dec 17, 2020
@fritzo fritzo deleted the infer-independent-dims branch December 17, 2020 16:33
@eb8680 eb8680 mentioned this pull request Jan 25, 2021
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants