Skip to content

Commit

Permalink
Add support for CorrCholeskyTransform (pytorch#48041)
Browse files Browse the repository at this point in the history
Summary:
This adds a transform to convert a real vector of (D * (D-1))/2 dimension into the cholesky factor of a D x D correlation matrix. This follows the implementation in [NumPyro](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py) by fehiepsi. This is needed for the LKJDistribution which will be added in a subsequent PR.

Also in line with the ongoing effort to refactor distributions test, this moves the transforms test into its own file that uses pytest with parametrized fixtures.

For review:
 fehiepsi - could you help review the math?
 fritzo - do you have any suggestions for what to do about the event dimension (more details are in the comment below)?
 ezyang - could you review the changes in `run_test.py`? Instead of a separate `PYTEST_TESTS`, I have clubbed these tests in `USE_PYTEST_LIST` to avoid duplicate logic. The only difference is that we do not anymore check if pytest is not installed and exclude the tests in the list. I figured that if existing tests are already using pytest, this should not matter.

TODOs (probably not all can be satisfied at the same time):
 - [x] Use operations that are JIT friendly, i.e. the transform works with different sized input under JIT.
 - [x] Resolve test failures - currently `arange(scalar_tensor)` fails on certain backends but this is needed for JIT. Maybe we should only support same sized tensor under JIT?
 - [x] Add tests to check that the transform gives correct gradients and is in agreement with the `log_det_jacobian`.
 - [x] Add `input_event_dim` and `output_event_dim` to `CorrCholeskyTransform`.

Pull Request resolved: pytorch#48041

Reviewed By: zhangguanheng66

Differential Revision: D25262505

Pulled By: neerajprad

fbshipit-source-id: 5a57e1c19d8230b53592437590b9169bdf2f71e9
  • Loading branch information
neerajprad authored and shaibagon committed Dec 3, 2020
1 parent b726b7b commit 7d365ba
Show file tree
Hide file tree
Showing 9 changed files with 548 additions and 333 deletions.
15 changes: 12 additions & 3 deletions test/distributions/test_constraints.py
Expand Up @@ -27,6 +27,7 @@
(constraints.half_open_interval, -2, -1),
(constraints.half_open_interval, 1, 2),
(constraints.simplex,),
(constraints.corr_cholesky,),
(constraints.lower_cholesky,),
]

Expand All @@ -49,7 +50,11 @@ def test_biject_to(constraint_fn, args, is_cuda):
except NotImplementedError:
pytest.skip('`biject_to` not implemented.')
assert t.bijective, "biject_to({}) is not bijective".format(constraint)
x = torch.randn(5, 5, dtype=torch.double)
if constraint_fn is constraints.corr_cholesky:
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
x = torch.randn(6, 6, dtype=torch.double)
else:
x = torch.randn(5, 5, dtype=torch.double)
if is_cuda:
x = x.cuda()
y = t(x)
Expand All @@ -62,7 +67,7 @@ def test_biject_to(constraint_fn, args, is_cuda):
assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)

j = t.log_abs_det_jacobian(x, y)
assert j.shape == x.shape[:x.dim() - t.event_dim]
assert j.shape == x.shape[:x.dim() - t.input_event_dim]


@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
Expand All @@ -72,7 +77,11 @@ def test_biject_to(constraint_fn, args, is_cuda):
def test_transform_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
t = transform_to(constraint)
x = torch.randn(5, 5, dtype=torch.double)
if constraint_fn is constraints.corr_cholesky:
# (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim)
x = torch.randn(6, 6, dtype=torch.double)
else:
x = torch.randn(5, 5, dtype=torch.double)
if is_cuda:
x = x.cuda()
y = t(x)
Expand Down
322 changes: 2 additions & 320 deletions test/distributions/test_distributions.py
Expand Up @@ -56,13 +56,8 @@
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.kl import _kl_expfamily_expfamily
from torch.distributions.transforms import (AbsTransform, AffineTransform,
CatTransform, ComposeTransform, ExpTransform,
LowerCholeskyTransform,
PowerTransform, SigmoidTransform,
TanhTransform, SoftmaxTransform,
StickBreakingTransform,
identity_transform, StackTransform)
from torch.distributions.transforms import (AffineTransform, CatTransform, ExpTransform,
StackTransform, identity_transform)
from torch.distributions.utils import probs_to_logits, lazy_property
from torch.nn.functional import softmax

Expand Down Expand Up @@ -4300,319 +4295,6 @@ def test_icdf(self):
self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)


class TestTransforms(TestCase):
def setUp(self):
super(TestTransforms, self).setUp()
self.transforms = []
transforms_by_cache_size = {}
for cache_size in [0, 1]:
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
PowerTransform(exponent=2,
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
SoftmaxTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
]
for t in transforms[:]:
transforms.append(t.inv)
transforms.append(identity_transform)
self.transforms += transforms
if cache_size == 0:
self.unique_transforms = transforms[:]

def _generate_data(self, transform):
domain = transform.domain
codomain = transform.codomain
x = torch.empty(4, 5)
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
x = torch.empty(6, 6)
x = x.normal_()
return x
elif domain is constraints.real:
return x.normal_()
elif domain is constraints.positive:
return x.normal_().exp()
elif domain is constraints.unit_interval:
return x.uniform_()
elif isinstance(domain, constraints.interval):
x = x.uniform_()
x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
return x
elif domain is constraints.simplex:
x = x.normal_().exp()
x /= x.sum(-1, True)
return x
raise ValueError('Unsupported domain: {}'.format(domain))

def test_inv_inv(self):
for t in self.transforms:
self.assertTrue(t.inv.inv is t)

def test_equality(self):
transforms = self.unique_transforms
for x, y in product(transforms, transforms):
if x is y:
self.assertTrue(x == y)
self.assertFalse(x != y)
else:
self.assertFalse(x == y)
self.assertTrue(x != y)

self.assertTrue(identity_transform == identity_transform.inv)
self.assertFalse(identity_transform != identity_transform.inv)

def test_with_cache(self):
for transform in self.transforms:
if transform._cache_size == 0:
transform = transform.with_cache(1)
self.assertTrue(transform._cache_size == 1)

x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
continue
y2 = transform(x)
self.assertTrue(y2 is y)

def test_forward_inverse_cache(self):
for transform in self.transforms:
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
continue
x2 = transform.inv(y) # should be implemented at least by caching
y2 = transform(x2) # should be implemented at least by caching
if transform.bijective:
# verify function inverse
self.assertEqual(x2, x, msg='\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
]))
else:
# verify weaker function pseudo-inverse
self.assertEqual(y2, y, msg='\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
]))

def test_forward_inverse_no_cache(self):
for transform in self.transforms:
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
x2 = transform.inv(y.clone()) # bypass cache
y2 = transform(x2)
except NotImplementedError:
continue
if transform.bijective:
# verify function inverse
self.assertEqual(x2, x, msg='\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
]))
else:
# verify weaker function pseudo-inverse
self.assertEqual(y2, y, msg='\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
]))

def test_univariate_forward_jacobian(self):
for transform in self.transforms:
if transform.event_dim > 0:
continue
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
expected = torch.abs(grad([y.sum()], [x])[0]).log()
self.assertEqual(actual, expected, msg='\n'.join([
'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform),
'Expected: {}'.format(expected),
'Actual: {}'.format(actual),
]))

def test_univariate_inverse_jacobian(self):
for transform in self.transforms:
if transform.event_dim > 0:
continue
y = self._generate_data(transform.inv).requires_grad_()
try:
x = transform.inv(y)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
expected = -torch.abs(grad([x.sum()], [y])[0]).log()
self.assertEqual(actual, expected, msg='\n'.join([
'{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform),
'Expected: {}'.format(expected),
'Actual: {}'.format(actual),
]))

def test_jacobian_shape(self):
for transform in self.transforms:
x = self._generate_data(transform)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
self.assertEqual(actual.shape, x.shape[:x.dim() - transform.event_dim])

def test_transform_shapes(self):
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()

self.assertEqual(transform0.event_dim, 0)
self.assertEqual(transform1.event_dim, 1)
self.assertEqual(transform2.event_dim, 2)
self.assertEqual(ComposeTransform([transform0, transform1]).event_dim, 1)
self.assertEqual(ComposeTransform([transform0, transform2]).event_dim, 2)
self.assertEqual(ComposeTransform([transform1, transform2]).event_dim, 2)

def test_transformed_distribution_shapes(self):
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
examples = [
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
]
for batch_shape, event_shape, dist in examples:
self.assertEqual(dist.batch_shape, batch_shape)
self.assertEqual(dist.event_shape, event_shape)
x = dist.rsample()
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
continue

def test_jit_fwd(self):
for transform in self.unique_transforms:
x = self._generate_data(transform).requires_grad_()

def f(x):
return transform(x)

try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
continue

# check on different inputs
x = self._generate_data(transform).requires_grad_()
self.assertEqual(f(x), traced_f(x))

def test_jit_inv(self):
for transform in self.unique_transforms:
y = self._generate_data(transform.inv).requires_grad_()

def f(y):
return transform.inv(y)

try:
traced_f = torch.jit.trace(f, (y,))
except NotImplementedError:
continue

# check on different inputs
y = self._generate_data(transform.inv).requires_grad_()
self.assertEqual(f(y), traced_f(y))

def test_jit_jacobian(self):
for transform in self.unique_transforms:
x = self._generate_data(transform).requires_grad_()

def f(x):
y = transform(x)
return transform.log_abs_det_jacobian(x, y)

try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
continue

# check on different inputs
x = self._generate_data(transform).requires_grad_()
self.assertEqual(f(x), traced_f(x))


class TestFunctors(TestCase):
def test_cat_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
Expand Down

0 comments on commit 7d365ba

Please sign in to comment.