New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LKJCholesky distribution #48798
Add LKJCholesky distribution #48798
Conversation
Codecov Report
@@ Coverage Diff @@
## master #48798 +/- ##
==========================================
+ Coverage 80.74% 80.80% +0.05%
==========================================
Files 1862 1866 +4
Lines 200435 201080 +645
==========================================
+ Hits 161840 162473 +633
- Misses 38595 38607 +12 |
💊 CI failures summary and remediationsAs of commit ee563ae (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_linux_xenial_py3_6_gcc5_4_test (1/1)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't checked the math, mostly minor comments.
Thanks for reviewing, @vishwakftw! @fehiepsi has agreed to review the math since this is basically a port of his implementation in numpyro. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! It is a bit tricky to review the math of this distribution even that we have a version to compare. :)) The test looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Thanks for reviewing, @fehiepsi, @vishwakftw! |
w = torch.sqrt(y) * u_hypersphere | ||
# Fill diagonal elements; clamp for numerical stability | ||
eps = torch.finfo(w.dtype).tiny | ||
diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi - I changed min to torch.finfo(w.dtype).tiny
because otherwise I was seeing samples whose log density would evaluate to inf (specially when concentration param is low).
@fehiepsi - I was running NUTS using this and made two small changes. Rest everything should be the same.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@neerajprad merged this pull request in dee82ef. |
As a follow up to #48041, this adds the
LKJCholesky
distribution that samples the Cholesky factor of positive definite correlation matrices.This also relaxes the check on
tril_matrix_to_vec
so that it works for 2x2 matrices withdiag=-2
.cc. @fehiepsi