Skip to content

Commit

Permalink
[Reinstate] Wishart distribution (#70377)
Browse files Browse the repository at this point in the history
Summary:
Implement #68050
Reopened merged and reverted PR #68588 worked with neerajprad
cc neerajprad

Sorry for the confusion.

TODO:

- [x] Unit Test
- [x] Documentation
- [x] Change constraint of matrix variables with 'torch.distributions.constraints.symmetric' if it is reviewed and merged. Debug positive definite constraints #68720

Pull Request resolved: #70377

Reviewed By: mikaylagawarecki

Differential Revision: D33355132

Pulled By: neerajprad

fbshipit-source-id: e968c0d9a3061fb2855564b96074235e46a57b6c
  • Loading branch information
nonconvexopt authored and facebook-github-bot committed Dec 30, 2021
1 parent 14d3d29 commit bc40fb5
Show file tree
Hide file tree
Showing 4 changed files with 506 additions and 5 deletions.
9 changes: 9 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,15 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:

:hidden:`Wishart`
~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torch.distributions.wishart
.. autoclass:: Wishart
:members:
:undoc-members:
:show-inheritance:

`KL Divergence`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
219 changes: 214 additions & 5 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
OneHotCategorical, OneHotCategoricalStraightThrough,
Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform,
VonMises, Weibull, constraints, kl_divergence)
VonMises, Weibull, Wishart, constraints, kl_divergence)
from torch.distributions.constraint_registry import transform_to
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.dirichlet import _Dirichlet_backward
Expand Down Expand Up @@ -473,6 +473,32 @@ def is_all_nan(tensor):
'concentration': torch.randn(1).abs().requires_grad_()
}
]),
Example(Wishart, [
{
'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True),
'df': torch.tensor([4.], requires_grad=True),
},
{
'precision_matrix': torch.tensor([[2.0, 0.1, 0.0],
[0.1, 0.25, 0.0],
[0.0, 0.0, 0.3]], requires_grad=True),
'df': torch.tensor([2.5, 3], requires_grad=True),
},
{
'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]],
[[2.0, 0.0], [0.3, 0.25]],
[[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True),
'df': torch.tensor([5., 3.5, 2], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
'df': torch.tensor([2.0]),
},
{
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
'df': 2.0,
},
]),
Example(MixtureSameFamily, [
{
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
Expand Down Expand Up @@ -740,6 +766,20 @@ def is_all_nan(tensor):
'concentration': torch.tensor([-1.0], requires_grad=True)
}
]),
Example(Wishart, [
{
'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True),
'df': torch.tensor([1.5], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
'df': torch.tensor([3.], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
'df': 3.,
},
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.5], requires_grad=True)},
Expand Down Expand Up @@ -792,10 +832,10 @@ def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=Fal
ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
if multivariate:
# Project onto a random axis.
axis = np.random.normal(size=torch_samples.shape[-1])
axis = np.random.normal(size=(1,) + torch_samples.shape[1:])
axis /= np.linalg.norm(axis)
torch_samples = np.dot(torch_samples, axis)
ref_samples = np.dot(ref_samples, axis)
torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1)
ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1)
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
if circular:
samples = [(np.cos(x), v) for (x, v) in samples]
Expand Down Expand Up @@ -2168,6 +2208,148 @@ def test_multivariate_normal_moments(self):
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)

# We applied same tests in Multivariate Normal distribution for Wishart distribution
def test_wishart_shape(self):
df = (torch.rand(5, requires_grad=True) + 1) * 10
df_no_batch = (torch.rand([], requires_grad=True) + 1) * 10
df_multi_batch = (torch.rand(6, 5, requires_grad=True) + 1) * 10

# construct PSD covariance
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()

# construct batch of PSD covariances
tmp = torch.randn(6, 5, 3, 10)
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
prec_batched = cov_batched.inverse()
scale_tril_batched = torch.linalg.cholesky(cov_batched)

# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(Wishart(df, cov).sample().size(), (5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (3, 3))
self.assertEqual(Wishart(df_multi_batch, cov).sample().size(), (6, 5, 3, 3))
self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, 3, 3))
self.assertEqual(Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3, 3))
self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, 3, 3))
self.assertEqual(Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3, 3))
self.assertEqual(Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3, 3))
self.assertEqual(Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))

# check gradients
# Modified and applied the same tests for multivariate_normal
def wishart_log_prob_gradcheck(df=None, covariance=None, precision=None, scale_tril=None):
wishart_samples = Wishart(df, covariance, precision, scale_tril).sample().requires_grad_()

def gradcheck_func(samples, nu, sigma, prec, scale_tril):
if sigma is not None:
sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance
if prec is not None:
prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision
if scale_tril is not None:
scale_tril = scale_tril.tril()
return Wishart(nu, sigma, prec, scale_tril).log_prob(samples)
gradcheck(gradcheck_func, (wishart_samples, df, covariance, precision, scale_tril), raise_exception=True)

wishart_log_prob_gradcheck(df, cov)
wishart_log_prob_gradcheck(df_multi_batch, cov)
wishart_log_prob_gradcheck(df_multi_batch, cov_batched)
wishart_log_prob_gradcheck(df, None, prec)
wishart_log_prob_gradcheck(df_no_batch, None, prec_batched)
wishart_log_prob_gradcheck(df, None, None, scale_tril)
wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched)

def test_wishart_stable_with_precision_matrix(self):
x = torch.randn(10)
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel
Wishart(torch.tensor(10), precision_matrix=P)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_wishart_log_prob(self):
df = (torch.rand([], requires_grad=True) + 1) * 10
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()

# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = Wishart(df, cov)
dist2 = Wishart(df, precision_matrix=prec)
dist3 = Wishart(df, scale_tril=scale_tril)
ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())

x = dist1.sample((10,))
expected = ref_dist.logpdf(x.transpose(0, 2).numpy())

self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)

# Double-check that batched versions behave the same as unbatched
df = (torch.rand(5, requires_grad=True) + 1) * 3
tmp = torch.randn(5, 3, 10)
cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()

dist_batched = Wishart(df, cov)
dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))]

x = dist_batched.sample((10,))
batched_prob = dist_batched.log_prob(x)
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()

self.assertEqual(batched_prob.shape, unbatched_prob.shape)
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_wishart_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
df = (torch.rand([], requires_grad=True) + 1) * 3
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()

self._check_sampler_sampler(Wishart(df, cov),
scipy.stats.wishart(df.item(), cov.detach().numpy()),
'Wishart(df={}, covariance_matrix={})'.format(df, cov),
multivariate=True)
self._check_sampler_sampler(Wishart(df, precision_matrix=prec),
scipy.stats.wishart(df.item(), cov.detach().numpy()),
'Wishart(df={}, precision_matrix={})'.format(df, prec),
multivariate=True)
self._check_sampler_sampler(Wishart(df, scale_tril=scale_tril),
scipy.stats.wishart(df.item(), cov.detach().numpy()),
'Wishart(df={}, scale_tril={})'.format(df, scale_tril),
multivariate=True)

def test_wishart_properties(self):
df = (torch.rand([]) + 1) * 5
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
m = Wishart(df=df, scale_tril=scale_tril)
self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))

def test_wishart_moments(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
df = (torch.rand([]) + 1) * 3
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(3, 3))
d = Wishart(df=df, scale_tril=scale_tril)
samples = d.rsample((100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=5, rtol=0)
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=5, rtol=0)

def test_exponential(self):
rate = torch.randn(5, 5).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
Expand Down Expand Up @@ -3487,6 +3669,23 @@ def test_weibull_scale_scalar_params(self):
self.assertEqual(weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))

def test_wishart_shape_scalar_params(self):
wishart = Wishart(torch.tensor(1), torch.tensor([[1.]]))
self.assertEqual(wishart._batch_shape, torch.Size())
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
self.assertEqual(wishart.sample().size(), torch.Size((1, 1)))
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1)))
self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample)

def test_wishart_shape_tensor_params(self):
wishart = Wishart(torch.tensor([1., 1.]), torch.tensor([[[1.]], [[1.]]]))
self.assertEqual(wishart._batch_shape, torch.Size((2,)))
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1)))
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1)))
self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2)
self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,)))

def test_normal_shape_scalar_params(self):
normal = Normal(0, 1)
self.assertEqual(normal._batch_shape, torch.Size())
Expand Down Expand Up @@ -4305,6 +4504,8 @@ def setUp(self):
positive_var2 = torch.randn(20).exp()
random_var = torch.randn(20)
simplex_tensor = softmax(torch.randn(20), dim=-1)
cov_tensor = torch.randn(20, 20)
cov_tensor = cov_tensor @ cov_tensor.mT
self.distribution_pairs = [
(
Bernoulli(simplex_tensor),
Expand Down Expand Up @@ -4375,6 +4576,10 @@ def setUp(self):
MultivariateNormal(random_var, torch.diag(positive_var2)),
scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
),
(
MultivariateNormal(random_var, cov_tensor),
scipy.stats.multivariate_normal(random_var, cov_tensor)
),
(
Normal(random_var, positive_var2),
scipy.stats.norm(random_var, positive_var2)
Expand Down Expand Up @@ -4406,7 +4611,11 @@ def setUp(self):
(
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
)
),
(
Wishart(20 + positive_var[0], cov_tensor), # scipy var for Wishart only supports scalars
scipy.stats.wishart(20 + positive_var[0].item(), cov_tensor),
),
]

def test_mean(self):
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull
from .wishart import Wishart
from . import transforms

__all__ = [
Expand Down Expand Up @@ -155,6 +156,7 @@
'Uniform',
'VonMises',
'Weibull',
'Wishart',
'TransformedDistribution',
'biject_to',
'kl_divergence',
Expand Down

0 comments on commit bc40fb5

Please sign in to comment.