Skip to content

Commit

Permalink
Implement backwards-comparible shape and Ellipsis-enabled dims
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>
  • Loading branch information
michaelosthege and twiecki committed May 14, 2021
1 parent aebc7e2 commit 2f124df
Show file tree
Hide file tree
Showing 6 changed files with 507 additions and 68 deletions.
4 changes: 4 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

### New Features
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4696](https://github.com/pymc-devs/pymc3/pull/4696)):
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Numeric entries in `shape` restrict the model variable to the exact length and re-sizing is no longer possible.
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects. An `Ellipsis` (`...`) in the last position of `dims` can be used as short-hand notation for implied dimensions.
- The `size` kwarg behaves like it does in Aesara/NumPy. For univariate RVs it is the same as `shape`, but for multivariate RVs it depends on how the RV implements broadcasting to dimensionality greater than `RVOp.ndim_supp`.
- ...

### Maintenance
Expand Down
281 changes: 253 additions & 28 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@

from abc import ABCMeta
from copy import copy
from typing import TYPE_CHECKING
from typing import Any, Optional, Sequence, Tuple, Union

import aesara
import aesara.tensor as at
import dill

from aesara.graph.basic import Variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import SpecifyShape, specify_shape

from pymc3.aesaraf import change_rv_size, pandas_to_array
from pymc3.distributions import _logcdf, _logp

if TYPE_CHECKING:
from typing import Optional, Callable

import aesara
import aesara.graph.basic
import aesara.tensor as at

from pymc3.exceptions import ShapeError, ShapeWarning
from pymc3.util import UNSET, get_repr_for_variable
from pymc3.vartypes import string_types

Expand All @@ -52,6 +50,10 @@

PLATFORM = sys.platform

Shape = Union[int, Sequence[Union[str, type(Ellipsis)]], Variable]
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
Size = Union[int, Tuple[int, ...]]


class _Unpickling:
pass
Expand Down Expand Up @@ -115,13 +117,111 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
return new_cls


def _valid_ellipsis_position(items: Union[None, Shape, Dims, Size]) -> bool:
if items is not None and not isinstance(items, Variable) and Ellipsis in items:
if any(i == Ellipsis for i in items[:-1]):
return False
return True


def _validate_shape_dims_size(
shape: Any = None, dims: Any = None, size: Any = None
) -> Tuple[Optional[Shape], Optional[Dims], Optional[Size]]:
# Raise on unsupported parametrization
if shape is not None and dims is not None:
raise ValueError(f"Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!")
if dims is not None and size is not None:
raise ValueError(f"Passing both `dims` ({dims}) and `size` ({size}) is not supported!")
if shape is not None and size is not None:
raise ValueError(f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!")

# Raise on invalid types
if not isinstance(shape, (type(None), int, list, tuple, Variable)):
raise ValueError("The `shape` parameter must be an int, list or tuple.")
if not isinstance(dims, (type(None), str, list, tuple)):
raise ValueError("The `dims` parameter must be a str, list or tuple.")
if not isinstance(size, (type(None), int, list, tuple)):
raise ValueError("The `size` parameter must be an int, list or tuple.")

# Auto-convert non-tupled parameters
if isinstance(shape, int):
shape = (shape,)
if isinstance(dims, str):
dims = (dims,)
if isinstance(size, int):
size = (size,)

# Convert to actual tuples
if not isinstance(shape, (type(None), tuple, Variable)):
shape = tuple(shape)
if not isinstance(dims, (type(None), tuple)):
dims = tuple(dims)
if not isinstance(size, (type(None), tuple)):
size = tuple(size)

if not _valid_ellipsis_position(shape):
raise ValueError(
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
)
if not _valid_ellipsis_position(dims):
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
if size is not None and Ellipsis in size:
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
return shape, dims, size


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""

rv_class = None
rv_op = None

def __new__(cls, name, *args, **kwargs):
def __new__(
cls,
name: str,
*args,
rng=None,
dims: Optional[Dims] = None,
testval=None,
observed=None,
total_size=None,
transform=UNSET,
**kwargs,
) -> RandomVariable:
"""Adds a RandomVariable corresponding to a PyMC3 distribution to the current model.
Note that all remaining kwargs must be compatible with ``.dist()``
Parameters
----------
cls : type
A PyMC3 distribution.
name : str
Name for the new model variable.
rng : optional
Random number generator to use with the RandomVariable.
dims : tuple, optional
A tuple of dimension names known to the model.
testval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
observed : optional
Observed data to be passed when registering the random variable in the model.
See ``Model.register_rv``.
total_size : float, optional
See ``Model.register_rv``.
transform : optional
See ``Model.register_rv``.
**kwargs
Keyword arguments that will be forwarded to ``.dist()``.
Most prominently: ``shape`` and ``size``
Returns
-------
rv : RandomVariable
The created RV, registered in the Model.
"""

try:
from pymc3.model import Model

Expand All @@ -134,40 +234,165 @@ def __new__(cls, name, *args, **kwargs):
"for a standalone distribution."
)

rng = kwargs.pop("rng", None)

if rng is None:
rng = model.default_rng

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

data = kwargs.pop("observed", None)
if rng is None:
rng = model.default_rng

total_size = kwargs.pop("total_size", None)
_, dims, _ = _validate_shape_dims_size(dims=dims)
resize = None

# Create the RV without specifying testval, because the testval may have a shape
# that only matches after replicating with a size implied by dims (see below).
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
n_implied = rv_out.ndim

# The `.dist()` can wrap automatically with a SpecifyShape Op which brings informative
# error messages earlier in model construction.
# Here, however, the underyling RV must be used - a new SpecifyShape Op can be added at the end.
assert_shape = None
if isinstance(rv_out.owner.op, SpecifyShape):
rv_out, assert_shape = rv_out.owner.inputs

# `dims` are only available with this API, because `.dist()` can be used
# without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
if Ellipsis in dims:
# Auto-complete the dims tuple to the full length
dims = (*dims[:-1], *[None] * rv_out.ndim)

n_resize = len(dims) - n_implied

# All resize dims must be known already (numerically or symbolically).
unknown_resize_dims = set(dims[:n_resize]) - set(model.dim_lengths)
if unknown_resize_dims:
raise KeyError(
f"Dimensions {unknown_resize_dims} are unknown to the model and cannot be used to specify a `size`."
)

dims = kwargs.pop("dims", None)
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
resize = tuple(model.dim_lengths[dname] for dname in dims[:n_resize])
elif observed is not None:
if not hasattr(observed, "shape"):
observed = pandas_to_array(observed)
n_resize = observed.ndim - n_implied
resize = tuple(observed.shape[d] for d in range(n_resize))

if resize:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv_var=rv_out, new_size=resize, expand=True)

if dims is not None:
# Now that we have a handle on the output RV, we can register named implied dimensions that
# were not yet known to the model, such that they can be used for size further downstream.
for di, dname in enumerate(dims[n_resize:]):
if not dname in model.dim_lengths:
model.add_coord(dname, values=None, length=rv_out.shape[n_resize + di])

if "shape" in kwargs:
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")
if testval is not None:
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = testval

transform = kwargs.pop("transform", UNSET)
rv_registered = model.register_rv(
rv_out, name, observed, total_size, dims=dims, transform=transform
)

rv_out = cls.dist(*args, rng=rng, **kwargs)
# Wrapping in specify_shape now does not break transforms:
if assert_shape is not None:
rv_registered = specify_shape(rv_registered, assert_shape)

return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
return rv_registered

@classmethod
def dist(cls, dist_params, **kwargs):
def dist(
cls,
dist_params,
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
testval=None,
**kwargs,
) -> RandomVariable:
"""Creates a RandomVariable corresponding to the `cls` distribution.
testval = kwargs.pop("testval", None)
Parameters
----------
dist_params
shape : tuple, optional
A tuple of sizes for each dimension of the new RV.
size : int, tuple, Variable, optional
A scalar or tuple for replicating the RV in addition
to its implied shape/dimensionality.
testval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
rv_var = cls.rv_op(*dist_params, **kwargs)
Returns
-------
rv : RandomVariable
The created RV.
"""
if "dims" in kwargs:
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")

shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
ndim_supp = cls.rv_op.ndim_supp

if shape is not None:
ndim_expected = len(tuple(shape))
batch_shape = tuple(shape)[: ndim_expected - ndim_supp]
elif size is not None:
ndim_expected = ndim_supp + len(tuple(size))
batch_shape = size
else:
ndim_expected = None
batch_shape = None

# Create the RV with a `size` right away.
# This is not necessarily the final result.
rv_intermediate = cls.rv_op(*dist_params, size=batch_shape, **kwargs)
ndim_actual = rv_intermediate.ndim

if shape is not None:
ndim_batch = len(tuple(shape)) - ndim_supp
elif size is not None:
ndim_batch = len(tuple(size))
else:
ndim_batch = ndim_actual - ndim_supp
needs_resize = shape is not None and size is None and ndim_actual != ndim_expected

# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
if needs_resize:
# There are shape dimensions that go beyond what's implied by the RV parameters.
# Recreate the RV without passing `size`, this time creating batch dimensions with `change_rv_size`.
rv_out = change_rv_size(
rv_var=cls.rv_op(*dist_params, size=None, **kwargs),
new_size=shape[:-ndim_batch] if ndim_batch > 0 else (),
expand=True,
)
if not rv_out.ndim == ndim_expected:
raise ShapeError(
f"Resized RV does not have the expected dimensionality. ",
f"This indicates a severe problem. Please open an issue.",
actual=rv_out.ndim,
expected=ndim_batch + ndim_supp,
)
else:
rv_out = rv_intermediate

# Warn about the edge cases where the RV Op creates more dimensions than
# it should based on `size` and `RVOp.ndim_supp`.
if size is not None and ndim_actual != ndim_expected:
warnings.warn(
f"You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional.",
ShapeWarning,
)

if testval is not None:
rv_var.tag.test_value = testval
rv_out.tag.test_value = testval

return rv_var
return rv_out

def _distr_parameters_for_repr(self):
"""Return the names of the parameters for this distribution (e.g. "mu"
Expand Down
Loading

0 comments on commit 2f124df

Please sign in to comment.