diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index aebc39038368..fe09626e60d8 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -167,6 +167,15 @@ Probability distributions - torch.distributions :undoc-members: :show-inheritance: +:hidden:`Kumaraswamy` +~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torch.distributions.kumaraswamy +.. autoclass:: Kumaraswamy + :members: + :undoc-members: + :show-inheritance: + :hidden:`Laplace` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 68f5922753a3..d75f21740435 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -44,7 +44,7 @@ Distribution, Exponential, ExponentialFamily, FisherSnedecor, Gamma, Geometric, Gumbel, HalfCauchy, HalfNormal, - Independent, Laplace, LogisticNormal, + Independent, Kumaraswamy, Laplace, LogisticNormal, LogNormal, LowRankMultivariateNormal, MixtureSameFamily, Multinomial, MultivariateNormal, NegativeBinomial, Normal, OneHotCategorical, Pareto, @@ -240,6 +240,16 @@ def is_all_nan(tensor): 'reinterpreted_batch_ndims': 3, }, ]), + Example(Kumaraswamy, [ + { + 'concentration1': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), + 'concentration0': torch.empty(2, 3).uniform_(1, 2).requires_grad_(), + }, + { + 'concentration1': torch.rand(4).uniform_(1, 2).requires_grad_(), + 'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(), + }, + ]), Example(Laplace, [ { 'loc': torch.randn(5, 5, requires_grad=True), @@ -2249,6 +2259,42 @@ def test_gumbel_sample(self): scipy.stats.gumbel_r(loc=loc, scale=scale), 'Gumbel(loc={}, scale={})'.format(loc, scale)) + def test_kumaraswamy_shape(self): + concentration1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) + concentration0 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) + concentration1_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True) + concentration0_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True) + self.assertEqual(Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3)) + self.assertEqual(Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3)) + self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,)) + self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(), (1, 1)) + self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ()) + self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,)) + + # Kumaraswamy distribution is not implemented in SciPy + # Hence these tests are explicit + def test_kumaraswamy_mean_variance(self): + c1_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) + c0_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True) + c1_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True) + c0_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True) + cases = [(c1_1, c0_1), (c1_2, c0_2)] + for i, (a, b) in enumerate(cases): + m = Kumaraswamy(a, b) + samples = m.sample((60000, )) + expected = samples.mean(0) + actual = m.mean + error = (expected - actual).abs() + max_error = max(error[error == error]) + self.assertLess(max_error, 0.01, + "Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases))) + expected = samples.var(0) + actual = m.variance + error = (expected - actual).abs() + max_error = max(error[error == error]) + self.assertLess(max_error, 0.01, + "Kumaraswamy example {}/{}, incorrect .variance".format(i + 1, len(cases))) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_fishersnedecor(self): df1 = torch.randn(2, 3).abs().requires_grad_() @@ -2622,6 +2668,18 @@ def test_valid_parameter_broadcasting(self): (1, 2)), (Gumbel(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])), (1, 1)), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=1.), + (2,)), + (Kumaraswamy(concentration1=1, concentration0=torch.tensor([1., 1.])), + (2, )), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([1.])), + (2,)), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.], [1.]])), + (2, 2)), + (Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.]])), + (1, 2)), + (Kumaraswamy(concentration1=torch.tensor([1.]), concentration0=torch.tensor([[1.]])), + (1, 1)), (Laplace(loc=torch.tensor([0., 0.]), scale=1), (2,)), (Laplace(loc=0, scale=torch.tensor([1., 1.])), @@ -2701,6 +2759,14 @@ def test_invalid_parameter_broadcasting(self): 'concentration': torch.tensor([0, 0]), 'rate': torch.tensor([1, 1, 1]) }), + (Kumaraswamy, { + 'concentration1': torch.tensor([[1, 1]]), + 'concentration0': torch.tensor([1, 1, 1, 1]) + }), + (Kumaraswamy, { + 'concentration1': torch.tensor([[[1, 1, 1], [1, 1, 1]]]), + 'concentration0': torch.tensor([1, 1]) + }), (Laplace, { 'loc': torch.tensor([0, 0]), 'scale': torch.tensor([1, 1, 1]) @@ -3242,6 +3308,15 @@ def test_gumbel_shape_scalar_params(self): self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_kumaraswamy_shape_scalar_params(self): + kumaraswamy = Kumaraswamy(1, 1) + self.assertEqual(kumaraswamy._batch_shape, torch.Size()) + self.assertEqual(kumaraswamy._event_shape, torch.Size()) + self.assertEqual(kumaraswamy.sample().size(), torch.Size()) + self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2))) + self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) + self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_vonmises_shape_tensor_params(self): von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.])) self.assertEqual(von_mises._batch_shape, torch.Size((2,))) diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index ffcf75695d2f..57408f0c03f0 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -91,6 +91,7 @@ from .half_normal import HalfNormal from .independent import Independent from .kl import kl_divergence, register_kl +from .kumaraswamy import Kumaraswamy from .laplace import Laplace from .log_normal import LogNormal from .logistic_normal import LogisticNormal @@ -132,6 +133,7 @@ 'HalfCauchy', 'HalfNormal', 'Independent', + 'Kumaraswamy', 'Laplace', 'LogNormal', 'LogisticNormal', diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 5bd3a2d3bd1e..a569af34ebdc 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -5,9 +5,7 @@ from torch.distributions.uniform import Uniform from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import AffineTransform, ExpTransform -from torch.distributions.utils import broadcast_all - -euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant +from torch.distributions.utils import broadcast_all, euler_constant class Gumbel(TransformedDistribution): diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index fe64ccc56009..ba7ba73d6063 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -31,7 +31,7 @@ from .poisson import Poisson from .transformed_distribution import TransformedDistribution from .uniform import Uniform -from .utils import _sum_rightmost +from .utils import _sum_rightmost, euler_constant as _euler_gamma _KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions. _KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions. @@ -174,8 +174,6 @@ def kl_divergence(p, q): # KL Divergence Implementations ################################################################################ -_euler_gamma = 0.57721566490153286060 - # Same distributions diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py new file mode 100644 index 000000000000..4fb2e177e7be --- /dev/null +++ b/torch/distributions/kumaraswamy.py @@ -0,0 +1,66 @@ +import torch +from torch.distributions import constraints +from torch.distributions.uniform import Uniform +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.utils import broadcast_all, euler_constant + + +def _moments(a, b, n): + """ + Computes nth moment of Kumaraswamy using using torch.lgamma + """ + arg1 = 1 + n / a + log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b) + return b * torch.exp(log_value) + + +class Kumaraswamy(TransformedDistribution): + r""" + Samples from a Kumaraswamy distribution. + + Example:: + + >>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0])) + >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 + tensor([ 0.1729]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive} + support = constraints.unit_interval + has_rsample = True + + def __init__(self, concentration1, concentration0, validate_args=None): + self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0) + finfo = torch.finfo(self.concentration0.dtype) + base_dist = Uniform(torch.full_like(self.concentration0, 0), + torch.full_like(self.concentration0, 1)) + transforms = [PowerTransform(exponent=self.concentration0.reciprocal()), + AffineTransform(loc=1., scale=-1.), + PowerTransform(exponent=self.concentration1.reciprocal())] + super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Kumaraswamy, _instance) + new.concentration1 = self.concentration1.expand(batch_shape) + new.concentration0 = self.concentration0.expand(batch_shape) + return super(Kumaraswamy, self).expand(batch_shape, _instance=new) + + @property + def mean(self): + return _moments(self.concentration1, self.concentration0, 1) + + @property + def variance(self): + return _moments(self.concentration1, self.concentration0, 2) - torch.pow(self.mean, 2) + + def entropy(self): + t1 = (1 - self.concentration1.reciprocal()) + t0 = (1 - self.concentration0.reciprocal()) + H0 = torch.digamma(self.concentration0 + 1) + euler_constant + return t0 + t1 * H0 - torch.log(self.concentration1) - torch.log(self.concentration0) diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 0fd623086562..36ff1f71c35b 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -5,6 +5,9 @@ from typing import Dict, Any +euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant + + def broadcast_all(*values): r""" Given a list of values (possibly containing numbers), returns a list where each