Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CorrCholeskyTransform #48041

Closed
wants to merge 13 commits into from

Conversation

neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Nov 16, 2020

This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in NumPyro by @fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR.

Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures.

For review:
@fehiepsi - could you help review the math?
@fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)?
@ezyang - could you review the changes in run_test.py? Instead of a separate PYTEST_TESTS, I have clubbed these tests in USE_PYTEST_LIST to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter.

TODOs (probably not all can be satisfied at the same time):

  • Use operations that are JIT friendly, i.e. the transform works with different sized input under JIT.
  • Resolve test failures - currently arange(scalar_tensor) fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT?
  • Add tests to check that the transform gives correct gradients and is in agreement with the log_det_jacobian.
  • Add input_event_dim and output_event_dim to CorrCholeskyTransform.

codomain = constraints.corr_cholesky
bijective = True
# Note that since we add a dim in _call this should actually be 2 for the inverse transform.
event_dim = 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi, @fritzo - Note that the forward transform adds an additional dim, so event_dim=1 is defined w.r.t. to the value passed to the forward transform. The inverse transform's event_dim is consequently incorrect. I wasn't sure what to do. One thing would be to modify the inverse transform for this to have event_dim=2. Any other suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what we can do here... Either leaving this 1 or 2 is fine I guess.

Copy link
Collaborator

@fritzo fritzo Nov 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I believe we did not anticipate transforms that change event dim? How about we add new properties to the base Transform class and then override these in _InverseTransform and your new transform?

class Transform:
    @property
    def input_event_dim(self):
        return self.event_dim

    @property
    def output_event_dim(self):
        return self.event_dim

class _InverseTransform(Transform):
    @property
    def input_event_dim(self):
        return self._inv.output_event_dim

    @property
    def output_event_dim(self):
        return self._inv.input_event_dim

class CorrCholeskyTransform(Transform):
    input_event_dim = 1
    output_event_dim = 2

    @property
    def event_dim(self):
        raise ValueError("Please use .intput_event_dim or .output_event_dim instead")

cc @stefanwebb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. I'll make the change!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks for the suggestion!

@dr-ci
Copy link

dr-ci bot commented Nov 16, 2020

💊 CI failures summary and remediations

As of commit 06efaf9 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 48 times.

Copy link
Contributor

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math looks good to me. I just have some minor comments. Besides event_dim=1, which needs some discussions, I am a bit worried about the gradient of the transform. Is there any test for it (e.g. test the agreement of log_det implementation and the gradient of t(x) where t is a transform)? The reason is in an old PyTorch version, gradients might not proprogate properly when taking z.sqrt() for a lower triangular matrix z. It might have been fixed now but worth to test. :)

z = r ** 2
z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt()
# Diagonal elements must be 1.
r.diagonal(dim1=-1).copy_(r.new_ones(r.shape[:-2] + (n,)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use fill_diagonal here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't know about this function!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this won't work for batched tensors.

eps = torch.finfo(x.dtype).eps
x = x.clamp(min=-1 + eps, max=1 - eps)
n = (1 + math.sqrt(1 + 8 * x.shape[-1])) / 2
if round(n) - n > eps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is numercially fine even for a very large n (says 10e6).

@neerajprad
Copy link
Contributor Author

I am a bit worried about the gradient of the transform. Is there any test for it (e.g. test the agreement of log_det implementation and the gradient of t(x) where t is a transform)?

We don't have any gradient tests, but I will add it in this PR.

The reason is in an old PyTorch version, gradients might not proprogate properly when taking z.sqrt() for a lower triangular matrix z. It might have been fixed now but worth to test. :)

Can you point me to this issue? I think I'm facing this right now.

@neerajprad neerajprad added the module: distributions Related to torch.distributions label Nov 17, 2020
@fehiepsi
Copy link
Contributor

fehiepsi commented Nov 17, 2020

@neerajprad It is this issue. Back then, I resolved it with

    z_cumprod_sqrt = z_cumprod.new_zeros(z_cumprod.shape)
    z_cumprod_sqrt[tril_idx] = z_cumprod[tril_idx].sqrt()

@ezyang
Copy link
Contributor

ezyang commented Nov 17, 2020

run_tests.py changes look fine. You're testing more now, not less, right? Very simple.

@neerajprad
Copy link
Contributor Author

We don't have any gradient tests, but I will add it in this PR.

@fehiepsi - Thanks for the push to add tests for jacobian (I caught a bug too). I realized that we had tests for jacobian of univariate transforms but I have extended those to support transforms like CorrCholesky and StickBreaking. I have also abstracted conversion of vector to lower triangular matrix into separate utility functions (as you did in NumPyro) to make it easier to test.

Copy link
Contributor

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall! I just have a few comments else. You might also want to test the utilities tril_matrix_to_vec, vec_to_tril_matrix with different diag values.

test/distributions/test_distributions.py Outdated Show resolved Hide resolved
test/distributions/test_transforms.py Show resolved Hide resolved
"""
def check(self, value):
row_norm = torch.norm(value, dim=-1)
unit_row_norm = (row_norm <= 1. & row_norm >= 1e-6).all(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you meant 1 - 1e-6.

nit: torch.norm is deprecated https://pytorch.org/docs/stable/generated/torch.norm.html

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good to me, but I defer math checking to @fehiepsi 😄
Thanks for refactoring the test files, they look great!

"""
def check(self, value):
row_norm = torch.linalg.norm(value, dim=-1)
unit_row_norm = ((row_norm <= 1.) & (row_norm >= (1 - 1e-6))).all(dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this to fail for float32 and matrices of dim > 100. If you want to be safer you could try something like:

tol = torch.finfo(value.dtype).eps * value.size(-1) * 10  # 10 is an arbitrary fudge factor
row_norm = torch.linalg.norm(value.detach(), dim=-1)
return (row_norm - 1).abs().le(tol).all(dim=-1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound advice! Thanks, will make the change.

Copy link
Contributor Author

@neerajprad neerajprad Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. One thing that I realized was that as the size of the vector in real space increases, the transformed matrix returned by CorrCholeskyTransform will have 0 diagonal entries and will violate the lower cholesky constraint. I am not entirely sure what to do about it since numerically this is bound to happen. cc @fehiepsi.

>>> y = CorrCholeskyTransform()(torch.randn(19900))  # 200 x 200 matrix
>>> y  # last rows all have 0 diagonal entries

tensor([[ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.6596,  0.7516,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3552,  0.4238,  0.8332,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.8225, -0.5501, -0.1322,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.5494, -0.6964, -0.3553,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1776, -0.4688,  0.7778,  ..., -0.0000,  0.0000,  0.0000]])
>>> lower_cholesky.check(y)

tensor(False)

>>> y @ y.T  # return a random correlation matrix

tensor([[ 1.0000, -0.6596,  0.3552,  ...,  0.8225,  0.5494,  0.1776],
        [-0.6596,  1.0000,  0.0842,  ..., -0.9560, -0.8859, -0.4695],
        [ 0.3552,  0.0842,  1.0000,  ..., -0.0512, -0.3961,  0.5125],
        ...,
        [ 0.8225, -0.9560, -0.0512,  ...,  1.0000,  0.8932,  0.3093],
        [ 0.5494, -0.8859, -0.3961,  ...,  0.8932,  1.0000,  0.1874],
        [ 0.1776, -0.4695,  0.5125,  ...,  0.3093,  0.1874,  1.0000]])

Copy link
Contributor Author

@neerajprad neerajprad Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I think we should relax the constraint to >= 0 for diagonal entries to allow for correlation matrices that are positive semi definite. The LKJ distribution is defined over the space of positive definite correlation matrices, so this constraint is fine. This issue could come up during initialization in HMC when we have large correlation matrices in the model, and is just something to be aware of. In the above snippet, if we keep values close to 0 by narrowing the variance, e.g. torch.randn(19900) * 0.01, then the issue can be averted but the default initialization of sampling between (-2, 2) could fail validation checks. @fehiepsi - I am leaving it as is for now, but we can possibly address this later (by applying some multiplication factor that is event size dependent) if it turns out to be an issue. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can replace

z1m_cumprod_sqrt = (1 - z).cumprod(-1).sqrt()

by

z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)

to improve precision but I am not sure if it will be stable enough. We might need to use float64 for a moderate dimension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to really help for matrices of up to 200 dimensions (diag entries can be as small as 1e-38). For large matrices, using float64 would be a better idea. :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, that should help. I'll make the change. If we have issues with larger correlation matrices, we can discuss work-arounds.

@neerajprad
Copy link
Contributor Author

Thanks for the very helpful reviews, @fehiepsi, @fritzo. I'm noticing a few build failures and will address all your comments alongside that shortly.

Copy link
Contributor

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @neerajprad!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@neerajprad merged this pull request in 5489a98.

shaibagon pushed a commit to shaibagon/pytorch that referenced this pull request Dec 3, 2020
Summary:
This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in [NumPyro](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py) by fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR.

Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures.

For review:
 fehiepsi - could you help review the math?
 fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)?
 ezyang - could you review the changes in `run_test.py`? Instead of a separate `PYTEST_TESTS`, I have clubbed these tests in `USE_PYTEST_LIST` to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter.

TODOs (probably not all can be satisfied at the same time):
 - [x] Use operations that are JIT friendly, i.e. the transform works with different sized input under JIT.
 - [x] Resolve test failures - currently `arange(scalar_tensor)` fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT?
 - [x] Add tests to check that the transform gives correct gradients and is in agreement with the `log_det_jacobian`.
 - [x] Add `input_event_dim` and `output_event_dim` to `CorrCholeskyTransform`.

Pull Request resolved: pytorch#48041

Reviewed By: zhangguanheng66

Differential Revision: D25262505

Pulled By: neerajprad

fbshipit-source-id: 5a57e1c19d8230b53592437590b9169bdf2f71e9
facebook-github-bot pushed a commit that referenced this pull request Dec 8, 2020
Summary:
As a follow up to #48041, this adds the `LKJCholesky` distribution that samples the Cholesky factor of positive definite correlation matrices.

This also relaxes the check on `tril_matrix_to_vec` so that it works for 2x2 matrices with `diag=-2`.

cc. fehiepsi

Pull Request resolved: #48798

Reviewed By: zhangguanheng66

Differential Revision: D25364635

Pulled By: neerajprad

fbshipit-source-id: 4abf8d83086b0ad45c5096760114a2c57e555602
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: distributions Related to torch.distributions
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants