Skip to content

Commit

Permalink
Implement uncensor and forecast_timeseries model transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 20, 2023
1 parent 103025b commit 96fbc30
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 0 deletions.
175 changes: 175 additions & 0 deletions pymc_experimental/model_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from pymc import DiracDelta
from pymc.distributions.censored import CensoredRV
from pymc.distributions.timeseries import AR, AutoRegressiveRV
from pymc.model import Model
from pytensor import shared
from pytensor.graph import FunctionGraph, node_rewriter
from pytensor.graph.basic import get_var_by_name
from pytensor.graph.rewriting.basic import in2out

from pymc_experimental.utils.model_fgraph import (
ModelObservedRV,
ModelValuedVar,
fgraph_from_model,
model_free_rv,
model_from_fgraph,
model_named,
)

__all__ = (
"forecast_timeseries",
"uncensor",
)


@node_rewriter(tracks=[ModelValuedVar])
def uncensor_node_rewrite(fgraph, node):
"""Rewrite that replaces censored variables by uncensored ones"""

(
censored_rv,
value,
*dims,
) = node.inputs
if not isinstance(censored_rv.owner.op, CensoredRV):
return

model_rv = node.outputs[0]
base_rv = censored_rv.owner.inputs[0]
uncensored_rv = node.op.make_node(base_rv, value, *dims).default_output()
uncensored_rv.name = f"{model_rv.name}_uncensored"
return [uncensored_rv]


uncensor_rewrite = in2out(uncensor_node_rewrite)


def uncensor(model: Model) -> Model:
"""Replace censored variables in the model by uncensored equivalent.
Replaced variables have the same name as original ones with an additional "_uncensored" suffix.
.. code-block:: python
import arviz as az
import pymc as pm
from pymc_experimental.model_transform import uncensor
with pm.Model() as model:
x = pm.Normal("x")
dist_raw = pm.Normal.dist(x)
y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[-1, 0.5, 1, 1, 1])
idata = pm.sample()
with uncensor(model):
idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_uncensored"])
az.summary(idata_pp)
"""
fg = fgraph_from_model(model)

(_, nodes_changed, *_) = uncensor_rewrite.apply(fg)
if not nodes_changed:
raise RuntimeError("No censored variables were replaced by uncensored counterparts")

return model_from_fgraph(fg)


@node_rewriter(tracks=[ModelValuedVar])
def forecast_timeseries_node_rewrite(fgraph: FunctionGraph, node):
"""Rewrite that replaces timeseries variables by new ones starting at the last timepoint(s)."""

(
timeseries_rv,
value,
*dims,
) = node.inputs
if not isinstance(timeseries_rv.owner.op, AutoRegressiveRV):
return

forecast_steps = get_var_by_name(fgraph.inputs, "forecast_steps_")
if len(forecast_steps) != 1:
return False

forecast_steps = forecast_steps[0]

op = timeseries_rv.owner.op
model_rv = node.outputs[0]

# We cannot reference the variable we are planning to replace
# Or it will introduce circularities in the graph
# FIXME: This special logic shouldn't be needed for ObservedRVs
# but PyMC does not allow one to not resample observed.
# We hack around by conditioning on the value variable directly,
# even though that should not be part of the generative graph...
if isinstance(node.op, ModelObservedRV):
init_dist = DiracDelta.dist(value[..., -op.ar_order :])
else:
cloned_model_rv = model_rv.owner.clone().default_output()
fgraph.add_output(cloned_model_rv, import_missing=True)
init_dist = DiracDelta.dist(cloned_model_rv[..., -op.ar_order :])

if isinstance(timeseries_rv.owner.op, AutoRegressiveRV):
rhos, sigma, *_ = timeseries_rv.owner.inputs
new_timeseries_rv = AR.rv_op(
rhos=rhos,
sigma=sigma,
init_dist=init_dist,
steps=forecast_steps,
ar_order=op.ar_order,
constant_term=op.constant_term,
)

new_name = f"{model_rv.name}_forecast"
new_value = new_timeseries_rv.type(name=new_name)
new_timeseries_rv = model_free_rv(new_timeseries_rv, new_value, transform=None)
new_timeseries_rv.name = new_name

# Import new variables into fgraph (value and RNG)
fgraph.import_var(new_timeseries_rv, import_missing=True)

return [new_timeseries_rv]


forecast_timeseries_rewrite = in2out(forecast_timeseries_node_rewrite, ignore_newtrees=True)


def forecast_timeseries(model: Model, forecast_steps: int) -> Model:
"""Replace timeseries variables in the model by forecast that start at the last value.
Replaced variables have the same name as original ones with an additional "_forecast" suffix.
The function will fail if any variables with fixed static shape depend on the timeseries being replaced,
and forecast_steps differs from the original timeseries steps.
.. code-block:: python
import pymc as pm
from pymc_experimental.model_transform import forecast_timeseries
with pm.Model() as model:
rho = pm.Normal("rho")
sigma = pm.HalfNormal("sigma")
init_dist = pm.Normal.dist()
y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=np.zeros(100,))
idata = pm.sample()
forecast_model = forecast_timeseries(mode, forecast_steps=20)
with forecast_model:
idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_forecast"])
az.summary(idata_pp)
"""

fg = fgraph_from_model(model)

forecast_steps_sh = shared(forecast_steps, name="forecast_steps_")
forecast_steps_sh = model_named(forecast_steps_sh)
fg.add_output(forecast_steps_sh, import_missing=True)

(_, nodes_changed, *_) = forecast_timeseries_rewrite.apply(fg)
if not nodes_changed:
raise RuntimeError("No timeseries were replaced by forecast counterparts")

res = model_from_fgraph(fg)
return res
90 changes: 90 additions & 0 deletions pymc_experimental/tests/test_model_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import arviz as az
import numpy as np
import pymc as pm
import pytest

from pymc_experimental.model_transform import forecast_timeseries, uncensor


@pytest.mark.parametrize(
"transform, kwargs",
[
(uncensor, dict()),
(forecast_timeseries, dict(forecast_steps=20)),
],
)
def test_transform_error(transform, kwargs):
"""Test informative error is raised when the transform is not applicable to a model."""
with pm.Model() as model:
x = pm.Normal("x")
y = pm.Normal("y", x, observed=[0, 5, 10])

with pytest.raises(RuntimeError, match="No .* were replaced by .* counterparts"):
transform(model, **kwargs)


def test_uncensor():
with pm.Model() as model:
x = pm.Normal("x")
dist_raw = pm.Normal.dist(x)
y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[0, 5, 10])
det = pm.Deterministic("det", y * 2)

idata = az.from_dict({"x": np.zeros((2, 500))})

with uncensor(model):
pp = pm.sample_posterior_predictive(
idata,
var_names=["y_uncensored", "det"],
random_seed=18,
).posterior_predictive

assert np.any(np.abs(pp["y_uncensored"]) > 1)
np.testing.assert_allclose(pp["y_uncensored"] * 2, pp["det"])


@pytest.mark.parametrize("observed", (True, False))
@pytest.mark.parametrize("ar_order", (1, 2))
def test_forecast_timeseries_ar(observed, ar_order):
data_steps = 3
data = np.hstack((np.zeros(ar_order), (np.arange(data_steps) + 1) * 100.0))
with pm.Model() as model:
rho = pm.Normal("rho", shape=(ar_order,))
sigma = pm.HalfNormal("sigma")
init_dist = pm.Normal.dist(0, 1e-3)
y = pm.AR(
"y",
init_dist=init_dist,
rho=rho,
sigma=sigma,
observed=data if observed else None,
steps=data_steps,
)
det = pm.Deterministic("det", y * 2)

draws = (2, 50)
# These rhos mean that all steps will be data[-1] for ar_order > 1
idata_dict = {
"rho": np.full((*draws, ar_order), (0.1,) + (0,) * (ar_order - 1)),
"sigma": np.full(draws, 1e-5),
}
if observed:
idata = az.from_dict(idata_dict, observed_data={"y": data})
else:
idata_dict["y"] = np.full((*draws, len(data)), data)
idata = az.from_dict(idata_dict)

forecast_steps = 5
with forecast_timeseries(model, forecast_steps=forecast_steps):
pp = pm.sample_posterior_predictive(
idata,
var_names=["y_forecast", "det"],
random_seed=50,
).posterior_predictive

expected = data[-1] / np.logspace(0, forecast_steps, forecast_steps + 1)
expected = np.hstack((data[-ar_order:-1], expected))
np.testing.assert_allclose(
pp["y_forecast"].values, np.full((*draws, forecast_steps + ar_order), expected), rtol=0.01
)
np.testing.assert_allclose(pp["y_forecast"] * 2, pp["det"])

0 comments on commit 96fbc30

Please sign in to comment.