Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJCholesky distribution #48798

Closed
wants to merge 7 commits into from
Closed

Conversation

neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Dec 3, 2020

As a follow up to #48041, this adds the LKJCholesky distribution that samples the Cholesky factor of positive definite correlation matrices.

This also relaxes the check on tril_matrix_to_vec so that it works for 2x2 matrices with diag=-2.

cc. @fehiepsi

@neerajprad neerajprad added the module: distributions Related to torch.distributions label Dec 3, 2020
@codecov
Copy link

codecov bot commented Dec 3, 2020

Codecov Report

Merging #48798 (c2b4608) into master (befab0d) will increase coverage by 0.05%.
The diff coverage is 98.11%.

@@            Coverage Diff             @@
##           master   #48798      +/-   ##
==========================================
+ Coverage   80.74%   80.80%   +0.05%     
==========================================
  Files        1862     1866       +4     
  Lines      200435   201080     +645     
==========================================
+ Hits       161840   162473     +633     
- Misses      38595    38607      +12     

@dr-ci
Copy link

dr-ci bot commented Dec 4, 2020

💊 CI failures summary and remediations

As of commit ee563ae (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Dec 07 23:06:52 [E request_callback_no_python.cpp:636] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Dec 07 23:06:52 At: 
Dec 07 23:06:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(120): serialize 
Dec 07 23:06:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(172): serialize 
Dec 07 23:06:52  
Dec 07 23:06:52 [E request_callback_no_python.cpp:636] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future 
Dec 07 23:06:52  
Dec 07 23:06:52 At: 
Dec 07 23:06:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(120): serialize 
Dec 07 23:06:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(172): serialize 
Dec 07 23:06:52  
Dec 07 23:06:52 [E request_callback_no_python.cpp:636] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future 
Dec 07 23:06:52  
Dec 07 23:06:52 At: 
Dec 07 23:06:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(120): serialize 
Dec 07 23:06:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(172): serialize 
Dec 07 23:06:52  
Dec 07 23:06:52 [W tensorpipe_agent.cpp:547] RPC agent for worker0 encountered error when reading incoming request from worker2: EOF: end of file (this is expected to happen during shutdown) 
Dec 07 23:06:53 ok (1.740s) 
Dec 07 23:06:54   test_return_future_remote (__main__.TensorPipeRpcTestWithSpawn) ... [W tensorpipe_agent.cpp:547] RPC agent for worker1 encountered error when reading incoming request from worker0: EOF: end of file (this is expected to happen during shutdown) 
Dec 07 23:06:54 [W tensorpipe_agent.cpp:547] RPC agent for worker3 encountered error when reading incoming request from worker2: EOF: end of file (this is expected to happen during shutdown) 
Dec 07 23:06:54 [W tensorpipe_agent.cpp:547] RPC agent for worker0 encountered error when reading incoming request from worker2: EOF: end of file (this is expected to happen during shutdown) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 7 times.

Copy link
Contributor

@vishwakftw vishwakftw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked the math, mostly minor comments.

torch/distributions/lkj_cholesky.py Show resolved Hide resolved
torch/distributions/lkj_cholesky.py Show resolved Hide resolved
test/distributions/test_distributions.py Show resolved Hide resolved
test/distributions/test_distributions.py Show resolved Hide resolved
@neerajprad
Copy link
Contributor Author

I haven't checked the math, mostly minor comments.

Thanks for reviewing, @vishwakftw! @fehiepsi has agreed to review the math since this is basically a port of his implementation in numpyro.

Copy link
Contributor

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! It is a bit tricky to review the math of this distribution even that we have a version to compare. :)) The test looks good to me.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@neerajprad
Copy link
Contributor Author

Thanks for reviewing, @fehiepsi, @vishwakftw!

w = torch.sqrt(y) * u_hypersphere
# Fill diagonal elements; clamp for numerical stability
eps = torch.finfo(w.dtype).tiny
diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi - I changed min to torch.finfo(w.dtype).tiny because otherwise I was seeing samples whose log density would evaluate to inf (specially when concentration param is low).

@neerajprad
Copy link
Contributor Author

@fehiepsi - I was running NUTS using this and made two small changes. Rest everything should be the same.

  • relaxed the check on vec_to_tril_matrix so that diag can be [-n, n) instead of (-n, n). This is needed so that we can use LKJ with dim=2.
  • added a small fudge factor to the diag elements so that we don't return samples whose log density evaluates to inf.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@neerajprad merged this pull request in dee82ef.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: distributions Related to torch.distributions
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants