Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Dirichlet.log_prob() when x=0 and alpha=1 #103605

Closed
wants to merge 1 commit into from

Conversation

kalekundert
Copy link
Contributor

Dirichlet.log_prob() incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$. The Dirichlet PDF is given by:
$$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as 0 * log(0).

This PR implements the same algorithm that scipy.stats.dirichlet uses to avoid this behavior, namely xlogy(alpha - 1, x) instead of (alpha - 1) * log(x). It also adds a test case comparing the pytorch and scipy implementations for this specific case.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103605

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e001407:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jun 14, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: kalekundert / name: Kale Kundert (e001407)

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change sounds good!
Just a small question on the test.

x = torch.tensor([0, 1])
actual_log_prob = dist.log_prob(x)
expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy())
self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need to override the tolerances here? I would expect the default one to work fine here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied the whole assertion line from the existing Dirichlet.log_prob() test. I don't know if there's a good reason why these tolerances were chosen in the first place, though. Maybe there's a test environment that uses low-precision floats or something? In any case, my goal was to be consistent. Also, this test only really needs to distinguish between NaN and not NaN, so for that the tolerance doesn't matter.

@albanD albanD added release notes: python_frontend release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 15, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's keep it as is in the name of consistency then. I agree that it does check what we want here anyways.

@albanD
Copy link
Collaborator

albanD commented Jun 15, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 15, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kalekundert kalekundert deleted the fix-dirichlet-xlogy branch June 15, 2023 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: python_frontend release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants