Skip to content

Commit

Permalink
Add gufunc_signature to SymbolicRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 29, 2024
1 parent 2f85e66 commit 6252d2e
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 36 deletions.
3 changes: 2 additions & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class CensoredRV(SymbolicRandomVariable):
"""Censored random variable"""

inline_logprob = True
signature = "(),(),()->()"
ndim_supp = 0
_print_name = ("Censored", "\\operatorname{Censored}")


Expand Down Expand Up @@ -115,7 +117,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):
return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
ndim_supp=0,
)(dist, lower, upper)


Expand Down
140 changes: 114 additions & 26 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from pytensor.graph.utils import MetaType
from pytensor.scan.op import Scan
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blockwise import safe_signature
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.utils import _parse_gufunc_signature
from pytensor.tensor.variable import TensorVariable
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -261,6 +263,12 @@ class SymbolicRandomVariable(OpFromGraph):
(0 for scalar, 1 for vector, ...)
"""

ndims_params: Optional[Sequence[int]] = None
"""Number of core dimensions of the distribution's parameters."""

signature: str = None
"""Numpy-like vectorized signature of the distribution."""

inline_logprob: bool = False
"""Specifies whether the logprob function is derived automatically by introspection
of the inner graph.
Expand All @@ -271,9 +279,25 @@ class SymbolicRandomVariable(OpFromGraph):
_print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""

def __init__(self, *args, ndim_supp, **kwargs):
"""Initialitze a SymbolicRandomVariable class."""
self.ndim_supp = ndim_supp
def __init__(
self,
*args,
**kwargs,
):
"""Initialize a SymbolicRandomVariable class."""
if self.signature is None:
self.signature = kwargs.get("signature", None)

if self.signature is not None:
inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(sig) for sig in inputs_sig]
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig)

if self.ndim_supp is None:
self.ndim_supp = kwargs.get("ndim_supp", None)
if self.ndim_supp is None:
raise ValueError("ndim_supp or gufunc_signature must be provided")

kwargs.setdefault("inline", True)
super().__init__(*args, **kwargs)

Expand All @@ -286,6 +310,11 @@ def update(self, node: Node):
"""
return {}

def batch_ndim(self, node: Node) -> int:
"""Number of dimensions of the distribution's batch shape."""
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
return out_ndim - self.ndim_supp


class Distribution(metaclass=DistributionMeta):
"""Statistical distribution"""
Expand Down Expand Up @@ -558,23 +587,29 @@ def dist(
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp is None or ndims_params is None:
if signature is None:
ndim_supp = 0
ndims_params = [0] * len(dist_params)
else:
inputs, outputs = _parse_gufunc_signature(signature)
ndim_supp = max(len(out) for out in outputs)
ndims_params = [len(inp) for inp in inputs]

if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)

dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(dist_params)

if logp is None:
logp = default_not_implemented(class_name, "logp")

Expand Down Expand Up @@ -614,7 +649,7 @@ def rv_op(
random: Optional[Callable],
support_point: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
ndims_params: Sequence[int],
dtype: str,
class_name: str,
**kwargs,
Expand Down Expand Up @@ -702,7 +737,9 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
Expand All @@ -712,14 +749,24 @@ def dist(
if logcdf is None:
logcdf = default_not_implemented(class_name, "logcdf")

if signature is None:
if ndim_supp is None:
ndim_supp = 0
if ndims_params is None:
ndims_params = [0] * len(dist_params)
signature = safe_signature(
core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params],
core_outputs=[pt.tensor(shape=(None,) * ndim_supp)],
)

return super().dist(
dist_params,
class_name=class_name,
logp=logp,
logcdf=logcdf,
dist=dist,
support_point=support_point,
ndim_supp=ndim_supp,
signature=signature,
**kwargs,
)

Expand All @@ -732,7 +779,7 @@ def rv_op(
logcdf: Optional[Callable],
support_point: Optional[Callable],
size=None,
ndim_supp: int,
signature: str,
class_name: str,
):
size = normalize_size_param(size)
Expand All @@ -745,6 +792,10 @@ def rv_op(
dummy_params = [dummy_size_param, *dummy_dist_params]
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

signature = cls._infer_final_signature(
signature, len(dummy_params), len(dummy_updates_dict)
)

rv_type = type(
class_name,
(CustomSymbolicDistRV,),
Expand Down Expand Up @@ -802,7 +853,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
new_rv = new_rv_op(new_size, *dist_params)

Expand All @@ -811,10 +862,30 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
ndim_supp=ndim_supp,
signature=signature,
)
return rv_op(size, *dist_params)

@staticmethod
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:
"""Add size and updates to user provided gufunc signature if they are missing."""
input_sig, output_sig = signature.split("->")
# Numpy parser does not accept (constant) functions without inputs like "->()"
# We work around as this makes sense for distributions like Flat that have no inputs
if input_sig.strip() == "":
inputs = ()
_, outputs = _parse_gufunc_signature("()" + signature)
else:
inputs, outputs = _parse_gufunc_signature(signature)
if len(inputs) == n_inputs - 1:
# Assume size is missing
input_sig = ("()," if input_sig else "()") + input_sig
if len(outputs) == 1:
# Assume updates are missing
output_sig = "()," * n_updates + output_sig
signature = "->".join((input_sig, output_sig))
return signature


class CustomDist:
"""A helper class to create custom distributions
Expand All @@ -828,12 +899,12 @@ class CustomDist:
when not provided by the user.
Alternatively, a user can provide a `random` function that returns numerical
draws (e.g., via NumPy routines), and a `logp` function that must return an
Python graph that represents the logp graph when evaluated. This is used for
draws (e.g., via NumPy routines), and a `logp` function that must return a
PyTensor graph that represents the logp graph when evaluated. This is used for
mcmc sampling.
Additionally, a user can provide a `logcdf` and `support_point` functions that must return
an PyTensor graph that computes those quantities. These may be used by other PyMC
PyTensor graphs that computes those quantities. These may be used by other PyMC
routines.
Parameters
Expand Down Expand Up @@ -894,14 +965,18 @@ class CustomDist:
distribution parameters, in the same order as they were supplied when the
CustomDist was created. If ``None``, a default ``support_point`` function will be
assigned that will always return 0, or an array of zeros.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndim_supp : Optional[int]
The number of dimensions in the support of the distribution.
Inferred from signature, if provided. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0. This is not needed if an
PyTensor dist function is provided.
parameters. Inferred from signature, if provided. Defaults to assuming
all parameters are scalars, i.e. ``ndims_params=[0, ...]``.
signature : Optional[str]
A numpy vectorize-like signature that indicates the number and core dimensionality
of the input parameters and sample outputs of the CustomDist.
When specified, `ndim_supp` and `ndims_params` are not needed. See examples below.
dtype : str
The dtype of the distribution. All draws and observations passed into the
distribution will be cast onto this dtype. This is not needed if an PyTensor
Expand Down Expand Up @@ -939,6 +1014,7 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable:
Provide a random function that return numerical draws. This allows one to use a
CustomDist in prior and posterior predictive sampling.
A gufunc signature was also provided, which may be used by other routines.
.. code-block:: python
Expand All @@ -965,6 +1041,7 @@ def random(
mu,
logp=logp,
random=random,
signature="()->()",
observed=np.random.randn(100, 3),
size=(100, 3),
)
Expand All @@ -973,6 +1050,7 @@ def random(
Provide a dist function that creates a PyTensor graph built from other
PyMC distributions. PyMC can automatically infer that the logp of this
variable corresponds to a shifted Exponential distribution.
A gufunc signature was also provided, which may be used by other routines.
.. code-block:: python
Expand All @@ -994,6 +1072,7 @@ def dist(
lam,
shift,
dist=dist,
signature="(),()->()",
observed=[-1, -1, 0],
)
Expand Down Expand Up @@ -1040,10 +1119,11 @@ def __new__(
random: Optional[Callable] = None,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
# TODO: Deprecate ndim_supp / ndims_params in favor of signature?
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1057,6 +1137,7 @@ def __new__(
)
dist_params = cls.parse_dist_params(dist_params)
cls.check_valid_dist_random(dist, random, dist_params)
moment = kwargs.pop("moment", None)
if moment is not None:
warnings.warn(
"`moment` argument is deprecated. Use `support_point` instead.",
Expand All @@ -1073,6 +1154,8 @@ def __new__(
logcdf=logcdf,
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1086,6 +1169,7 @@ def __new__(
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand All @@ -1099,8 +1183,9 @@ def dist(
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
support_point: Optional[Callable] = None,
ndim_supp: int = 0,
ndim_supp: Optional[int] = None,
ndims_params: Optional[Sequence[int]] = None,
signature: Optional[str] = None,
dtype: str = "floatX",
**kwargs,
):
Expand All @@ -1114,6 +1199,8 @@ def dist(
logcdf=logcdf,
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
**kwargs,
)
else:
Expand All @@ -1125,6 +1212,7 @@ def dist(
support_point=support_point,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
**kwargs,
)
Expand Down
9 changes: 8 additions & 1 deletion pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,17 @@ def rv_op(cls, weights, *components, size=None):
# Output mix_indexes rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
signature = f"(),(w),({comp_s})->({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
signature = f"(),(w),{comps_s}->({s})"
mix_op = MarginalMixtureRV(
inputs=[mix_indexes_rng_, weights_, *components_],
outputs=[mix_indexes_rng_next_, mix_out_],
ndim_supp=components[0].owner.op.ndim_supp,
signature=signature,
)

# Create the actual MarginalMixture variable
Expand Down
Loading

0 comments on commit 6252d2e

Please sign in to comment.