New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for CorrCholeskyTransform #48041
Conversation
torch/distributions/transforms.py
Outdated
codomain = constraints.corr_cholesky | ||
bijective = True | ||
# Note that since we add a dim in _call this should actually be 2 for the inverse transform. | ||
event_dim = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi, @fritzo - Note that the forward transform adds an additional dim, so event_dim=1
is defined w.r.t. to the value passed to the forward transform. The inverse transform's event_dim is consequently incorrect. I wasn't sure what to do. One thing would be to modify the inverse transform for this to have event_dim=2
. Any other suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what we can do here... Either leaving this 1
or 2
is fine I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I believe we did not anticipate transforms that change event dim? How about we add new properties to the base Transform
class and then override these in _InverseTransform
and your new transform?
class Transform:
@property
def input_event_dim(self):
return self.event_dim
@property
def output_event_dim(self):
return self.event_dim
class _InverseTransform(Transform):
@property
def input_event_dim(self):
return self._inv.output_event_dim
@property
def output_event_dim(self):
return self._inv.input_event_dim
class CorrCholeskyTransform(Transform):
input_event_dim = 1
output_event_dim = 2
@property
def event_dim(self):
raise ValueError("Please use .intput_event_dim or .output_event_dim instead")
cc @stefanwebb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good idea. I'll make the change!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the suggestion!
💊 CI failures summary and remediationsAs of commit 06efaf9 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The math looks good to me. I just have some minor comments. Besides event_dim=1
, which needs some discussions, I am a bit worried about the gradient of the transform. Is there any test for it (e.g. test the agreement of log_det implementation and the gradient of t(x)
where t
is a transform)? The reason is in an old PyTorch version, gradients might not proprogate properly when taking z.sqrt()
for a lower triangular matrix z
. It might have been fixed now but worth to test. :)
torch/distributions/transforms.py
Outdated
z = r ** 2 | ||
z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt() | ||
# Diagonal elements must be 1. | ||
r.diagonal(dim1=-1).copy_(r.new_ones(r.shape[:-2] + (n,))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can use fill_diagonal here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I didn't know about this function!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this won't work for batched tensors.
torch/distributions/transforms.py
Outdated
eps = torch.finfo(x.dtype).eps | ||
x = x.clamp(min=-1 + eps, max=1 - eps) | ||
n = (1 + math.sqrt(1 + 8 * x.shape[-1])) / 2 | ||
if round(n) - n > eps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is numercially fine even for a very large n (says 10e6).
74cd6a9
to
f533b44
Compare
We don't have any gradient tests, but I will add it in this PR.
Can you point me to this issue? I think I'm facing this right now. |
@neerajprad It is this issue. Back then, I resolved it with z_cumprod_sqrt = z_cumprod.new_zeros(z_cumprod.shape)
z_cumprod_sqrt[tril_idx] = z_cumprod[tril_idx].sqrt() |
|
d2f09f4
to
df134ed
Compare
@fehiepsi - Thanks for the push to add tests for jacobian (I caught a bug too). I realized that we had tests for jacobian of univariate transforms but I have extended those to support transforms like CorrCholesky and StickBreaking. I have also abstracted conversion of vector to lower triangular matrix into separate utility functions (as you did in NumPyro) to make it easier to test. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall! I just have a few comments else. You might also want to test the utilities tril_matrix_to_vec
, vec_to_tril_matrix
with different diag
values.
torch/distributions/constraints.py
Outdated
""" | ||
def check(self, value): | ||
row_norm = torch.norm(value, dim=-1) | ||
unit_row_norm = (row_norm <= 1. & row_norm >= 1e-6).all(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you meant 1 - 1e-6
.
nit: torch.norm
is deprecated https://pytorch.org/docs/stable/generated/torch.norm.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good to me, but I defer math checking to @fehiepsi 😄
Thanks for refactoring the test files, they look great!
torch/distributions/constraints.py
Outdated
""" | ||
def check(self, value): | ||
row_norm = torch.linalg.norm(value, dim=-1) | ||
unit_row_norm = ((row_norm <= 1.) & (row_norm >= (1 - 1e-6))).all(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect this to fail for float32 and matrices of dim > 100. If you want to be safer you could try something like:
tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an arbitrary fudge factor
row_norm = torch.linalg.norm(value.detach(), dim=-1)
return (row_norm - 1).abs().le(tol).all(dim=-1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sound advice! Thanks, will make the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. One thing that I realized was that as the size of the vector in real space increases, the transformed matrix returned by CorrCholeskyTransform
will have 0 diagonal entries and will violate the lower cholesky constraint. I am not entirely sure what to do about it since numerically this is bound to happen. cc @fehiepsi.
>>> y = CorrCholeskyTransform()(torch.randn(19900)) # 200 x 200 matrix
>>> y # last rows all have 0 diagonal entries
tensor([[ 1.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.6596, 0.7516, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.3552, 0.4238, 0.8332, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.8225, -0.5501, -0.1322, ..., 0.0000, 0.0000, 0.0000],
[ 0.5494, -0.6964, -0.3553, ..., 0.0000, 0.0000, 0.0000],
[ 0.1776, -0.4688, 0.7778, ..., -0.0000, 0.0000, 0.0000]])
>>> lower_cholesky.check(y)
tensor(False)
>>> y @ y.T # return a random correlation matrix
tensor([[ 1.0000, -0.6596, 0.3552, ..., 0.8225, 0.5494, 0.1776],
[-0.6596, 1.0000, 0.0842, ..., -0.9560, -0.8859, -0.4695],
[ 0.3552, 0.0842, 1.0000, ..., -0.0512, -0.3961, 0.5125],
...,
[ 0.8225, -0.9560, -0.0512, ..., 1.0000, 0.8932, 0.3093],
[ 0.5494, -0.8859, -0.3961, ..., 0.8932, 1.0000, 0.1874],
[ 0.1776, -0.4695, 0.5125, ..., 0.3093, 0.1874, 1.0000]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In any case, I think we should relax the constraint to >= 0 for diagonal entries to allow for correlation matrices that are positive semi definite. The LKJ distribution is defined over the space of positive definite correlation matrices, so this constraint is fine. This issue could come up during initialization in HMC when we have large correlation matrices in the model, and is just something to be aware of. In the above snippet, if we keep values close to 0 by narrowing the variance, e.g. torch.randn(19900) * 0.01
, then the issue can be averted but the default initialization of sampling between (-2, 2) could fail validation checks. @fehiepsi - I am leaving it as is for now, but we can possibly address this later (by applying some multiplication factor that is event size dependent) if it turns out to be an issue. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can replace
z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt()
by
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
to improve precision but I am not sure if it will be stable enough. We might need to use float64 for a moderate dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to really help for matrices of up to 200 dimensions (diag entries can be as small as 1e-38). For large matrices, using float64 would be a better idea. :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, that should help. I'll make the change. If we have issues with larger correlation matrices, we can discuss work-arounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @neerajprad!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@neerajprad merged this pull request in 5489a98. |
Summary: This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in [NumPyro](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py) by fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR. Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures. For review: fehiepsi - could you help review the math? fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)? ezyang - could you review the changes in `run_test.py`? Instead of a separate `PYTEST_TESTS`, I have clubbed these tests in `USE_PYTEST_LIST` to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter. TODOs (probably not all can be satisfied at the same time): - [x] Use operations that are JIT friendly, i.e. the transform works with different sized input under JIT. - [x] Resolve test failures - currently `arange(scalar_tensor)` fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT? - [x] Add tests to check that the transform gives correct gradients and is in agreement with the `log_det_jacobian`. - [x] Add `input_event_dim` and `output_event_dim` to `CorrCholeskyTransform`. Pull Request resolved: pytorch#48041 Reviewed By: zhangguanheng66 Differential Revision: D25262505 Pulled By: neerajprad fbshipit-source-id: 5a57e1c19d8230b53592437590b9169bdf2f71e9
Summary: As a follow up to #48041, this adds the `LKJCholesky` distribution that samples the Cholesky factor of positive definite correlation matrices. This also relaxes the check on `tril_matrix_to_vec` so that it works for 2x2 matrices with `diag=-2`. cc. fehiepsi Pull Request resolved: #48798 Reviewed By: zhangguanheng66 Differential Revision: D25364635 Pulled By: neerajprad fbshipit-source-id: 4abf8d83086b0ad45c5096760114a2c57e555602
This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in NumPyro by @fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR.
Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures.
For review:
@fehiepsi - could you help review the math?
@fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)?
@ezyang - could you review the changes in
run_test.py
? Instead of a separatePYTEST_TESTS
, I have clubbed these tests inUSE_PYTEST_LIST
to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter.TODOs (probably not all can be satisfied at the same time):
arange(scalar_tensor)
fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT?log_det_jacobian
.input_event_dim
andoutput_event_dim
toCorrCholeskyTransform
.