Skip to content

Implement faster Multinomial JAX dispatch #868

@ricardoV94

Description

@ricardoV94

Description

We are defaulting to the numpyro implementation when that's installed, but the numpyro implementation is incredibly wasteful for large N, as it consists of doing n Categorical draws and summing up the values.

https://github.com/pyro-ppl/numpyro/blob/5af9ebda72bd7aeb08c61e4248ecd0d982473224/numpyro/distributions/util.py#L238

We should probably do sequential Binomial sampling: https://en.wikipedia.org/wiki/Multinomial_distribution#Algorithm:_Sequential_conditional_binomial_sampling

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions