Skip to content

Commit

Permalink
Implement Kumaraswamy Distribution (#48285)
Browse files Browse the repository at this point in the history
Summary:
This PR implements the Kumaraswamy distribution.

cc: fritzo alicanb sdaulton

Pull Request resolved: #48285

Reviewed By: ejguan

Differential Revision: D25221015

Pulled By: ezyang

fbshipit-source-id: e621b25a9c75671bdfc94af145a4d9de2f07231e
  • Loading branch information
vishwakftw authored and facebook-github-bot committed Dec 2, 2020
1 parent 9c6979a commit 47db191
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 7 deletions.
9 changes: 9 additions & 0 deletions docs/source/distributions.rst
Expand Up @@ -167,6 +167,15 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:

:hidden:`Kumaraswamy`
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torch.distributions.kumaraswamy
.. autoclass:: Kumaraswamy
:members:
:undoc-members:
:show-inheritance:

:hidden:`Laplace`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
77 changes: 76 additions & 1 deletion test/distributions/test_distributions.py
Expand Up @@ -44,7 +44,7 @@
Distribution, Exponential, ExponentialFamily,
FisherSnedecor, Gamma, Geometric, Gumbel,
HalfCauchy, HalfNormal,
Independent, Laplace, LogisticNormal,
Independent, Kumaraswamy, Laplace, LogisticNormal,
LogNormal, LowRankMultivariateNormal,
MixtureSameFamily, Multinomial, MultivariateNormal,
NegativeBinomial, Normal, OneHotCategorical, Pareto,
Expand Down Expand Up @@ -240,6 +240,16 @@ def is_all_nan(tensor):
'reinterpreted_batch_ndims': 3,
},
]),
Example(Kumaraswamy, [
{
'concentration1': torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
'concentration0': torch.empty(2, 3).uniform_(1, 2).requires_grad_(),
},
{
'concentration1': torch.rand(4).uniform_(1, 2).requires_grad_(),
'concentration0': torch.rand(4).uniform_(1, 2).requires_grad_(),
},
]),
Example(Laplace, [
{
'loc': torch.randn(5, 5, requires_grad=True),
Expand Down Expand Up @@ -2249,6 +2259,42 @@ def test_gumbel_sample(self):
scipy.stats.gumbel_r(loc=loc, scale=scale),
'Gumbel(loc={}, scale={})'.format(loc, scale))

def test_kumaraswamy_shape(self):
concentration1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
concentration0 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
concentration1_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True)
concentration0_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Kumaraswamy(concentration1, concentration0).sample().size(), (2, 3))
self.assertEqual(Kumaraswamy(concentration1, concentration0).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample().size(), (1,))
self.assertEqual(Kumaraswamy(concentration1_1d, concentration0_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Kumaraswamy(1.0, 1.0).sample().size(), ())
self.assertEqual(Kumaraswamy(1.0, 1.0).sample((1,)).size(), (1,))

# Kumaraswamy distribution is not implemented in SciPy
# Hence these tests are explicit
def test_kumaraswamy_mean_variance(self):
c1_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
c0_1 = torch.tensor(torch.randn(2, 3).abs(), requires_grad=True)
c1_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True)
c0_2 = torch.tensor(torch.randn(4).abs(), requires_grad=True)
cases = [(c1_1, c0_1), (c1_2, c0_2)]
for i, (a, b) in enumerate(cases):
m = Kumaraswamy(a, b)
samples = m.sample((60000, ))
expected = samples.mean(0)
actual = m.mean
error = (expected - actual).abs()
max_error = max(error[error == error])
self.assertLess(max_error, 0.01,
"Kumaraswamy example {}/{}, incorrect .mean".format(i + 1, len(cases)))
expected = samples.var(0)
actual = m.variance
error = (expected - actual).abs()
max_error = max(error[error == error])
self.assertLess(max_error, 0.01,
"Kumaraswamy example {}/{}, incorrect .variance".format(i + 1, len(cases)))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_fishersnedecor(self):
df1 = torch.randn(2, 3).abs().requires_grad_()
Expand Down Expand Up @@ -2622,6 +2668,18 @@ def test_valid_parameter_broadcasting(self):
(1, 2)),
(Gumbel(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
(1, 1)),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=1.),
(2,)),
(Kumaraswamy(concentration1=1, concentration0=torch.tensor([1., 1.])),
(2, )),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([1.])),
(2,)),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.], [1.]])),
(2, 2)),
(Kumaraswamy(concentration1=torch.tensor([1., 1.]), concentration0=torch.tensor([[1.]])),
(1, 2)),
(Kumaraswamy(concentration1=torch.tensor([1.]), concentration0=torch.tensor([[1.]])),
(1, 1)),
(Laplace(loc=torch.tensor([0., 0.]), scale=1),
(2,)),
(Laplace(loc=0, scale=torch.tensor([1., 1.])),
Expand Down Expand Up @@ -2701,6 +2759,14 @@ def test_invalid_parameter_broadcasting(self):
'concentration': torch.tensor([0, 0]),
'rate': torch.tensor([1, 1, 1])
}),
(Kumaraswamy, {
'concentration1': torch.tensor([[1, 1]]),
'concentration0': torch.tensor([1, 1, 1, 1])
}),
(Kumaraswamy, {
'concentration1': torch.tensor([[[1, 1, 1], [1, 1, 1]]]),
'concentration0': torch.tensor([1, 1])
}),
(Laplace, {
'loc': torch.tensor([0, 0]),
'scale': torch.tensor([1, 1, 1])
Expand Down Expand Up @@ -3242,6 +3308,15 @@ def test_gumbel_shape_scalar_params(self):
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))

def test_kumaraswamy_shape_scalar_params(self):
kumaraswamy = Kumaraswamy(1, 1)
self.assertEqual(kumaraswamy._batch_shape, torch.Size())
self.assertEqual(kumaraswamy._event_shape, torch.Size())
self.assertEqual(kumaraswamy.sample().size(), torch.Size())
self.assertEqual(kumaraswamy.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(kumaraswamy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))

def test_vonmises_shape_tensor_params(self):
von_mises = VonMises(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
self.assertEqual(von_mises._batch_shape, torch.Size((2,)))
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Expand Up @@ -91,6 +91,7 @@
from .half_normal import HalfNormal
from .independent import Independent
from .kl import kl_divergence, register_kl
from .kumaraswamy import Kumaraswamy
from .laplace import Laplace
from .log_normal import LogNormal
from .logistic_normal import LogisticNormal
Expand Down Expand Up @@ -132,6 +133,7 @@
'HalfCauchy',
'HalfNormal',
'Independent',
'Kumaraswamy',
'Laplace',
'LogNormal',
'LogisticNormal',
Expand Down
4 changes: 1 addition & 3 deletions torch/distributions/gumbel.py
Expand Up @@ -5,9 +5,7 @@
from torch.distributions.uniform import Uniform
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.utils import broadcast_all

euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
from torch.distributions.utils import broadcast_all, euler_constant


class Gumbel(TransformedDistribution):
Expand Down
4 changes: 1 addition & 3 deletions torch/distributions/kl.py
Expand Up @@ -31,7 +31,7 @@
from .poisson import Poisson
from .transformed_distribution import TransformedDistribution
from .uniform import Uniform
from .utils import _sum_rightmost
from .utils import _sum_rightmost, euler_constant as _euler_gamma

_KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions.
_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions.
Expand Down Expand Up @@ -174,8 +174,6 @@ def kl_divergence(p, q):
# KL Divergence Implementations
################################################################################

_euler_gamma = 0.57721566490153286060

# Same distributions


Expand Down
66 changes: 66 additions & 0 deletions torch/distributions/kumaraswamy.py
@@ -0,0 +1,66 @@
import torch
from torch.distributions import constraints
from torch.distributions.uniform import Uniform
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, PowerTransform
from torch.distributions.utils import broadcast_all, euler_constant


def _moments(a, b, n):
"""
Computes nth moment of Kumaraswamy using using torch.lgamma
"""
arg1 = 1 + n / a
log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
return b * torch.exp(log_value)


class Kumaraswamy(TransformedDistribution):
r"""
Samples from a Kumaraswamy distribution.
Example::
>>> m = Kumaraswamy(torch.Tensor([1.0]), torch.Tensor([1.0]))
>>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
tensor([ 0.1729])
Args:
concentration1 (float or Tensor): 1st concentration parameter of the distribution
(often referred to as alpha)
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
(often referred to as beta)
"""
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive}
support = constraints.unit_interval
has_rsample = True

def __init__(self, concentration1, concentration0, validate_args=None):
self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0)
finfo = torch.finfo(self.concentration0.dtype)
base_dist = Uniform(torch.full_like(self.concentration0, 0),
torch.full_like(self.concentration0, 1))
transforms = [PowerTransform(exponent=self.concentration0.reciprocal()),
AffineTransform(loc=1., scale=-1.),
PowerTransform(exponent=self.concentration1.reciprocal())]
super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Kumaraswamy, _instance)
new.concentration1 = self.concentration1.expand(batch_shape)
new.concentration0 = self.concentration0.expand(batch_shape)
return super(Kumaraswamy, self).expand(batch_shape, _instance=new)

@property
def mean(self):
return _moments(self.concentration1, self.concentration0, 1)

@property
def variance(self):
return _moments(self.concentration1, self.concentration0, 2) - torch.pow(self.mean, 2)

def entropy(self):
t1 = (1 - self.concentration1.reciprocal())
t0 = (1 - self.concentration0.reciprocal())
H0 = torch.digamma(self.concentration0 + 1) + euler_constant
return t0 + t1 * H0 - torch.log(self.concentration1) - torch.log(self.concentration0)
3 changes: 3 additions & 0 deletions torch/distributions/utils.py
Expand Up @@ -5,6 +5,9 @@
from typing import Dict, Any


euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant


def broadcast_all(*values):
r"""
Given a list of values (possibly containing numbers), returns a list where each
Expand Down

0 comments on commit 47db191

Please sign in to comment.