Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full-featured distribution wrappers #386

Open
20 of 34 tasks
eb8680 opened this issue Oct 26, 2020 · 2 comments
Open
20 of 34 tasks

Full-featured distribution wrappers #386

eb8680 opened this issue Oct 26, 2020 · 2 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed testing

Comments

@eb8680
Copy link
Member

eb8680 commented Oct 26, 2020

The biggest single issue blocking wider use of Funsor in Pyro and NumPyro right now is the incomplete coverage of distributions.

At a high level, the goal is to be able to perform all distribution operations that appear in any Pyro or NumPyro model (e.g. sampling and scoring) on Funsors directly, where distribution funsors are obtained from using to_funsor to automatically convert distributions to Funsors initially and using funsor.to_data to automatically convert the final results back to raw PyTorch/JAX objects. Wherever possible, the Funsor wrappers should also avoid the need for user-facing higher-order distributions, such as Independent or TransformedDistribution, in favor of idiomatic Funsor operations or broadcasting semantics.

We have gotten pretty far with the generic wrappers in funsor.distribution and funsor.{jax,torch}.distributions, but finishing the job and achieving full coverage of pyro.distributions remains a challenge because of the large distribution API and number of distributions, many small impedance mismatches and legacy design choices in PyTorch (e.g. the data type of Bernoulli), and difficulty of programmatic access and automation (e.g. no generic tool for constructing random valid instances of a distribution given a batch_shape).

I've tried to collect the remaining Funsor-specific tasks in this issue so that we can better measure progress toward this goal. We may also need to do additional work upstream in Pyro, NumPyro or PyTorch distributions.

Transforms and TransformedDistributions (some design discussion in #309):

Other basic distribution modifiers:

Masking:

Direct TFP distribution wrappers:

  • Direct wrapping of TFP distributions in the JAX backend - this would probably involve a new subclass class TFPDistribution(funsor.distribution.Distribution) with TFP-specific implementations of _infer_value_domain and _infer_param_domain
  • Direct TFP Bijector wrappers

Atomic distribution computations beyond sampling and scoring implemented in the backend libraries:

Test harnesses for distribution wrappers (testing correctness of underlying distribution functionality here is out of scope - we are mostly interested in ensuring that results are converted to Funsors correctly):

Miscellaneous:

  • Find a workaround for int/float casting issues in Bernoulli (discussed in Casting between real and bint #348)
  • Generic support for sampling via Funsor.sample() in funsor.pyro.FunsorDistribution - done properly, this may eliminate the need for first-class conjugate distribution implementations in backends, e.g. DirichletMultinomial

Lower priority, possibly unnecessary:

@eb8680
Copy link
Member Author

eb8680 commented Nov 27, 2020

Conversion of TransformedDistributions to and from funsors on the JAX backend (can copy #365)

Note this will be a bit trickier than expected because in NumPyro Transform.inv is just a regular method, rather than returning an _InverseTransform, which NumPyro does not implement.

@fehiepsi
Copy link
Member

I will make a PR for that. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed testing
Projects
None yet
Development

No branches or pull requests

2 participants