New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Dirichlet.log_prob()
when x=0 and alpha=1
#103605
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103605
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e001407: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change sounds good!
Just a small question on the test.
x = torch.tensor([0, 1]) | ||
actual_log_prob = dist.log_prob(x) | ||
expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy()) | ||
self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you really need to override the tolerances here? I would expect the default one to work fine here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied the whole assertion line from the existing Dirichlet.log_prob()
test. I don't know if there's a good reason why these tolerances were chosen in the first place, though. Maybe there's a test environment that uses low-precision floats or something? In any case, my goal was to be consistent. Also, this test only really needs to distinguish between NaN and not NaN, so for that the tolerance doesn't matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok let's keep it as is in the name of consistency then. I agree that it does check what we want here anyways.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Dirichlet.log_prob()
incorrectly returns NaN in the case whereSo this corresponds to the case where one of the terms has the form
0 * log(0)
.This PR implements the same algorithm that
scipy.stats.dirichlet
uses to avoid this behavior, namelyxlogy(alpha - 1, x)
instead of(alpha - 1) * log(x)
. It also adds a test case comparing the pytorch and scipy implementations for this specific case.