From e71d1cb86fe1589720395664f5d3557f291810b1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Jun 2024 15:06:46 +0200 Subject: [PATCH] Make DiracDelta a SymbolicRV --- pymc/distributions/distribution.py | 32 ++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 79ae04ac11..e5c2e68684 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -1048,7 +1048,11 @@ def change_custom_dist_size(op, rv, new_size, expand): return new_rv - rngs, rngs_updates = zip(*dummy_updates_dict.items()) + if dummy_updates_dict: + rngs, rngs_updates = zip(*dummy_updates_dict.items()) + else: + rngs, rngs_updates = (), () + inputs = [*dummy_params, *rngs] outputs = [dummy_rv, *rngs_updates] signature = cls._infer_final_signature( @@ -1497,19 +1501,26 @@ def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False ) -class DiracDeltaRV(RandomVariable): +class DiracDeltaRV(SymbolicRandomVariable): name = "diracdelta" - signature = "()->()" + extended_signature = "[size],()->()" _print_name = ("DiracDelta", "\\operatorname{DiracDelta}") + def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: + # Because the distribution does not have RNGs we have to prevent constant-folding + return False + @classmethod - def rng_fn(cls, rng, c, size=None): - if size is None: - return c.copy() - return np.full(size, c) + def rv_op(cls, c, *, size=None, rng=None): + size = normalize_size_param(size) + c = pt.as_tensor(c) + if rv_size_is_none(size): + out = c.copy() + else: + out = pt.full(size, c) -diracdelta = DiracDeltaRV() + return cls(inputs=[size, c], outputs=[out])(size, c) class DiracDelta(Discrete): @@ -1524,14 +1535,15 @@ class DiracDelta(Discrete): that use DiracDelta, such as Mixtures. """ - rv_op = diracdelta + rv_type = DiracDeltaRV + rv_op = DiracDeltaRV.rv_op @classmethod def dist(cls, c, *args, **kwargs): c = pt.as_tensor_variable(c) if c.dtype in continuous_types: c = floatX(c) - return super().dist([c], dtype=c.dtype, **kwargs) + return super().dist([c], **kwargs) def support_point(rv, size, c): if not rv_size_is_none(size):