Skip to content

Commit

Permalink
Jittable transforms (#1575)
Browse files Browse the repository at this point in the history
* [WIP] jittable transforms

* add licence to new test file

* turn BijectorConstraint into pytree

* test flattening/unflattening of parametrized constraints

* cosmetic edits

* fix typo

* implement tree_flatten/unflatten for transforms

* attempt to avoid confusing black

* add (un)flattening meths for BijectorTransform

* fixup! implement tree_flatten/unflatten for transforms

* test vmapping over transforms/constraints

* Make constraints `__eq__` checks robust to arbitrary inputs

* make transforms equality check robust to arbitrary inputs

* test constraints and transforms equality checks
  • Loading branch information
pierreglaser committed May 31, 2023
1 parent e230805 commit eab63ed
Show file tree
Hide file tree
Showing 6 changed files with 656 additions and 23 deletions.
14 changes: 14 additions & 0 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def __call__(self, x):
def codomain(self):
return _get_codomain(self.bijector)

def tree_flatten(self):
return self.bijector, ()

@classmethod
def tree_unflatten(cls, _, bijector):
return cls(bijector)


class BijectorTransform(Transform):
"""
Expand Down Expand Up @@ -106,6 +113,13 @@ def inverse_shape(self, shape):
batch_shape = shape[: len(shape) - len(out_event_shape)]
return batch_shape + in_shape

def tree_flatten(self):
return self.bijector, ()

@classmethod
def tree_unflatten(cls, _, bijector):
return cls(bijector)


@biject_to.register(BijectorConstraint)
def _transform_to_bijector_constraint(constraint):
Expand Down
125 changes: 114 additions & 11 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import numpy as np

import jax.numpy
import jax.numpy as jnp
from jax.tree_util import register_pytree_node


class Constraint(object):
Expand All @@ -75,6 +77,10 @@ class Constraint(object):
is_discrete = False
event_dim = 0

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)

def __call__(self, x):
raise NotImplementedError

Expand All @@ -94,8 +100,24 @@ def feasible_like(self, prototype):
"""
raise NotImplementedError

@classmethod
def tree_unflatten(cls, aux_data, params):
params_keys, aux_data = aux_data
self = cls.__new__(cls)
for k, v in zip(params_keys, params):
setattr(self, k, v)

for k, v in aux_data.items():
setattr(self, k, v)
return self


class ParameterFreeConstraint(Constraint):
def tree_flatten(self):
return (), ((), dict())


class _SingletonConstraint(Constraint):
class _SingletonConstraint(ParameterFreeConstraint):
"""
A constraint type which has only one canonical instance, like constraints.real,
and unlike constraints.interval.
Expand Down Expand Up @@ -202,8 +224,23 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)

def __eq__(self, other):
return (
type(self) is type(other)
and self._is_discrete == other._is_discrete
and self._event_dim == other._event_dim
)

def tree_flatten(self):
return (), (
(),
dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim),
)


class dependent_property(property, _Dependent):
# XXX: this should not need to be pytree-able since it simply wraps a method
# and thus is automatically present once the method's object is created
def __init__(
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
):
Expand Down Expand Up @@ -243,8 +280,16 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.lower_bound,), (("lower_bound",), dict())

def __eq__(self, other):
if not isinstance(other, _GreaterThan):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound)

class _Positive(_GreaterThan, _SingletonConstraint):

class _Positive(_SingletonConstraint, _GreaterThan):
def __init__(self):
super().__init__(0.0)

Expand Down Expand Up @@ -301,6 +346,20 @@ def __repr__(self):
def feasible_like(self, prototype):
return self.base_constraint.feasible_like(prototype)

def tree_flatten(self):
return (self.base_constraint,), (
("base_constraint",),
{"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims},
)

def __eq__(self, other):
if not isinstance(other, _IndependentConstraint):
return False

return (self.base_constraint == other.base_constraint) & (
self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims
)


class _RealVector(_IndependentConstraint, _SingletonConstraint):
def __init__(self):
Expand All @@ -327,6 +386,14 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.upper_bound,), (("upper_bound",), dict())

def __eq__(self, other):
if not isinstance(other, _LessThan):
return False
return jnp.array_equal(self.upper_bound, other.upper_bound)


class _IntegerInterval(Constraint):
is_discrete = True
Expand All @@ -348,6 +415,20 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.lower_bound, self.upper_bound), (
("lower_bound", "upper_bound"),
dict(),
)

def __eq__(self, other):
if not isinstance(other, _IntegerInterval):
return False

return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal(
self.upper_bound, other.upper_bound
)


class _IntegerGreaterThan(Constraint):
is_discrete = True
Expand All @@ -366,13 +447,21 @@ def __repr__(self):
def feasible_like(self, prototype):
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))

def tree_flatten(self):
return (self.lower_bound,), (("lower_bound",), dict())

class _IntegerPositive(_IntegerGreaterThan, _SingletonConstraint):
def __eq__(self, other):
if not isinstance(other, _IntegerGreaterThan):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound)


class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan):
def __init__(self):
super().__init__(1)


class _IntegerNonnegative(_IntegerGreaterThan, _SingletonConstraint):
class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan):
def __init__(self):
super().__init__(0)

Expand All @@ -398,19 +487,25 @@ def feasible_like(self, prototype):
)

def __eq__(self, other):
return (
isinstance(other, _Interval)
and self.lower_bound == other.lower_bound
and self.upper_bound == other.upper_bound
if not isinstance(other, _Interval):
return False
return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal(
self.upper_bound, other.upper_bound
)

def tree_flatten(self):
return (self.lower_bound, self.upper_bound), (
("lower_bound", "upper_bound"),
dict(),
)

class _Circular(_Interval, _SingletonConstraint):

class _Circular(_SingletonConstraint, _Interval):
def __init__(self):
super().__init__(-math.pi, math.pi)


class _UnitInterval(_Interval, _SingletonConstraint):
class _UnitInterval(_SingletonConstraint, _Interval):
def __init__(self):
super().__init__(0.0, 1.0)

Expand Down Expand Up @@ -462,6 +557,14 @@ def feasible_like(self, prototype):
value = jax.numpy.pad(jax.numpy.expand_dims(self.upper_bound, -1), pad_width)
return jax.numpy.broadcast_to(value, prototype.shape)

def tree_flatten(self):
return (self.upper_bound,), (("upper_bound",), dict())

def __eq__(self, other):
if not isinstance(other, _Multinomial):
return False
return jnp.array_equal(self.upper_bound, other.upper_bound)


class _L1Ball(_SingletonConstraint):
"""
Expand Down Expand Up @@ -546,7 +649,7 @@ def feasible_like(self, prototype):
return jax.numpy.full_like(prototype, 1 / prototype.shape[-1])


class _SoftplusPositive(_GreaterThan, _SingletonConstraint):
class _SoftplusPositive(_SingletonConstraint, _GreaterThan):
def __init__(self):
super().__init__(lower_bound=0.0)

Expand Down
24 changes: 24 additions & 0 deletions numpyro/distributions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
log_scale = intermediates
return log_scale.sum(-1)

def tree_flatten(self):
return (self.log_scale_min_clip, self.log_scale_max_clip), (
("log_scale_min_clip", "log_scale_max_clip"),
{"arn": self.arn},
)

def __eq__(self, other):
if not isinstance(other, InverseAutoregressiveTransform):
return False
return (
(self.arn is other.arn)
& jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip)
& jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip)
)


class BlockNeuralAutoregressiveTransform(Transform):
"""
Expand Down Expand Up @@ -139,3 +154,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
else:
logdet = intermediates
return logdet.sum(-1)

def tree_flatten(self):
return (), ((), {"bn_arn": self.bn_arn})

def __eq__(self, other):
return (
isinstance(other, BlockNeuralAutoregressiveTransform)
and self.bn_arn is other.bn_arn
)

0 comments on commit eab63ed

Please sign in to comment.