Skip to content

Scan logprob fails when unvalued stochastic outputs are returned #6909

@ricardoV94

Description

@ricardoV94

Description

Found by @lucianopaz

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.pytensorf import collect_default_updates

steps = 4

def ar_dist1(rho, sigma, size):
    def ar_step(x_tm1, rho, sigma):
        eps_t = pm.Normal.dist(sigma=sigma)
        mu = x_tm1 * rho
        x = mu + eps_t
        return x, collect_default_updates([x])

    ar_innov, _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": pt.zeros(()), "taps": [-1]}],
        non_sequences=[rho, sigma],
        n_steps=steps,
        strict=True,
    )

    return ar_innov


def ar_dist2(rho, sigma, size):
    def ar_step(x_tm1, rho, sigma):
        eps_t = pm.Normal.dist(sigma=sigma)
        mu = x_tm1 * rho
        x = mu + eps_t
        return [x, eps_t], collect_default_updates([x])

    [ar_innov, _], _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": pt.zeros(()), "taps": [-1]}, None],
        non_sequences=[rho, sigma],
        n_steps=steps,
        strict=True,
    )

    return ar_innov


with pm.Model() as m:
    rho = 0.1
    sigma = 0.1
    observed = np.arange(steps)

    pm.CustomDist(
        "ar_dist1",
        rho,
        sigma,
        dist=ar_dist1,
        observed=observed,
    )

    pm.CustomDist(
        "ar_dist2",
        rho,
        sigma,
        dist=ar_dist2,
        observed=observed,
    )

logp1, logp2 = m.compile_logp(sum=False)({})
np.testing.assert_allclose(logp1, logp2)
"""
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 2 / 4 (50%)
Max absolute difference: 58.
Max relative difference: 0.12928641
 x: array([   1.383647,  -48.616353, -179.116353, -390.616353])
 y: array([   1.383647,  -48.616353, -198.616353, -448.616353])
"""

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions