diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 64def0c2..a4f5ffa4 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -335,6 +335,7 @@ class VariableFactory(Protocol): class PowerSumDistribution: """Create a distribution that is the sum of powers of a base distribution.""" + def __init__(self, distribution: VariableFactory, n: int): self.distribution = distribution self.n = n @@ -345,7 +346,12 @@ def dims(self): def create_variable(self, name: str) -> "TensorVariable": raw = self.distribution.create_variable(f"{name}_raw") - return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,) + return pm.Deterministic( + name, + pt.sum([raw**n for n in range(1, self.n + 1)], axis=0), + dims=self.dims, + ) + cubic = PowerSumDistribution(Prior("Normal"), n=3) samples = sample_prior(cubic) @@ -533,8 +539,10 @@ class Prior: from pymc_extras.prior import register_tensor_transform + def custom_transform(x): - return x ** 2 + return x**2 + register_tensor_transform("square", custom_transform)