Skip to content

Fix hang in VonMises rejection sampling for small values of concentration#114498

Closed
julian-urban wants to merge 5 commits intopytorch:mainfrom
julian-urban:my-nightly-branch
Closed

Fix hang in VonMises rejection sampling for small values of concentration#114498
julian-urban wants to merge 5 commits intopytorch:mainfrom
julian-urban:my-nightly-branch

Conversation

@julian-urban
Copy link
Copy Markdown
Contributor

@julian-urban julian-urban commented Nov 24, 2023

Fixes #88443

Forces the internal dtype of torch.distributions.von_mises.VonMises to be torch.double and mirrors the numpy implementation of the second order Taylor expansion for concentration < 1e-5. Samples and log probs are returned with dtype of argument loc.

In principle one could also use masking in the rejection sampler to return uniformly distributed numbers for concentration < 1e-8, as in numpy. This may be slightly more efficient, but isn't required to solve the hanging issue.

cc @fritzo @neerajprad @alicanb @nikitaved

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Nov 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114498

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8d068d3 with merge base b27565a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@julian-urban
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Nov 24, 2023
@mikaylagawarecki mikaylagawarecki added the module: distributions Related to torch.distributions label Nov 29, 2023
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 29, 2023
Copy link
Copy Markdown
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! Could we move your logic to three @lazy_propertys, to keep the other code paths entirely single precision?

@julian-urban
Copy link
Copy Markdown
Contributor Author

@fritzo Thanks for picking it up, please check my latest commit. I followed your suggestion of implementing the _loc and _concentration attributes with double prec as lazy properties and similarly _proposal_r, also replacing the if/else with torch.where to make sure it works with tensor-valued loc and concentration. The hard-coded double prec stuff should now be confined to calls of sample() only.

I'm not sure if I was supposed to resolve the above conversation already, this is my first PR and I'm not familiar with the proper etiquette. Apologies.

Copy link
Copy Markdown
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @julian-urban! It might be worth adding a line comment in .sample() explaining why sampling is performed using double precision, but feel free to leave as is.

@fritzo
Copy link
Copy Markdown
Collaborator

fritzo commented Nov 30, 2023

Looks like a legit lint error

  Warning (UFMT) format
    Run `lintrunner -a` to apply this patch.

    You can run `lintrunner -a` to apply this patch.

    167  167 |         """
    168  168 |         shape = self._extended_shape(sample_shape)
    169  169 |         x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
    169      |-        return _rejection_sample(self._loc, self._concentration, self._proposal_r, x).to(self.loc.dtype)
         170 |+        return _rejection_sample(
         171 |+            self._loc, self._concentration, self._proposal_r, x
         172 |+        ).to(self.loc.dtype)
    171  173 | 
    172  174 |     def expand(self, batch_shape):
    173  175 |         try:

@julian-urban
Copy link
Copy Markdown
Contributor Author

@fritzo whoopsie. The test passes now and I added a clarifying comment to sample().

@fritzo
Copy link
Copy Markdown
Collaborator

fritzo commented Dec 2, 2023

@ezyang could you please merge this?

@fritzo
Copy link
Copy Markdown
Collaborator

fritzo commented Dec 4, 2023

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 4, 2023
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
…tion (pytorch#114498)

Fixes pytorch#88443

Forces the internal `dtype` of `torch.distributions.von_mises.VonMises` to be `torch.double` and mirrors the numpy implementation of the second order Taylor expansion for `concentration < 1e-5`. Samples and log probs are returned with `dtype` of argument `loc`.

In principle one could also use masking in the rejection sampler to return uniformly distributed numbers for `concentration < 1e-8`, as in numpy. This may be slightly more efficient, but isn't required to solve the hanging issue.

Pull Request resolved: pytorch#114498
Approved by: https://github.com/fritzo
@julian-urban julian-urban deleted the my-nightly-branch branch December 31, 2023 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: distributions Related to torch.distributions open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Hang: sampling VonMises distribution gets stuck in rejection sampling for small kappa

5 participants