Skip to content

Commit

Permalink
Implement PositivePowerTransform (#2904)
Browse files Browse the repository at this point in the history
* Add PositivePowerTransform

* Fix docs build

* Address review comment
  • Loading branch information
fritzo committed Jul 20, 2021
1 parent 9f67c43 commit 09e4401
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
"funsor": ("http://funsor.pyro.ai/en/stable/", None),
"opt_einsum": ("https://optimized-einsum.readthedocs.io/en/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"Bio": ("https://biopython.readthedocs.io/en/latest/", None),
"Bio": ("https://biopython.org/docs/latest/api/", None),
"horovod": ("https://horovod.readthedocs.io/en/stable/", None),
}

Expand Down
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,13 @@ Permute
:undoc-members:
:show-inheritance:

PositivePowerTransform
----------------------
.. autoclass:: pyro.distributions.transforms.PositivePowerTransform
:members:
:undoc-members:
:show-inheritance:

SoftplusLowerCholeskyTransform
------------------------------
.. autoclass:: pyro.distributions.transforms.SoftplusLowerCholeskyTransform
Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .permute import Permute, permute
from .planar import ConditionalPlanar, Planar, conditional_planar, planar
from .polynomial import Polynomial, polynomial
from .power import PositivePowerTransform
from .radial import ConditionalRadial, Radial, conditional_radial, radial
from .softplus import SoftplusLowerCholeskyTransform, SoftplusTransform
from .spline import ConditionalSpline, Spline, conditional_spline, spline
Expand Down Expand Up @@ -180,6 +181,7 @@ def iterated(repeats, base_fn, *args, **kwargs):
"Permute",
"Planar",
"Polynomial",
"PositivePowerTransform",
"Radial",
"SoftplusLowerCholeskyTransform",
"SoftplusTransform",
Expand Down
60 changes: 60 additions & 0 deletions pyro/distributions/transforms/power.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import Distribution, constraints
from torch.distributions.transforms import Transform


class PositivePowerTransform(Transform):
r"""
Transform via the mapping
:math:`y=\operatorname{sign}(x)|x|^{\text{exponent}}`.
Whereas :class:`~torch.distributions.transforms.PowerTransform` allows
arbitrary ``exponent`` and restricts domain and codomain to postive values,
this class restricts ``exponent > 0`` and allows real domain and codomain.
.. warning:: The Jacobian is typically zero or infinite at the origin.
"""
domain = constraints.real
codomain = constraints.real
bijective = True
sign = +1

def __init__(self, exponent, *, cache_size=0, validate_args=None):
super().__init__(cache_size=cache_size)
if isinstance(exponent, int):
exponent = float(exponent)
exponent = torch.as_tensor(exponent)
self.exponent = exponent
if validate_args is None:
validate_args = Distribution._validate_args
if validate_args:
if not exponent.gt(0).all():
raise ValueError(f"Expected exponent > 0 but got:{exponent}")

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return PositivePowerTransform(self.exponent, cache_size=cache_size)

def __eq__(self, other):
if not isinstance(other, PositivePowerTransform):
return False
return self.exponent.eq(other.exponent).all().item()

def _call(self, x):
return x.abs().pow(self.exponent) * x.sign()

def _inverse(self, y):
return y.abs().pow(self.exponent.reciprocal()) * y.sign()

def log_abs_det_jacobian(self, x, y):
return self.exponent.log() + (y / x).log()

def forward_shape(self, shape):
return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))

def inverse_shape(self, shape):
return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
10 changes: 7 additions & 3 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _test(
if event_dim > 1:
transform = Flatten(transform, event_shape)
self._test_jacobian(reduce(operator.mul, event_shape, 1), transform)
if autodiff:
if isinstance(transform, dist.TransformModule) and autodiff:
# If the function doesn't have an explicit inverse, then use the forward op for autodiff
self._test_autodiff(
reduce(operator.mul, event_shape, 1), transform, inverse=not inverse
Expand Down Expand Up @@ -390,10 +390,14 @@ def test_sylvester(self):
self._test(T.sylvester, inverse=False)

def test_normalize_transform(self):
self._test(lambda p: T.Normalize(p=p), autodiff=False)
self._test(lambda p: T.Normalize(p=p))

def test_softplus(self):
self._test(lambda _: T.SoftplusTransform(), autodiff=False)
self._test(lambda _: T.SoftplusTransform())

def test_positive_power(self):
for p in [0.3, 1.0, 3.0]:
self._test(lambda _: T.PositivePowerTransform(p), event_dim=0)


@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 5)])
Expand Down

0 comments on commit 09e4401

Please sign in to comment.