Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CorrCholeskyTransform #48041

Closed
wants to merge 13 commits into from
322 changes: 2 additions & 320 deletions test/distributions/test_distributions.py
Expand Up @@ -53,15 +53,10 @@
VonMises, Weibull, constraints, kl_divergence)
from torch.distributions.constraint_registry import transform_to
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.transforms import (AffineTransform, CatTransform, ExpTransform,
StackTransform, identity_transform)
neerajprad marked this conversation as resolved.
Show resolved Hide resolved
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.kl import _kl_expfamily_expfamily
from torch.distributions.transforms import (AbsTransform, AffineTransform,
CatTransform, ComposeTransform, ExpTransform,
LowerCholeskyTransform,
PowerTransform, SigmoidTransform,
TanhTransform, SoftmaxTransform,
StickBreakingTransform,
identity_transform, StackTransform)
from torch.distributions.utils import probs_to_logits, lazy_property
from torch.nn.functional import softmax

Expand Down Expand Up @@ -4207,319 +4202,6 @@ def test_icdf(self):
self.assertEqual(icdf, scipy_dist.ppf(samples), msg=pytorch_dist)


class TestTransforms(TestCase):
def setUp(self):
super(TestTransforms, self).setUp()
self.transforms = []
transforms_by_cache_size = {}
for cache_size in [0, 1]:
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
PowerTransform(exponent=2,
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
SoftmaxTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
]
for t in transforms[:]:
transforms.append(t.inv)
transforms.append(identity_transform)
self.transforms += transforms
if cache_size == 0:
self.unique_transforms = transforms[:]

def _generate_data(self, transform):
domain = transform.domain
codomain = transform.codomain
x = torch.empty(4, 5)
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
x = torch.empty(6, 6)
x = x.normal_()
return x
elif domain is constraints.real:
return x.normal_()
elif domain is constraints.positive:
return x.normal_().exp()
elif domain is constraints.unit_interval:
return x.uniform_()
elif isinstance(domain, constraints.interval):
x = x.uniform_()
x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
return x
elif domain is constraints.simplex:
x = x.normal_().exp()
x /= x.sum(-1, True)
return x
raise ValueError('Unsupported domain: {}'.format(domain))

def test_inv_inv(self):
for t in self.transforms:
self.assertTrue(t.inv.inv is t)

def test_equality(self):
transforms = self.unique_transforms
for x, y in product(transforms, transforms):
if x is y:
self.assertTrue(x == y)
self.assertFalse(x != y)
else:
self.assertFalse(x == y)
self.assertTrue(x != y)

self.assertTrue(identity_transform == identity_transform.inv)
self.assertFalse(identity_transform != identity_transform.inv)

def test_with_cache(self):
for transform in self.transforms:
if transform._cache_size == 0:
transform = transform.with_cache(1)
self.assertTrue(transform._cache_size == 1)

x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
continue
y2 = transform(x)
self.assertTrue(y2 is y)

def test_forward_inverse_cache(self):
for transform in self.transforms:
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
continue
x2 = transform.inv(y) # should be implemented at least by caching
y2 = transform(x2) # should be implemented at least by caching
if transform.bijective:
# verify function inverse
self.assertEqual(x2, x, msg='\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
]))
else:
# verify weaker function pseudo-inverse
self.assertEqual(y2, y, msg='\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
]))

def test_forward_inverse_no_cache(self):
for transform in self.transforms:
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
x2 = transform.inv(y.clone()) # bypass cache
y2 = transform(x2)
except NotImplementedError:
continue
if transform.bijective:
# verify function inverse
self.assertEqual(x2, x, msg='\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
]))
else:
# verify weaker function pseudo-inverse
self.assertEqual(y2, y, msg='\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
]))

def test_univariate_forward_jacobian(self):
for transform in self.transforms:
if transform.event_dim > 0:
continue
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
expected = torch.abs(grad([y.sum()], [x])[0]).log()
self.assertEqual(actual, expected, msg='\n'.join([
'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform),
'Expected: {}'.format(expected),
'Actual: {}'.format(actual),
]))

def test_univariate_inverse_jacobian(self):
for transform in self.transforms:
if transform.event_dim > 0:
continue
y = self._generate_data(transform.inv).requires_grad_()
try:
x = transform.inv(y)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
expected = -torch.abs(grad([x.sum()], [y])[0]).log()
self.assertEqual(actual, expected, msg='\n'.join([
'{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform),
'Expected: {}'.format(expected),
'Actual: {}'.format(actual),
]))

def test_jacobian_shape(self):
for transform in self.transforms:
x = self._generate_data(transform)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
continue
self.assertEqual(actual.shape, x.shape[:x.dim() - transform.event_dim])

def test_transform_shapes(self):
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()

self.assertEqual(transform0.event_dim, 0)
self.assertEqual(transform1.event_dim, 1)
self.assertEqual(transform2.event_dim, 2)
self.assertEqual(ComposeTransform([transform0, transform1]).event_dim, 1)
self.assertEqual(ComposeTransform([transform0, transform2]).event_dim, 2)
self.assertEqual(ComposeTransform([transform1, transform2]).event_dim, 2)

def test_transformed_distribution_shapes(self):
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
examples = [
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
]
for batch_shape, event_shape, dist in examples:
self.assertEqual(dist.batch_shape, batch_shape)
self.assertEqual(dist.event_shape, event_shape)
x = dist.rsample()
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
continue

def test_jit_fwd(self):
for transform in self.unique_transforms:
x = self._generate_data(transform).requires_grad_()

def f(x):
return transform(x)

try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
continue

# check on different inputs
x = self._generate_data(transform).requires_grad_()
self.assertEqual(f(x), traced_f(x))

def test_jit_inv(self):
for transform in self.unique_transforms:
y = self._generate_data(transform.inv).requires_grad_()

def f(y):
return transform.inv(y)

try:
traced_f = torch.jit.trace(f, (y,))
except NotImplementedError:
continue

# check on different inputs
y = self._generate_data(transform.inv).requires_grad_()
self.assertEqual(f(y), traced_f(y))

def test_jit_jacobian(self):
for transform in self.unique_transforms:
x = self._generate_data(transform).requires_grad_()

def f(x):
y = transform(x)
return transform.log_abs_det_jacobian(x, y)

try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
continue

# check on different inputs
x = self._generate_data(transform).requires_grad_()
self.assertEqual(f(x), traced_f(x))


class TestFunctors(TestCase):
def test_cat_transform(self):
x1 = -1 * torch.arange(1, 101, dtype=torch.float).view(-1, 100)
Expand Down