diff --git a/test/distributions/test_constraints.py b/test/distributions/test_constraints.py index b4f75fb58de8..d4dd9239920d 100644 --- a/test/distributions/test_constraints.py +++ b/test/distributions/test_constraints.py @@ -27,6 +27,7 @@ (constraints.half_open_interval, -2, -1), (constraints.half_open_interval, 1, 2), (constraints.simplex,), + (constraints.corr_cholesky,), (constraints.lower_cholesky,), ] @@ -49,7 +50,11 @@ def test_biject_to(constraint_fn, args, is_cuda): except NotImplementedError: pytest.skip('`biject_to` not implemented.') assert t.bijective, "biject_to({}) is not bijective".format(constraint) - x = torch.randn(5, 5, dtype=torch.double) + if constraint_fn is constraints.corr_cholesky: + # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) + x = torch.randn(6, 6, dtype=torch.double) + else: + x = torch.randn(5, 5, dtype=torch.double) if is_cuda: x = x.cuda() y = t(x) @@ -62,7 +67,7 @@ def test_biject_to(constraint_fn, args, is_cuda): assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint) j = t.log_abs_det_jacobian(x, y) - assert j.shape == x.shape[:x.dim() - t.event_dim] + assert j.shape == x.shape[:x.dim() - t.input_event_dim] @pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS]) @@ -72,7 +77,11 @@ def test_biject_to(constraint_fn, args, is_cuda): def test_transform_to(constraint_fn, args, is_cuda): constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda) t = transform_to(constraint) - x = torch.randn(5, 5, dtype=torch.double) + if constraint_fn is constraints.corr_cholesky: + # (D * (D-1)) / 2 (where D = 4) = 6 (size of last dim) + x = torch.randn(6, 6, dtype=torch.double) + else: + x = torch.randn(5, 5, dtype=torch.double) if is_cuda: x = x.cuda() y = t(x) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index b4b8b6e81462..abba69eb472f 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -56,13 +56,8 @@ from torch.distributions.constraints import Constraint, is_dependent from torch.distributions.dirichlet import _Dirichlet_backward from torch.distributions.kl import _kl_expfamily_expfamily -from torch.distributions.transforms import (AbsTransform, AffineTransform, - CatTransform, ComposeTransform, ExpTransform, - LowerCholeskyTransform, - PowerTransform, SigmoidTransform, - TanhTransform, SoftmaxTransform, - StickBreakingTransform, - identity_transform, StackTransform) +from torch.distributions.transforms import (AffineTransform, CatTransform, ExpTransform, + StackTransform, identity_transform) from torch.distributions.utils import probs_to_logits, lazy_property from torch.nn.functional import softmax @@ -4300,319 +4295,6 @@ def test_icdf(self): self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist) -class TestTransforms(TestCase): - def setUp(self): - super(TestTransforms, self).setUp() - self.transforms = [] - transforms_by_cache_size = {} - for cache_size in [0, 1]: - transforms = [ - AbsTransform(cache_size=cache_size), - ExpTransform(cache_size=cache_size), - PowerTransform(exponent=2, - cache_size=cache_size), - PowerTransform(exponent=torch.tensor(5.).normal_(), - cache_size=cache_size), - SigmoidTransform(cache_size=cache_size), - TanhTransform(cache_size=cache_size), - AffineTransform(0, 1, cache_size=cache_size), - AffineTransform(1, -2, cache_size=cache_size), - AffineTransform(torch.randn(5), - torch.randn(5), - cache_size=cache_size), - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - SoftmaxTransform(cache_size=cache_size), - StickBreakingTransform(cache_size=cache_size), - LowerCholeskyTransform(cache_size=cache_size), - ComposeTransform([ - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - ]), - ComposeTransform([ - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - ExpTransform(cache_size=cache_size), - ]), - ComposeTransform([ - AffineTransform(0, 1, cache_size=cache_size), - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - AffineTransform(1, -2, cache_size=cache_size), - AffineTransform(torch.randn(4, 5), - torch.randn(4, 5), - cache_size=cache_size), - ]), - ] - for t in transforms[:]: - transforms.append(t.inv) - transforms.append(identity_transform) - self.transforms += transforms - if cache_size == 0: - self.unique_transforms = transforms[:] - - def _generate_data(self, transform): - domain = transform.domain - codomain = transform.codomain - x = torch.empty(4, 5) - if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky: - x = torch.empty(6, 6) - x = x.normal_() - return x - elif domain is constraints.real: - return x.normal_() - elif domain is constraints.positive: - return x.normal_().exp() - elif domain is constraints.unit_interval: - return x.uniform_() - elif isinstance(domain, constraints.interval): - x = x.uniform_() - x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound) - return x - elif domain is constraints.simplex: - x = x.normal_().exp() - x /= x.sum(-1, True) - return x - raise ValueError('Unsupported domain: {}'.format(domain)) - - def test_inv_inv(self): - for t in self.transforms: - self.assertTrue(t.inv.inv is t) - - def test_equality(self): - transforms = self.unique_transforms - for x, y in product(transforms, transforms): - if x is y: - self.assertTrue(x == y) - self.assertFalse(x != y) - else: - self.assertFalse(x == y) - self.assertTrue(x != y) - - self.assertTrue(identity_transform == identity_transform.inv) - self.assertFalse(identity_transform != identity_transform.inv) - - def test_with_cache(self): - for transform in self.transforms: - if transform._cache_size == 0: - transform = transform.with_cache(1) - self.assertTrue(transform._cache_size == 1) - - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - except NotImplementedError: - continue - y2 = transform(x) - self.assertTrue(y2 is y) - - def test_forward_inverse_cache(self): - for transform in self.transforms: - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - except NotImplementedError: - continue - x2 = transform.inv(y) # should be implemented at least by caching - y2 = transform(x2) # should be implemented at least by caching - if transform.bijective: - # verify function inverse - self.assertEqual(x2, x, msg='\n'.join([ - '{} t.inv(t(-)) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - ])) - else: - # verify weaker function pseudo-inverse - self.assertEqual(y2, y, msg='\n'.join([ - '{} t(t.inv(t(-))) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - 'y2 = t(x2) = {}'.format(y2), - ])) - - def test_forward_inverse_no_cache(self): - for transform in self.transforms: - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - x2 = transform.inv(y.clone()) # bypass cache - y2 = transform(x2) - except NotImplementedError: - continue - if transform.bijective: - # verify function inverse - self.assertEqual(x2, x, msg='\n'.join([ - '{} t.inv(t(-)) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - ])) - else: - # verify weaker function pseudo-inverse - self.assertEqual(y2, y, msg='\n'.join([ - '{} t(t.inv(t(-))) error'.format(transform), - 'x = {}'.format(x), - 'y = t(x) = {}'.format(y), - 'x2 = t.inv(y) = {}'.format(x2), - 'y2 = t(x2) = {}'.format(y2), - ])) - - def test_univariate_forward_jacobian(self): - for transform in self.transforms: - if transform.event_dim > 0: - continue - x = self._generate_data(transform).requires_grad_() - try: - y = transform(x) - actual = transform.log_abs_det_jacobian(x, y) - except NotImplementedError: - continue - expected = torch.abs(grad([y.sum()], [x])[0]).log() - self.assertEqual(actual, expected, msg='\n'.join([ - 'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform), - 'Expected: {}'.format(expected), - 'Actual: {}'.format(actual), - ])) - - def test_univariate_inverse_jacobian(self): - for transform in self.transforms: - if transform.event_dim > 0: - continue - y = self._generate_data(transform.inv).requires_grad_() - try: - x = transform.inv(y) - actual = transform.log_abs_det_jacobian(x, y) - except NotImplementedError: - continue - expected = -torch.abs(grad([x.sum()], [y])[0]).log() - self.assertEqual(actual, expected, msg='\n'.join([ - '{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform), - 'Expected: {}'.format(expected), - 'Actual: {}'.format(actual), - ])) - - def test_jacobian_shape(self): - for transform in self.transforms: - x = self._generate_data(transform) - try: - y = transform(x) - actual = transform.log_abs_det_jacobian(x, y) - except NotImplementedError: - continue - self.assertEqual(actual.shape, x.shape[:x.dim() - transform.event_dim]) - - def test_transform_shapes(self): - transform0 = ExpTransform() - transform1 = SoftmaxTransform() - transform2 = LowerCholeskyTransform() - - self.assertEqual(transform0.event_dim, 0) - self.assertEqual(transform1.event_dim, 1) - self.assertEqual(transform2.event_dim, 2) - self.assertEqual(ComposeTransform([transform0, transform1]).event_dim, 1) - self.assertEqual(ComposeTransform([transform0, transform2]).event_dim, 2) - self.assertEqual(ComposeTransform([transform1, transform2]).event_dim, 2) - - def test_transformed_distribution_shapes(self): - transform0 = ExpTransform() - transform1 = SoftmaxTransform() - transform2 = LowerCholeskyTransform() - base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4)) - base_dist1 = Dirichlet(torch.ones(4, 4)) - base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4)) - examples = [ - ((4, 4), (), base_dist0), - ((4,), (4,), base_dist1), - ((4, 4), (), TransformedDistribution(base_dist0, [transform0])), - ((4,), (4,), TransformedDistribution(base_dist0, [transform1])), - ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])), - ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])), - ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform0])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform1])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform2])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])), - ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])), - ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])), - ((3, 4, 4), (), base_dist2), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])), - ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])), - ] - for batch_shape, event_shape, dist in examples: - self.assertEqual(dist.batch_shape, batch_shape) - self.assertEqual(dist.event_shape, event_shape) - x = dist.rsample() - try: - dist.log_prob(x) # this should not crash - except NotImplementedError: - continue - - def test_jit_fwd(self): - for transform in self.unique_transforms: - x = self._generate_data(transform).requires_grad_() - - def f(x): - return transform(x) - - try: - traced_f = torch.jit.trace(f, (x,)) - except NotImplementedError: - continue - - # check on different inputs - x = self._generate_data(transform).requires_grad_() - self.assertEqual(f(x), traced_f(x)) - - def test_jit_inv(self): - for transform in self.unique_transforms: - y = self._generate_data(transform.inv).requires_grad_() - - def f(y): - return transform.inv(y) - - try: - traced_f = torch.jit.trace(f, (y,)) - except NotImplementedError: - continue - - # check on different inputs - y = self._generate_data(transform.inv).requires_grad_() - self.assertEqual(f(y), traced_f(y)) - - def test_jit_jacobian(self): - for transform in self.unique_transforms: - x = self._generate_data(transform).requires_grad_() - - def f(x): - y = transform(x) - return transform.log_abs_det_jacobian(x, y) - - try: - traced_f = torch.jit.trace(f, (x,)) - except NotImplementedError: - continue - - # check on different inputs - x = self._generate_data(transform).requires_grad_() - self.assertEqual(f(x), traced_f(x)) - - class TestFunctors(TestCase): def test_cat_transform(self): x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100) diff --git a/test/distributions/test_transforms.py b/test/distributions/test_transforms.py new file mode 100644 index 000000000000..b5e9144f0bd8 --- /dev/null +++ b/test/distributions/test_transforms.py @@ -0,0 +1,365 @@ +from numbers import Number + +import pytest + +import torch +from torch.autograd.functional import jacobian +from torch.distributions import Dirichlet, Normal, TransformedDistribution, constraints +from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform, + CorrCholeskyTransform, ExpTransform, + LowerCholeskyTransform, PowerTransform, + SigmoidTransform, TanhTransform, SoftmaxTransform, + StickBreakingTransform, identity_transform, Transform, + _InverseTransform) +from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix + + +def get_transforms(cache_size): + transforms = [ + AbsTransform(cache_size=cache_size), + ExpTransform(cache_size=cache_size), + PowerTransform(exponent=2, + cache_size=cache_size), + PowerTransform(exponent=torch.tensor(5.).normal_(), + cache_size=cache_size), + SigmoidTransform(cache_size=cache_size), + TanhTransform(cache_size=cache_size), + AffineTransform(0, 1, cache_size=cache_size), + AffineTransform(1, -2, cache_size=cache_size), + AffineTransform(torch.randn(5), + torch.randn(5), + cache_size=cache_size), + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + SoftmaxTransform(cache_size=cache_size), + StickBreakingTransform(cache_size=cache_size), + LowerCholeskyTransform(cache_size=cache_size), + CorrCholeskyTransform(cache_size=cache_size), + ComposeTransform([ + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + ]), + ComposeTransform([ + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + ExpTransform(cache_size=cache_size), + ]), + ComposeTransform([ + AffineTransform(0, 1, cache_size=cache_size), + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + AffineTransform(1, -2, cache_size=cache_size), + AffineTransform(torch.randn(4, 5), + torch.randn(4, 5), + cache_size=cache_size), + ]), + ] + transforms += [t.inv for t in transforms] + return transforms + + +def reshape_transform(transform, shape): + # Needed to squash batch dims for testing jacobian + if isinstance(transform, AffineTransform): + if isinstance(transform.loc, Number): + return transform + try: + return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size) + except RuntimeError: + return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size) + if isinstance(transform, ComposeTransform): + reshaped_parts = [] + for p in transform.parts: + reshaped_parts.append(reshape_transform(p, shape)) + return ComposeTransform(reshaped_parts, cache_size=transform._cache_size) + if isinstance(transform.inv, AffineTransform): + return reshape_transform(transform.inv, shape).inv + if isinstance(transform.inv, ComposeTransform): + return reshape_transform(transform.inv, shape).inv + return transform + + +# Generate pytest ids +def transform_id(x): + assert isinstance(x, Transform) + name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}' + return f'{name}(cache_size={x._cache_size})' + + +def generate_data(transform): + torch.manual_seed(1) + domain = transform.domain + codomain = transform.codomain + x = torch.empty(4, 5) + if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky: + x = torch.empty(6, 6) + x = x.normal_() + return x + elif domain is constraints.real: + return x.normal_() + elif domain is constraints.real_vector: + # For corr_cholesky the last dim in the vector + # must be of size (dim * dim) // 2 + x = torch.empty(3, 6) + x = x.normal_() + return x + elif domain is constraints.positive: + return x.normal_().exp() + elif domain is constraints.unit_interval: + return x.uniform_() + elif isinstance(domain, constraints.interval): + x = x.uniform_() + x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound) + return x + elif domain is constraints.simplex: + x = x.normal_().exp() + x /= x.sum(-1, True) + return x + elif domain is constraints.corr_cholesky: + x = torch.empty(4, 5, 5) + x = x.normal_().tril() + x /= x.norm(dim=-1, keepdim=True) + x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs()) + return x + raise ValueError('Unsupported domain: {}'.format(domain)) + + +TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1) +TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0) +ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform] + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +def test_inv_inv(transform, ids=transform_id): + assert transform.inv.inv is transform + + +@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_equality(x, y): + if x is y: + assert x == y + else: + assert x != y + assert identity_transform == identity_transform.inv + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +def test_with_cache(transform): + if transform._cache_size == 0: + transform = transform.with_cache(1) + assert transform._cache_size == 1 + x = generate_data(transform).requires_grad_() + try: + y = transform(x) + except NotImplementedError: + pytest.skip('Not implemented.') + y2 = transform(x) + assert y2 is y + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +@pytest.mark.parametrize('test_cached', [True, False]) +def test_forward_inverse(transform, test_cached): + x = generate_data(transform).requires_grad_() + try: + y = transform(x) + except NotImplementedError: + pytest.skip('Not implemented.') + if test_cached: + x2 = transform.inv(y) # should be implemented at least by caching + else: + try: + x2 = transform.inv(y.clone()) # bypass cache + except NotImplementedError: + pytest.skip('Not implemented.') + y2 = transform(x2) + if transform.bijective: + # verify function inverse + assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([ + '{} t.inv(t(-)) error'.format(transform), + 'x = {}'.format(x), + 'y = t(x) = {}'.format(y), + 'x2 = t.inv(y) = {}'.format(x2), + ]) + else: + # verify weaker function pseudo-inverse + assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([ + '{} t(t.inv(t(-))) error'.format(transform), + 'x = {}'.format(x), + 'y = t(x) = {}'.format(y), + 'x2 = t.inv(y) = {}'.format(x2), + 'y2 = t(x2) = {}'.format(y2), + ]) + + +def test_compose_transform_shapes(): + transform0 = ExpTransform() + transform1 = SoftmaxTransform() + transform2 = LowerCholeskyTransform() + + assert transform0.event_dim == 0 + assert transform1.event_dim == 1 + assert transform2.event_dim == 2 + assert ComposeTransform([transform0, transform1]).event_dim == 1 + assert ComposeTransform([transform0, transform2]).event_dim == 2 + assert ComposeTransform([transform1, transform2]).event_dim == 2 + + +transform0 = ExpTransform() +transform1 = SoftmaxTransform() +transform2 = LowerCholeskyTransform() +base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4)) +base_dist1 = Dirichlet(torch.ones(4, 4)) +base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4)) + + +@pytest.mark.parametrize('batch_shape, event_shape, dist', [ + ((4, 4), (), base_dist0), + ((4,), (4,), base_dist1), + ((4, 4), (), TransformedDistribution(base_dist0, [transform0])), + ((4,), (4,), TransformedDistribution(base_dist0, [transform1])), + ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])), + ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])), + ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform0])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform1])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform2])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])), + ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])), + ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])), + ((3, 4, 4), (), base_dist2), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])), + ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])), +]) +def test_transformed_distribution_shapes(batch_shape, event_shape, dist): + assert dist.batch_shape == batch_shape + assert dist.event_shape == event_shape + x = dist.rsample() + try: + dist.log_prob(x) # this should not crash + except NotImplementedError: + pytest.skip('Not implemented.') + + +@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_jit_fwd(transform): + x = generate_data(transform).requires_grad_() + + def f(x): + return transform(x) + + try: + traced_f = torch.jit.trace(f, (x,)) + except NotImplementedError: + pytest.skip('Not implemented.') + + # check on different inputs + x = generate_data(transform).requires_grad_() + assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) + + +@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_jit_inv(transform): + y = generate_data(transform.inv).requires_grad_() + + def f(y): + return transform.inv(y) + + try: + traced_f = torch.jit.trace(f, (y,)) + except NotImplementedError: + pytest.skip('Not implemented.') + + # check on different inputs + y = generate_data(transform.inv).requires_grad_() + assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True) + + +@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id) +def test_jit_jacobian(transform): + x = generate_data(transform).requires_grad_() + + def f(x): + y = transform(x) + return transform.log_abs_det_jacobian(x, y) + + try: + traced_f = torch.jit.trace(f, (x,)) + except NotImplementedError: + pytest.skip('Not implemented.') + + # check on different inputs + x = generate_data(transform).requires_grad_() + assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) + + +@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id) +def test_jacobian(transform): + x = generate_data(transform) + try: + y = transform(x) + actual = transform.log_abs_det_jacobian(x, y) + except NotImplementedError: + pytest.skip('Not implemented.') + # Test shape + target_shape = x.shape[:x.dim() - transform.input_event_dim] + assert actual.shape == target_shape + + # Expand if required + transform = reshape_transform(transform, x.shape) + ndims = len(x.shape) + event_dim = ndims - transform.input_event_dim + x_ = x.view((-1,) + x.shape[event_dim:]) + n = x_.shape[0] + # Reshape to squash batch dims to a single batch dim + transform = reshape_transform(transform, x_.shape) + + # 1. Transforms with 0 off-diagonal elements + if transform.input_event_dim == 0: + jac = jacobian(transform, x_) + # assert off-diagonal elements are zero + assert torch.allclose(jac, jac.diagonal().diag_embed()) + expected = jac.diagonal().abs().log().reshape(x.shape) + # 2. Transforms with non-0 off-diagonal elements + else: + if isinstance(transform, CorrCholeskyTransform): + jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) + elif isinstance(transform.inv, CorrCholeskyTransform): + jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)), + tril_matrix_to_vec(x_, diag=-1)) + elif isinstance(transform, StickBreakingTransform): + jac = jacobian(lambda x: transform(x)[..., :-1], x_) + else: + jac = jacobian(transform, x_) + + # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims) + # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims) + # after reshaping the event dims (see above) to give a batched square matrix whose determinant + # can be computed. + gather_idx_shape = list(jac.shape) + gather_idx_shape[-2] = 1 + gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape) + jac = jac.gather(-2, gather_idxs).squeeze(-2) + out_ndims = jac.shape[-2] + jac = jac[..., :out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking). + expected = torch.slogdet(jac).logabsdet + + assert torch.allclose(actual, expected, atol=1e-5) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/test/distributions/test_utils.py b/test/distributions/test_utils.py new file mode 100644 index 000000000000..b58cfe39fc1c --- /dev/null +++ b/test/distributions/test_utils.py @@ -0,0 +1,24 @@ +import pytest + +import torch +from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix + + +@pytest.mark.parametrize('shape', [ + (2, 2), + (3, 3), + (2, 4, 4), + (2, 2, 4, 4), +]) +def test_tril_matrix_to_vec(shape): + mat = torch.randn(shape) + n = mat.shape[-1] + for diag in range(-n + 1, n): + actual = mat.tril(diag) + vec = tril_matrix_to_vec(actual, diag) + tril_mat = vec_to_tril_matrix(vec, diag) + assert torch.allclose(tril_mat, actual) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/test/run_test.py b/test/run_test.py index 2bf1353ecd34..070b6103ab54 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -3,7 +3,6 @@ import argparse import copy from datetime import datetime -import importlib import modulefinder import os import shutil @@ -152,6 +151,9 @@ 'distributed/_pipeline/sync/test_stream', 'distributed/_pipeline/sync/test_transparency', 'distributed/_pipeline/sync/test_worker', + 'distributions/test_constraints', + 'distributions/test_transforms', + 'distributions/test_utils', ] WINDOWS_BLOCKLIST = [ @@ -188,11 +190,6 @@ 'test_cuda_primary_ctx', ] + [test for test in TESTS if test.startswith('distributed/')] -# These tests use some specific pytest feature like parameterized testing or -# fixtures that cannot be run by unittest -PYTEST_TESTS = [ - 'distributions/test_constraints' -] # These tests are slow enough that it's worth calculating whether the patch # touched any related files first. @@ -647,9 +644,6 @@ def get_selected_tests(options): options.exclude.extend(JIT_EXECUTOR_TESTS) selected_tests = exclude_tests(options.exclude, selected_tests) - # exclude PYTEST_TESTS if pytest not installed. - if importlib.util.find_spec('pytest') is None: - selected_tests = exclude_tests(PYTEST_TESTS, selected_tests, 'PyTest not found.') if sys.platform == 'win32' and not options.ignore_win_blocklist: target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH') diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 6587631c4cfe..4675b8ceaca8 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -215,6 +215,12 @@ def _transform_to_lower_cholesky(constraint): return transforms.LowerCholeskyTransform() +@biject_to.register(constraints.corr_cholesky) +@transform_to.register(constraints.corr_cholesky) +def _transform_to_corr_cholesky(constraint): + return transforms.CorrCholeskyTransform() + + @biject_to.register(constraints.cat) def _biject_to_cat(constraint): return transforms.CatTransform([biject_to(c) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 7bcbc586434d..630c192ffed0 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -26,6 +26,7 @@ 'Constraint', 'boolean', 'cat', + 'corr_cholesky', 'dependent', 'dependent_property', 'greater_than', @@ -275,6 +276,18 @@ def check(self, value): return lower_triangular & positive_diagonal +class _CorrCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals and each + row vector being of unit length. + """ + def check(self, value): + tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor + row_norm = torch.linalg.norm(value.detach(), dim=-1) + unit_row_norm = (row_norm - 1.).abs().le(tol).all(dim=-1) + return _LowerCholesky().check(value) & unit_row_norm + + class _PositiveDefinite(Constraint): """ Constrain to positive-definite matrices. @@ -360,6 +373,7 @@ def check(self, value): simplex = _Simplex() lower_triangular = _LowerTriangular() lower_cholesky = _LowerCholesky() +corr_cholesky = _CorrCholesky() positive_definite = _PositiveDefinite() cat = _Cat stack = _Stack diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index f4de4b15b0bb..a0412d52df0d 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -6,7 +6,8 @@ import torch.nn.functional as F from torch.distributions import constraints from torch.distributions.utils import (_sum_rightmost, broadcast_all, - lazy_property) + lazy_property, tril_matrix_to_vec, + vec_to_tril_matrix) from torch.nn.functional import pad from torch.nn.functional import softplus from typing import List @@ -16,6 +17,7 @@ 'AffineTransform', 'CatTransform', 'ComposeTransform', + 'CorrCholeskyTransform', 'ExpTransform', 'LowerCholeskyTransform', 'PowerTransform', @@ -92,6 +94,14 @@ def __init__(self, cache_size=0): raise ValueError('cache_size must be 0 or 1') super(Transform, self).__init__() + @property + def input_event_dim(self): + return self.event_dim + + @property + def output_event_dim(self): + return self.event_dim + @property def inv(self): """ @@ -195,6 +205,16 @@ def codomain(self): assert self._inv is not None return self._inv.domain + @property + def input_event_dim(self): + assert self._inv is not None + return self._inv.output_event_dim + + @property + def output_event_dim(self): + assert self._inv is not None + return self._inv.input_event_dim + @property def bijective(self): assert self._inv is not None @@ -535,6 +555,74 @@ def log_abs_det_jacobian(self, x, y): return result.expand(shape) +class CorrCholeskyTransform(Transform): + r""" + Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the + Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower + triangular matrix with positive diagonals and unit Euclidean norm for each row. + The transform is processed as follows: + + 1. First we convert x into a lower triangular matrix in row order. + 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of + class :class:`StickBreakingTransform` to transform :math:`X_i` into a + unit Euclidean length vector using the following steps: + - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. + - Transforms into an unsigned domain: :math:`z_i = r_i^2`. + - Applies :math:`s_i = StickBreakingTransform(z_i)`. + - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. + """ + domain = constraints.real_vector + codomain = constraints.corr_cholesky + input_event_dim = 1 + output_event_dim = 2 + bijective = True + + @property + def event_dim(self): + raise ValueError("Please use `.input_event_dim` or `.output_event_dim` instead.") + + def _call(self, x): + x = torch.tanh(x) + eps = torch.finfo(x.dtype).eps + x = x.clamp(min=-1 + eps, max=1 - eps) + r = vec_to_tril_matrix(x, diag=-1) + # apply stick-breaking on the squared values + # Note that y = sign(r) * sqrt(z * z1m_cumprod) + # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) + z = r ** 2 + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) + # Diagonal elements must be 1. + r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) + y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) + return y + + def _inverse(self, y): + # inverse stick-breaking + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y_cumsum = 1 - torch.cumsum(y * y, dim=-1) + y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) + y_vec = tril_matrix_to_vec(y, diag=-1) + y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) + t = y_vec / (y_cumsum_vec).sqrt() + # inverse of tanh + x = ((1 + t) / (1 - t)).log() / 2 + return x + + def log_abs_det_jacobian(self, x, y, intermediates=None): + # Because domain and codomain are two spaces with different dimensions, determinant of + # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the + # flattened lower triangular part of `y`. + + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y1m_cumsum = 1 - (y * y).cumsum(dim=-1) + # by taking diagonal=-2, we don't need to shift z_cumprod to the right + # also works for 2 x 2 matrix + y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) + stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) + tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1) + return stick_breaking_logdet + tanh_logdet + + class SoftmaxTransform(Transform): r""" Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 36ff1f71c35b..05500f22c344 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -108,3 +108,36 @@ def __get__(self, instance, obj_type=None): value = self.wrapped(instance) setattr(instance, self.wrapped.__name__, value) return value + + +def tril_matrix_to_vec(mat, diag=0): + r""" + Convert a `D x D` matrix or a batch of matrices into a (batched) vector + which comprises of lower triangular elements from the matrix in row order. + """ + n = mat.shape[-1] + if not torch._C._get_tracing_state() and (diag <= -n or diag >= n): + raise ValueError(f'diag ({diag}) provided is outside [{-n+1}, {n-1}].') + arange = torch.arange(n, device=mat.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + vec = mat[..., tril_mask] + return vec + + +def vec_to_tril_matrix(vec, diag=0): + r""" + Convert a vector or a batch of vectors into a batched `D x D` + lower triangular matrix containing elements from the vector in row order. + """ + # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0 + n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2 + eps = torch.finfo(vec.dtype).eps + if not torch._C._get_tracing_state() and (round(n) - n > eps): + raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' + + 'the lower triangular part of a square D x D matrix.') + n = torch.round(n).long() if isinstance(n, torch.Tensor) else round(n) + mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n))) + arange = torch.arange(n, device=vec.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + mat[..., tril_mask] = vec + return mat