Skip to content

Commit

Permalink
Add conditional inverse and compose TransformModules (#3185)
Browse files Browse the repository at this point in the history
* Add conditional compose and inverse transformmodules

* private name

* add tests

* fix test

* lint

* rename

* add docstring with example and flesh out test

* nit

* fix doctest

* address comment
  • Loading branch information
eb8680 committed Mar 4, 2023
1 parent 04fc486 commit 9afb089
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 3 deletions.
65 changes: 65 additions & 0 deletions pyro/distributions/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn

from .torch import TransformedDistribution
from .torch_transform import ComposeTransformModule


class ConditionalDistribution(ABC):
Expand Down Expand Up @@ -36,6 +37,70 @@ def __init__(self, *args, **kwargs):
def __hash__(self):
return super().__hash__()

@property
def inv(self) -> "ConditionalTransformModule":
return _ConditionalInverseTransformModule(self)


class _ConditionalInverseTransformModule(ConditionalTransformModule):
def __init__(self, transform: ConditionalTransform):
super().__init__()
self._transform = transform

@property
def inv(self) -> ConditionalTransform:
return self._transform

def condition(self, context: torch.Tensor):
return self._transform.condition(context).inv


class ConditionalComposeTransformModule(
ConditionalTransformModule, torch.nn.ModuleList
):
"""
Conditional analogue of :class:`~pyro.distributions.torch_transform.ComposeTransformModule` .
Useful as a base class for specifying complicated conditional distributions::
>>> class ConditionalFlowStack(dist.conditional.ConditionalComposeTransformModule):
... def __init__(self, input_dim, context_dim, hidden_dims, num_flows):
... super().__init__([
... dist.transforms.conditional_planar(input_dim, context_dim, hidden_dims)
... for _ in range(num_flows)
... ], cache_size=1)
>>> cond_dist = dist.conditional.ConditionalTransformedDistribution(
... dist.Normal(torch.zeros(3), torch.ones(3)).to_event(1),
... [ConditionalFlowStack(3, 2, [8, 8], num_flows=4).inv]
... )
>>> context = torch.rand(10, 2)
>>> data = torch.rand(10, 3)
>>> nll = -cond_dist.condition(context).log_prob(data)
"""

def __init__(self, transforms, cache_size: int = 0):
self.transforms = [
ConstantConditionalTransform(t)
if not isinstance(t, ConditionalTransform)
else t
for t in transforms
]
super().__init__()
if cache_size not in {0, 1}:
raise ValueError("cache_size must be 0 or 1")
self._cache_size = cache_size
# for parameter storage
for t in transforms:
if isinstance(t, torch.nn.Module):
self.append(t)

def condition(self, context: torch.Tensor) -> ComposeTransformModule:
return ComposeTransformModule(
[t.condition(context) for t in self.transforms]
).with_cache(self._cache_size)


class ConstantConditionalDistribution(ConditionalDistribution):
def __init__(self, base_dist):
Expand Down
12 changes: 9 additions & 3 deletions pyro/distributions/torch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ class ComposeTransformModule(torch.distributions.ComposeTransform, torch.nn.Modu
store when used in :class:`~pyro.nn.module.PyroModule` instances.
"""

def __init__(self, parts):
super().__init__(parts)
def __init__(self, parts, cache_size=0):
super().__init__(parts, cache_size=cache_size)
for part in parts:
self.append(part)
if isinstance(part, torch.nn.Module):
self.append(part)

def __hash__(self):
return super(torch.nn.Module, self).__hash__()

def with_cache(self, cache_size=1):
if cache_size == self._cache_size:
return self
return ComposeTransformModule(self.parts, cache_size=cache_size)
71 changes: 71 additions & 0 deletions tests/distributions/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,74 @@ def test_lower_cholesky_transform(transform, batch_shape, dim):
y2 = transform(x2)
assert y2.shape == shape
assert_close(y, y2)


@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 7)])
@pytest.mark.parametrize("input_dim", [2, 3, 5])
@pytest.mark.parametrize("context_dim", [2, 3, 5])
def test_inverse_conditional_transform_module(batch_shape, input_dim, context_dim):
cond_transform = T.conditional_spline(input_dim, context_dim, [6])

noise = torch.rand(batch_shape + (input_dim,))
context = torch.rand(batch_shape + (context_dim,))

assert_close(
cond_transform.inv.condition(context)(noise),
cond_transform.condition(context).inv(noise),
)

assert cond_transform.inv.inv is cond_transform
assert_close(
cond_transform.inv.condition(context).inv(noise),
cond_transform.condition(context).inv.inv(noise),
)


@pytest.mark.parametrize("batch_shape", [(), (7,), (6, 7)])
@pytest.mark.parametrize("input_dim", [2, 3, 5])
@pytest.mark.parametrize("context_dim", [2, 3, 5])
@pytest.mark.parametrize("cache_size", [0, 1])
def test_conditional_compose_transform_module(
batch_shape, input_dim, context_dim, cache_size
):
conditional_transforms = [
T.AffineTransform(1.0, 2.0),
T.Spline(input_dim),
T.conditional_spline(input_dim, context_dim, [5]),
T.SoftplusTransform(),
T.conditional_spline(input_dim, context_dim, [6]),
]
cond_transform = dist.conditional.ConditionalComposeTransformModule(
conditional_transforms, cache_size=cache_size
)

base_dist = dist.Normal(0, 1).expand(batch_shape + (input_dim,)).to_event(1)
cond_dist = dist.ConditionalTransformedDistribution(base_dist, [cond_transform])

context = torch.rand(batch_shape + (context_dim,))
d = cond_dist.condition(context)
transform = d.transforms[0]
assert isinstance(transform, T.ComposeTransformModule)

data = d.rsample()
assert data.shape == batch_shape + (input_dim,)
assert d.log_prob(data).shape == batch_shape

actual_params = set(cond_transform.parameters())
expected_params = set(
torch.nn.ModuleList(
[t for t in conditional_transforms if isinstance(t, torch.nn.Module)]
).parameters()
)
assert set() != actual_params == expected_params

noise = base_dist.rsample()
expected = noise
for t in conditional_transforms:
expected = (t.condition(context) if hasattr(t, "condition") else t)(expected)

actual = transform(noise)
assert_close(actual, expected)

assert_close(cond_transform.inv.condition(context)(actual), noise)
assert_close(cond_transform.condition(context).inv(expected), noise)

0 comments on commit 9afb089

Please sign in to comment.