Skip to content

Commit

Permalink
Make DiracDelta a SymbolicRV
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 21, 2024
1 parent 0d6caed commit a6fe6eb
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,11 @@ def change_custom_dist_size(op, rv, new_size, expand):

return new_rv

rngs, rngs_updates = zip(*dummy_updates_dict.items())
if dummy_updates_dict:
rngs, rngs_updates = zip(*dummy_updates_dict.items())
else:
rngs, rngs_updates = (), ()

inputs = [*dummy_params, *rngs]
outputs = [dummy_rv, *rngs_updates]
signature = cls._infer_final_signature(
Expand Down Expand Up @@ -1497,19 +1501,26 @@ def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False
)


class DiracDeltaRV(RandomVariable):
class DiracDeltaRV(SymbolicRandomVariable):
name = "diracdelta"
signature = "()->()"
extended_signature = "[size],()->()"
_print_name = ("DiracDelta", "\\operatorname{DiracDelta}")

def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
# Because the distribution does not have RNGs we have to prevent constant-folding
return False

@classmethod
def rng_fn(cls, rng, c, size=None):
if size is None:
return c.copy()
return np.full(size, c)
def rv_op(cls, c, *, size=None, rng=None):
size = normalize_size_param(size)
c = pt.as_tensor(c)

if rv_size_is_none(size):
out = c.copy()
else:
out = pt.full(size, c)

diracdelta = DiracDeltaRV()
return cls(inputs=[size, c], outputs=[out])(size, c)


class DiracDelta(Discrete):
Expand All @@ -1524,14 +1535,15 @@ class DiracDelta(Discrete):
that use DiracDelta, such as Mixtures.
"""

rv_op = diracdelta
rv_type = DiracDeltaRV
rv_op = DiracDeltaRV.rv_op

@classmethod
def dist(cls, c, *args, **kwargs):
c = pt.as_tensor_variable(c)
if c.dtype in continuous_types:
c = floatX(c)
return super().dist([c], dtype=c.dtype, **kwargs)
return super().dist([c], **kwargs)

def support_point(rv, size, c):
if not rv_size_is_none(size):
Expand Down

0 comments on commit a6fe6eb

Please sign in to comment.