Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Categorical.log_prob() of enumerated support #1831

Merged
merged 4 commits into from Apr 19, 2019

Conversation

Projects
None yet
4 participants
@fritzo
Copy link
Collaborator

commented Apr 19, 2019

This adds an optimization to dist.Categorical to completely avoid tensor ops in x.log_prob(x.enumerate_support()), thus speeding up inference using enumeration. In particular it can replace a torch.gather() with a reshape-transpose-unsqueeze.

Performance

  • examples/tabular.py: 139s -> 79s
    (that's overall; torch.gather was taking 70s alone, all of which is eliminated by this PR)
  • examples/hmm.py --model 1: 15.5s -> 13.1s
  • examples/hmm.py --model 4: 21.5s -> 18.8s
  • examples/hmm.py --model 6: 73.9s -> 32.1s
  • examples/hmm.py --model 6 --jit: 63.8s -> 26.4s

Tested

  • refactoring is covered by existing tests, especially test_enum.py
  • tested performance on contrib-treecat branch

@fritzo fritzo requested review from neerajprad and eb8680 Apr 19, 2019

@fritzo fritzo referenced this pull request Apr 19, 2019

Closed

Use Vindex in Categorical.log_prob()? #1830

2 of 3 tasks complete
@eb8680
Copy link
Member

left a comment

Wow, those are some huge speedups! I bet the mixed-effect HMM examples would see an even more dramatic performance boost.

@fehiepsi
Copy link
Collaborator

left a comment

Interesting, this is a very nice observation!

@neerajprad
Copy link
Member

left a comment

This is a nice optimization (and a clever observation)!

@eb8680

This comment has been minimized.

Copy link
Member

commented Apr 19, 2019

@neerajprad @fritzo should we consider doing a minor release after merging this?

@neerajprad

This comment has been minimized.

Copy link
Member

commented Apr 19, 2019

@neerajprad @fritzo should we consider doing a minor release after merging this?

Yes, I'm shooting for sometime early next week, since the MOBB model will be going live as well. Are there any other changes we are looking to put in?

@fritzo

This comment has been minimized.

Copy link
Collaborator Author

commented Apr 19, 2019

should we consider doing a minor release after merging this?

Sure. I defer to @neerajprad for timing, since his MOBB support is highest priority.

Are there any other changes we are looking to put in?

No changes on my end. @neerajprad may want to put in more changes. The feature I plan to release is a minipyro.JitTrace_ELBO, but there is no hurry for that.

@eb8680

This comment has been minimized.

Copy link
Member

commented Apr 19, 2019

Are there any other changes we are looking to put in?

Not a release blocker if I don't get around to it, but at some point I would like to update the mixed-effect HMM examples to use Vindex, which should dramatically simplify the gross indexing code there. @fritzo did you want to include an updated version of the enumeration tutorial? I figure if we're releasing Vindex we might as well advertise it a bit.

@fritzo

This comment has been minimized.

Copy link
Collaborator Author

commented Apr 19, 2019

did you want to include an updated version of the enumeration tutorial?

Likewise this is not a release blocker. I'll try to get around to it, but @neerajprad don't wait for me to release.

@neerajprad neerajprad merged commit 985f57e into dev Apr 19, 2019

2 checks passed

continuous-integration/travis-ci/pr The Travis CI build passed
Details
license/cla Contributor License Agreement is signed.
Details

ahmadsalim added a commit to ahmadsalim/pyro that referenced this pull request May 6, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.