Skip to content

Commit

Permalink
Debug positive definite constraints (#68720)
Browse files Browse the repository at this point in the history
Summary:
While implementing #68644,
during the testing of 'torch.distributions.constraint.positive_definite', I found an error in the code: [location](https://github.com/pytorch/pytorch/blob/c7ecf1498d961415006c3710ac8d99166fe5d634/torch/distributions/constraints.py#L465-L468)
```
class _PositiveDefinite(Constraint):
    """
    Constrain to positive-definite matrices.
    """
    event_dim = 2

    def check(self, value):
        # Assumes that the matrix or batch of matrices in value are symmetric
        # info == 0 means no error, that is, it's SPD
        return torch.linalg.cholesky_ex(value).info.eq(0).unsqueeze(0)
```

The error is caused when I check the positive definiteness of
`torch.cuda.DoubleTensor([[2., 0], [2., 2]])`
But it did not made a problem for
`torch.DoubleTensor([[2., 0], [2., 2]])`

You may easily reproduce the error by following code:

```
Python 3.9.7 (default, Sep 16 2021, 13:09:58)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> const = torch.distributions.constraints.positive_definite
>>> const.check(torch.cuda.DoubleTensor([[2., 0], [2., 2]]))
tensor([False], device='cuda:0')
>>> const.check(torch.DoubleTensor([[2., 0], [2., 2]]))
tensor([True])
```
The cause of error can be analyzed more if you give 'check_errors = True' as a additional argument for 'torch.linalg.cholesky_ex'.
It seem that it is caused by the recent changes in 'torch.linalg'.
And, I suggest to modify the '_PositiveDefinite' class by using 'torch.linalg.eig' function like the below:

```
class _PositiveDefinite(Constraint):
    """
    Constrain to positive-definite matrices.
    """
    event_dim = 2

    def check(self, value):
        return (torch.linalg.eig(value)[0].real > 0).all(dim=-1)
```

By using above implementation, I get following result:
```
Python 3.9.7 (default, Sep 16 2021, 13:09:58)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> const = torch.distributions.constraints.positive_definite
>>> const.check(torch.cuda.DoubleTensor([[2., 0.], [2., 2.]]))
tensor(True, device='cuda:0')
>>> const.check(torch.DoubleTensor([[2., 0.], [2., 2.]]))
tensor(True)
```

FYI, I do not know what algorithm is used in 'torch.linalg.eig' and 'torch.linalg.cholesky_ex'. As far as I know, they have same time complexity generally, O(n^3). It seems that in case you used special algorithms or finer parallelization, time complexity of Cholesky decomposition may be reduced to approximately O(n^2.5). If there is a reason 'torch.distributions.constraints.positive_definite' used 'torch.linalg.cholesky_ex' rather than 'torch.linalg.eig' previously, I hope to know.

Pull Request resolved: #68720

Reviewed By: samdow

Differential Revision: D32724391

Pulled By: neerajprad

fbshipit-source-id: 32e2a04b2d5b5ddf57a3de50f995131d279ede49
  • Loading branch information
nonconvexopt authored and facebook-github-bot committed Dec 1, 2021
1 parent 8586f37 commit 845a82b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 8 deletions.
19 changes: 19 additions & 0 deletions test/distributions/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
from torch.testing._internal.common_cuda import TEST_CUDA


EXAMPLES = [
(constraints.symmetric, False, [[2., 0], [2., 2]]),
(constraints.positive_definite, False, [[2., 0], [2., 2]]),
(constraints.symmetric, True, [[3., -5], [-5., 3]]),
(constraints.positive_definite, False, [[3., -5], [-5., 3]]),
(constraints.symmetric, True, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.positive_definite, False, [[[1., -2], [-2., 1]], [[2., 3], [3., 2]]]),
(constraints.symmetric, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
(constraints.positive_definite, True, [[[4., 2], [2., 4]], [[3., -1], [-1., 3]]]),
]

CONSTRAINTS = [
(constraints.real,),
(constraints.real_vector,),
Expand Down Expand Up @@ -41,6 +52,14 @@ def build_constraint(constraint_fn, args, is_cuda=False):
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))

@pytest.mark.parametrize('constraint_fn, result, value', EXAMPLES)
@pytest.mark.parametrize('is_cuda', [False,
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
reason='CUDA not found.'))])
def test_constraint(constraint_fn, result, value, is_cuda):
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
assert constraint_fn.check(t(value)).all() == result


@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
Expand Down
44 changes: 39 additions & 5 deletions torch/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
- ``constraints.real_vector``
- ``constraints.real``
- ``constraints.simplex``
- ``constraints.symmetric``
- ``constraints.stack``
- ``constraints.square``
- ``constraints.symmetric``
- ``constraints.unit_interval``
"""

Expand Down Expand Up @@ -53,7 +56,9 @@
'real',
'real_vector',
'simplex',
'square',
'stack',
'symmetric',
'unit_interval',
]

Expand Down Expand Up @@ -456,16 +461,43 @@ def check(self, value):
return _LowerCholesky().check(value) & unit_row_norm


class _PositiveDefinite(Constraint):
class _Square(Constraint):
"""
Constrain to positive-definite matrices.
Constrain to square matrices.
"""
event_dim = 2

def check(self, value):
# Assumes that the matrix or batch of matrices in value are symmetric
# info == 0 means no error, that is, it's SPD
return torch.linalg.cholesky_ex(value).info.eq(0).unsqueeze(0)
return torch.full(
size=value.shape[:-2],
fill_value=(value.shape[-2] == value.shape[-1]),
dtype=torch.bool,
device=value.device
)


class _Symmetric(_Square):
"""
Constrain to Symmetric square matrices.
"""

def check(self, value):
square_check = super().check(value)
if not square_check.all():
return square_check
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)


class _PositiveDefinite(_Symmetric):
"""
Constrain to positive-definite matrices.
"""

def check(self, value):
sym_check = super().check(value)
if not sym_check.all():
return sym_check
return torch.linalg.cholesky_ex(value).info.eq(0)


class _Cat(Constraint):
Expand Down Expand Up @@ -557,6 +589,8 @@ def check(self, value):
lower_triangular = _LowerTriangular()
lower_cholesky = _LowerCholesky()
corr_cholesky = _CorrCholesky()
square = _Square()
symmetric = _Symmetric()
positive_definite = _PositiveDefinite()
cat = _Cat
stack = _Stack
4 changes: 1 addition & 3 deletions torch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def covariance_matrix(self):

@lazy_property
def precision_matrix(self):
identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype)
# TODO: use cholesky_inverse when its batching is supported
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
self._batch_shape + self._event_shape + self._event_shape)

@property
Expand Down

0 comments on commit 845a82b

Please sign in to comment.