From 6252d2e58dc211c913ee2e652a4058d271d48bbd Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 18 Feb 2024 20:55:08 +0100 Subject: [PATCH] Add gufunc_signature to SymbolicRandomVariables --- pymc/distributions/censored.py | 3 +- pymc/distributions/distribution.py | 140 ++++++++++++++++++----- pymc/distributions/mixture.py | 9 +- pymc/distributions/multivariate.py | 9 +- pymc/distributions/timeseries.py | 18 ++- tests/distributions/test_distribution.py | 28 +++++ 6 files changed, 171 insertions(+), 36 deletions(-) diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 5cc155d3ff..8b717d3c24 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -30,6 +30,8 @@ class CensoredRV(SymbolicRandomVariable): """Censored random variable""" inline_logprob = True + signature = "(),(),()->()" + ndim_supp = 0 _print_name = ("Censored", "\\operatorname{Censored}") @@ -115,7 +117,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None): return CensoredRV( inputs=[dist_, lower_, upper_], outputs=[censored_rv_], - ndim_supp=0, )(dist, lower, upper) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 49b2563efb..7035e4e027 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -33,11 +33,13 @@ from pytensor.graph.utils import MetaType from pytensor.scan.op import Scan from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.blockwise import safe_signature from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias @@ -261,6 +263,12 @@ class SymbolicRandomVariable(OpFromGraph): (0 for scalar, 1 for vector, ...) """ + ndims_params: Optional[Sequence[int]] = None + """Number of core dimensions of the distribution's parameters.""" + + signature: str = None + """Numpy-like vectorized signature of the distribution.""" + inline_logprob: bool = False """Specifies whether the logprob function is derived automatically by introspection of the inner graph. @@ -271,9 +279,25 @@ class SymbolicRandomVariable(OpFromGraph): _print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}") """Tuple of (name, latex name) used for for pretty-printing variables of this type""" - def __init__(self, *args, ndim_supp, **kwargs): - """Initialitze a SymbolicRandomVariable class.""" - self.ndim_supp = ndim_supp + def __init__( + self, + *args, + **kwargs, + ): + """Initialize a SymbolicRandomVariable class.""" + if self.signature is None: + self.signature = kwargs.get("signature", None) + + if self.signature is not None: + inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature) + self.ndims_params = [len(sig) for sig in inputs_sig] + self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig) + + if self.ndim_supp is None: + self.ndim_supp = kwargs.get("ndim_supp", None) + if self.ndim_supp is None: + raise ValueError("ndim_supp or gufunc_signature must be provided") + kwargs.setdefault("inline", True) super().__init__(*args, **kwargs) @@ -286,6 +310,11 @@ def update(self, node: Node): """ return {} + def batch_ndim(self, node: Node) -> int: + """Number of dimensions of the distribution's batch shape.""" + out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs) + return out_ndim - self.ndim_supp + class Distribution(metaclass=DistributionMeta): """Statistical distribution""" @@ -558,12 +587,22 @@ def dist( logcdf: Optional[Callable] = None, random: Optional[Callable] = None, support_point: Optional[Callable] = None, - ndim_supp: int = 0, + ndim_supp: Optional[int] = None, ndims_params: Optional[Sequence[int]] = None, + signature: Optional[str] = None, dtype: str = "floatX", class_name: str = "CustomDist", **kwargs, ): + if ndim_supp is None or ndims_params is None: + if signature is None: + ndim_supp = 0 + ndims_params = [0] * len(dist_params) + else: + inputs, outputs = _parse_gufunc_signature(signature) + ndim_supp = max(len(out) for out in outputs) + ndims_params = [len(inp) for inp in inputs] + if ndim_supp > 0: raise NotImplementedError( "CustomDist with ndim_supp > 0 and without a `dist` function are not supported." @@ -571,10 +610,6 @@ def dist( dist_params = [as_tensor_variable(param) for param in dist_params] - # Assume scalar ndims_params - if ndims_params is None: - ndims_params = [0] * len(dist_params) - if logp is None: logp = default_not_implemented(class_name, "logp") @@ -614,7 +649,7 @@ def rv_op( random: Optional[Callable], support_point: Optional[Callable], ndim_supp: int, - ndims_params: Optional[Sequence[int]], + ndims_params: Sequence[int], dtype: str, class_name: str, **kwargs, @@ -702,7 +737,9 @@ def dist( logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, support_point: Optional[Callable] = None, - ndim_supp: int = 0, + ndim_supp: Optional[int] = None, + ndims_params: Optional[Sequence[int]] = None, + signature: Optional[str] = None, dtype: str = "floatX", class_name: str = "CustomDist", **kwargs, @@ -712,6 +749,16 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") + if signature is None: + if ndim_supp is None: + ndim_supp = 0 + if ndims_params is None: + ndims_params = [0] * len(dist_params) + signature = safe_signature( + core_inputs=[pt.tensor(shape=(None,) * ndim_param) for ndim_param in ndims_params], + core_outputs=[pt.tensor(shape=(None,) * ndim_supp)], + ) + return super().dist( dist_params, class_name=class_name, @@ -719,7 +766,7 @@ def dist( logcdf=logcdf, dist=dist, support_point=support_point, - ndim_supp=ndim_supp, + signature=signature, **kwargs, ) @@ -732,7 +779,7 @@ def rv_op( logcdf: Optional[Callable], support_point: Optional[Callable], size=None, - ndim_supp: int, + signature: str, class_name: str, ): size = normalize_size_param(size) @@ -745,6 +792,10 @@ def rv_op( dummy_params = [dummy_size_param, *dummy_dist_params] dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + signature = cls._infer_final_signature( + signature, len(dummy_params), len(dummy_updates_dict) + ) + rv_type = type( class_name, (CustomSymbolicDistRV,), @@ -802,7 +853,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand): new_rv_op = rv_type( inputs=dummy_params, outputs=[*dummy_updates_dict.values(), dummy_rv], - ndim_supp=ndim_supp, + signature=signature, ) new_rv = new_rv_op(new_size, *dist_params) @@ -811,10 +862,30 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand): rv_op = rv_type( inputs=dummy_params, outputs=[*dummy_updates_dict.values(), dummy_rv], - ndim_supp=ndim_supp, + signature=signature, ) return rv_op(size, *dist_params) + @staticmethod + def _infer_final_signature(signature: str, n_inputs, n_updates) -> str: + """Add size and updates to user provided gufunc signature if they are missing.""" + input_sig, output_sig = signature.split("->") + # Numpy parser does not accept (constant) functions without inputs like "->()" + # We work around as this makes sense for distributions like Flat that have no inputs + if input_sig.strip() == "": + inputs = () + _, outputs = _parse_gufunc_signature("()" + signature) + else: + inputs, outputs = _parse_gufunc_signature(signature) + if len(inputs) == n_inputs - 1: + # Assume size is missing + input_sig = ("()," if input_sig else "()") + input_sig + if len(outputs) == 1: + # Assume updates are missing + output_sig = "()," * n_updates + output_sig + signature = "->".join((input_sig, output_sig)) + return signature + class CustomDist: """A helper class to create custom distributions @@ -828,12 +899,12 @@ class CustomDist: when not provided by the user. Alternatively, a user can provide a `random` function that returns numerical - draws (e.g., via NumPy routines), and a `logp` function that must return an - Python graph that represents the logp graph when evaluated. This is used for + draws (e.g., via NumPy routines), and a `logp` function that must return a + PyTensor graph that represents the logp graph when evaluated. This is used for mcmc sampling. Additionally, a user can provide a `logcdf` and `support_point` functions that must return - an PyTensor graph that computes those quantities. These may be used by other PyMC + PyTensor graphs that computes those quantities. These may be used by other PyMC routines. Parameters @@ -894,14 +965,18 @@ class CustomDist: distribution parameters, in the same order as they were supplied when the CustomDist was created. If ``None``, a default ``support_point`` function will be assigned that will always return 0, or an array of zeros. - ndim_supp : int - The number of dimensions in the support of the distribution. Defaults to assuming - a scalar distribution, i.e. ``ndim_supp = 0``. + ndim_supp : Optional[int] + The number of dimensions in the support of the distribution. + Inferred from signature, if provided. Defaults to assuming + a scalar distribution, i.e. ``ndim_supp = 0`` ndims_params : Optional[Sequence[int]] The list of number of dimensions in the support of each of the distribution's - parameters. If ``None``, it is assumed that all parameters are scalars, hence - the number of dimensions of their support will be 0. This is not needed if an - PyTensor dist function is provided. + parameters. Inferred from signature, if provided. Defaults to assuming + all parameters are scalars, i.e. ``ndims_params=[0, ...]``. + signature : Optional[str] + A numpy vectorize-like signature that indicates the number and core dimensionality + of the input parameters and sample outputs of the CustomDist. + When specified, `ndim_supp` and `ndims_params` are not needed. See examples below. dtype : str The dtype of the distribution. All draws and observations passed into the distribution will be cast onto this dtype. This is not needed if an PyTensor @@ -939,6 +1014,7 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: Provide a random function that return numerical draws. This allows one to use a CustomDist in prior and posterior predictive sampling. + A gufunc signature was also provided, which may be used by other routines. .. code-block:: python @@ -965,6 +1041,7 @@ def random( mu, logp=logp, random=random, + signature="()->()", observed=np.random.randn(100, 3), size=(100, 3), ) @@ -973,6 +1050,7 @@ def random( Provide a dist function that creates a PyTensor graph built from other PyMC distributions. PyMC can automatically infer that the logp of this variable corresponds to a shifted Exponential distribution. + A gufunc signature was also provided, which may be used by other routines. .. code-block:: python @@ -994,6 +1072,7 @@ def dist( lam, shift, dist=dist, + signature="(),()->()", observed=[-1, -1, 0], ) @@ -1040,10 +1119,11 @@ def __new__( random: Optional[Callable] = None, logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, - moment: Optional[Callable] = None, support_point: Optional[Callable] = None, - ndim_supp: int = 0, + # TODO: Deprecate ndim_supp / ndims_params in favor of signature? + ndim_supp: Optional[int] = None, ndims_params: Optional[Sequence[int]] = None, + signature: Optional[str] = None, dtype: str = "floatX", **kwargs, ): @@ -1057,6 +1137,7 @@ def __new__( ) dist_params = cls.parse_dist_params(dist_params) cls.check_valid_dist_random(dist, random, dist_params) + moment = kwargs.pop("moment", None) if moment is not None: warnings.warn( "`moment` argument is deprecated. Use `support_point` instead.", @@ -1073,6 +1154,8 @@ def __new__( logcdf=logcdf, support_point=support_point, ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, **kwargs, ) else: @@ -1086,6 +1169,7 @@ def __new__( support_point=support_point, ndim_supp=ndim_supp, ndims_params=ndims_params, + signature=signature, dtype=dtype, **kwargs, ) @@ -1099,8 +1183,9 @@ def dist( logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, support_point: Optional[Callable] = None, - ndim_supp: int = 0, + ndim_supp: Optional[int] = None, ndims_params: Optional[Sequence[int]] = None, + signature: Optional[str] = None, dtype: str = "floatX", **kwargs, ): @@ -1114,6 +1199,8 @@ def dist( logcdf=logcdf, support_point=support_point, ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, **kwargs, ) else: @@ -1125,6 +1212,7 @@ def dist( support_point=support_point, ndim_supp=ndim_supp, ndims_params=ndims_params, + signature=signature, dtype=dtype, **kwargs, ) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 46e13c6017..452cb471dd 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -296,10 +296,17 @@ def rv_op(cls, weights, *components, size=None): # Output mix_indexes rng update so that it can be updated in place mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0] + s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp)) + if len(components) == 1: + comp_s = ",".join((*s, "w")) + signature = f"(),(w),({comp_s})->({s})" + else: + comps_s = ",".join(f"({s})" for _ in components) + signature = f"(),(w),{comps_s}->({s})" mix_op = MarginalMixtureRV( inputs=[mix_indexes_rng_, weights_, *components_], outputs=[mix_indexes_rng_next_, mix_out_], - ndim_supp=components[0].owner.op.ndim_supp, + signature=signature, ) # Create the actual MarginalMixture variable diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 8d980251ce..956bca276d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1163,6 +1163,8 @@ def rng_fn(self, rng, n, eta, D, size): # be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper class _LKJCholeskyCovRV(SymbolicRandomVariable): default_output = 1 + signature = "(),(),(),(n)->(),(n)" + ndim_supp = 1 _print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}") def update(self, node): @@ -1218,7 +1220,6 @@ def rv_op(cls, n, eta, sd_dist, size=None): return _LKJCholeskyCovRV( inputs=[rng_, n_, eta_, sd_dist_], outputs=[next_rng_, lkjcov_], - ndim_supp=1, )(rng, n, eta, sd_dist) @@ -2787,10 +2788,12 @@ def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None): for axis in range(n_zerosum_axes): zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) + support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) + signature = f"({support_str}),(),(s)->({support_str})" return ZeroSumNormalRV( inputs=[normal_dist_, sigma_, support_shape_], - outputs=[zerosum_rv_, support_shape_], - ndim_supp=n_zerosum_axes, + outputs=[zerosum_rv_], + signature=signature, )(normal_dist, sigma, support_shape) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index c100117555..5ff3948458 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -195,12 +195,17 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None): # shape = (B, T, S) grw_ = pt.concatenate([init_dist_dimswapped_, innovation_dist_dimswapped_], axis=-ndim_supp) grw_ = pt.cumsum(grw_, axis=-ndim_supp) + + innov_supp_dims = [f"d{i}" for i in range(dist_ndim_supp)] + innov_supp_str = ",".join(innov_supp_dims) + out_supp_str = ",".join(["t", *innov_supp_dims]) + signature = f"({innov_supp_str}),({innov_supp_str}),(s)->({out_supp_str})" return RandomWalkRV( [init_dist_, innovation_dist_, steps_], # We pass steps_ through just so we can keep a reference to it, even though # it's no longer needed at this point - [grw_, steps_], - ndim_supp=ndim_supp, + [grw_], + signature=signature, )(init_dist, innovation_dist, steps) @@ -417,6 +422,8 @@ def get_dists( class AutoRegressiveRV(SymbolicRandomVariable): """A placeholder used to specify a log-likelihood for an AR sub-graph.""" + signature = "(o),(),(o),(s)->(),(t)" + ndim_supp = 1 default_output = 1 ar_order: int constant_term: bool @@ -655,7 +662,6 @@ def step(*args): outputs=[noise_next_rng, ar_], ar_order=ar_order, constant_term=constant_term, - ndim_supp=1, ) ar = ar_op(rhos, sigma, init_dist, steps) @@ -719,6 +725,8 @@ class GARCH11RV(SymbolicRandomVariable): """A placeholder used to specify a GARCH11 graph.""" default_output = 1 + signature = "(),(),(),(),(),(s)->(),(t)" + ndim_supp = 1 _print_name = ("GARCH11", "\\operatorname{GARCH11}") def update(self, node: Node): @@ -825,7 +833,6 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): garch11_op = GARCH11RV( inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_], outputs=[noise_next_rng, garch11_], - ndim_supp=1, ) garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps) @@ -881,6 +888,7 @@ class EulerMaruyamaRV(SymbolicRandomVariable): default_output = 1 dt: float sde_fn: Callable + ndim_supp = 1 _print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}") def __init__(self, *args, dt, sde_fn, **kwargs): @@ -1006,7 +1014,7 @@ def step(*prev_args): outputs=[noise_next_rng, sde_out_], dt=dt, sde_fn=sde_fn, - ndim_supp=1, + signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)", ) eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 4e2ca8143f..bb43063be9 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -744,6 +744,34 @@ def inner_dist(size=None): pm.logp(pm.LogNormal.dist(), 1.0).eval(), ) + def test_signature(self): + def dist(p, size): + return -pm.Categorical.dist(p=p, size=size) + + out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(p)->()") + # Size and updates are added automatically to the signature + assert out.owner.op.signature == "(),(p)->(),()" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [0, 1] + + # When recreated internally, the whole signature may already be known + out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(),(p)->(),()") + assert out.owner.op.signature == "(),(p)->(),()" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [0, 1] + + # A safe signature can be inferred from ndim_supp and ndims_params + out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[0, 1]) + assert out.owner.op.signature == "(),(i10)->(),()" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [0, 1] + + # Otherwise be default we assume everything is scalar, even though it's wrong in this case + out = CustomDist.dist([0.25, 0.75], dist=dist) + assert out.owner.op.signature == "(),()->(),()" + assert out.owner.op.ndim_supp == 0 + assert out.owner.op.ndims_params == [0, 0] + class TestSymbolicRandomVariable: def test_inline(self):