Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conversion of Transforms and TransformedDistributions to and from Funsors #365

Merged
merged 46 commits into from
Nov 20, 2020

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Sep 20, 2020

Addresses #309, #386. Blocked by #387, #388, #389.

This PR adds support for converting Transforms and TransformedDistributions to and from Funsors, and for sampling and scoring withfunsor.Distributions with transforms in their value slot.

The latter is achieved with a new term I've called Lebesgue for want of a better idea, which computes log-det-jacobians of lazy expressions substituted into its free variable using funsor.delta.solve. (Update: I'm moving Lebesgue to another PR, since it no longer seems necessary for the basic task of converting TransformedDistributions.)

There are several ways this functionality could have been instantiated. The design choices in this version were made with a view toward locality of changes - the overall distribution API is left unchanged, as is the behavior of normalize and eager, and I've reused funsor.delta.solve even though it should probably be broken up into separate inversion and linearization transformations.

The number of transforms currently supported is very limited, partly because we don't have many transforms wrapped as TransformOps in Funsor and partly because solve cannot handle inverting more complex expressions. I've at least tried to get the two most important higher-order transforms ComposeTransform and InverseTransform working, so that adding more transforms later will be relatively straightforward. For ease of review, I have also chosen to leave support for the JAX backend to a followup PR, though it should simply involve copying the additions to funsor.torch.distributions verbatim.

Tested:

Tasks remaining:

  • Add unit tests for density
  • Add unit tests for cons-hashing
  • Add unit tests for sampling
  • Add unit tests for transform conversion
  • Get generic conversion tests to pass

Out of scope for this PR:

  • Support more transforms
  • Add similar capability to the JAX backend
  • Add more distribution conversion unit tests, especially interactions with Independent - added one Independent case, but conversion failing
  • Add better documentation for working with transformed distributions and Lebesgue - moving Lebesgue to another PR

@eb8680 eb8680 added the WIP label Sep 20, 2020
@eb8680 eb8680 mentioned this pull request Oct 26, 2020
34 tasks
@eb8680 eb8680 added the Blocked Blocked by other issues label Oct 29, 2020
@eb8680 eb8680 removed the Blocked Blocked by other issues label Oct 31, 2020
@eb8680 eb8680 added the Blocked Blocked by other issues label Nov 11, 2020
@eb8680 eb8680 removed the Blocked Blocked by other issues label Nov 11, 2020
funsor/cnf.py Outdated Show resolved Hide resolved
else:
raise NotImplementedError("cannot get raw dist for {}".format(self))
value_name = [name for name, domain in self.value.inputs.items() # TODO is this right?
if domain == self.value.output][0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird. Can you explain what's going on? When is value_name != "value"? The assertion in self.__init__() suggests value_name == "value" IIUC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value_name can potentially be anything, since self.value is a lazy expression. This logic is meant to solve the problem of identifying the value name when self.value is not a Variable or even has more than one input (which happens when constructing the Funsor version of a TransformedDistribution).

The simplest nontrivial example of the latter case would be an affine or power transform where the parameters are funsor.Tensors with nontrivial .inputs, although these are not handled in this PR since funsor.delta.solve does not yet support inverting such expressions.

The main use for value_name in this PR is in Distribution.unscaled_sample, which needs to know value_name to construct a sample Delta with the correct .inputs.

test/test_distribution_generic.py Show resolved Hide resolved
test/test_distribution_generic.py Outdated Show resolved Hide resolved
test/test_distribution_generic.py Outdated Show resolved Hide resolved
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing nits!

@fritzo fritzo merged commit 668aa70 into master Nov 20, 2020
@fritzo fritzo deleted the lebesgue-2 branch November 20, 2020 20:19
@fritzo fritzo mentioned this pull request Jan 17, 2021
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants