-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
Description
Found by @lucianopaz
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.pytensorf import collect_default_updates
steps = 4
def ar_dist1(rho, sigma, size):
def ar_step(x_tm1, rho, sigma):
eps_t = pm.Normal.dist(sigma=sigma)
mu = x_tm1 * rho
x = mu + eps_t
return x, collect_default_updates([x])
ar_innov, _ = pytensor.scan(
fn=ar_step,
outputs_info=[{"initial": pt.zeros(()), "taps": [-1]}],
non_sequences=[rho, sigma],
n_steps=steps,
strict=True,
)
return ar_innov
def ar_dist2(rho, sigma, size):
def ar_step(x_tm1, rho, sigma):
eps_t = pm.Normal.dist(sigma=sigma)
mu = x_tm1 * rho
x = mu + eps_t
return [x, eps_t], collect_default_updates([x])
[ar_innov, _], _ = pytensor.scan(
fn=ar_step,
outputs_info=[{"initial": pt.zeros(()), "taps": [-1]}, None],
non_sequences=[rho, sigma],
n_steps=steps,
strict=True,
)
return ar_innov
with pm.Model() as m:
rho = 0.1
sigma = 0.1
observed = np.arange(steps)
pm.CustomDist(
"ar_dist1",
rho,
sigma,
dist=ar_dist1,
observed=observed,
)
pm.CustomDist(
"ar_dist2",
rho,
sigma,
dist=ar_dist2,
observed=observed,
)
logp1, logp2 = m.compile_logp(sum=False)({})
np.testing.assert_allclose(logp1, logp2)
"""
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 2 / 4 (50%)
Max absolute difference: 58.
Max relative difference: 0.12928641
x: array([ 1.383647, -48.616353, -179.116353, -390.616353])
y: array([ 1.383647, -48.616353, -198.616353, -448.616353])
"""