forked from pytorch/pytorch
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for CorrCholeskyTransform (pytorch#48041)
Summary: This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in [NumPyro](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py) by fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR. Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures. For review: fehiepsi - could you help review the math? fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)? ezyang - could you review the changes in `run_test.py`? Instead of a separate `PYTEST_TESTS`, I have clubbed these tests in `USE_PYTEST_LIST` to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter. TODOs (probably not all can be satisfied at the same time): - [x] Use operations that are JIT friendly, i.e. the transform works with different sized input under JIT. - [x] Resolve test failures - currently `arange(scalar_tensor)` fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT? - [x] Add tests to check that the transform gives correct gradients and is in agreement with the `log_det_jacobian`. - [x] Add `input_event_dim` and `output_event_dim` to `CorrCholeskyTransform`. Pull Request resolved: pytorch#48041 Reviewed By: zhangguanheng66 Differential Revision: D25262505 Pulled By: neerajprad fbshipit-source-id: 5a57e1c19d8230b53592437590b9169bdf2f71e9
- Loading branch information
Showing
9 changed files
with
548 additions
and
333 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.