In [1]:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pymc_extras.statespace.filters import StandardFilter
from tests.statespace.utilities.test_helpers import make_test_inputs
from pytensor.graph.replace import vectorize_graph
from importlib import reload
import pymc_extras.statespace.filters.distributions as pmss_dist
from pymc_extras.statespace.filters.distributions import SequenceMvNormal
import pymc as pm

In [2]:
seed = sum(map(ord, "batched-kf"))
rng = np.random.default_rng(seed)

In [3]:
def create_batch_inputs(batch_size, p=1, m=5, r=1, n=10, rng=rng):
    """
    Create batched inputs for testing.

    Parameters
    ----------
    batch_size : int
        Number of batches to create
    p : int
        First dimension parameter
    m : int
        Second dimension parameter
    r : int
        Third dimension parameter
    n : int
        Fourth dimension parameter
    rng : numpy.random.Generator
        Random number generator

    Returns
    -------
    list
        List of stacked inputs for each batch
    """
    # Create individual inputs for each batch
    np_batch_inputs = []
    for i in range(batch_size):
        inputs = make_test_inputs(p, m, r, n, rng)
        np_batch_inputs.append(inputs)

    return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)]

In [4]:
# Create batch inputs with batch size 3
np_batch_inputs = create_batch_inputs(3)
np_batch_inputs[0].shape

(3, 10, 1)

In [5]:
p, m, r, n = 1, 5, 1, 10
inputs = [pt.as_tensor(x).type() for x in make_test_inputs(p, m, r, n, rng)]

In [6]:
kf = StandardFilter()
kf_outputs = kf.build_graph(*inputs)

In [7]:
batched_inputs = [pt.tensor(shape=(None, *x.type.shape)) for x in inputs]
vec_subs = dict(zip(inputs, batched_inputs))
bacthed_kf_outputs = vectorize_graph(kf_outputs, vec_subs)

In [8]:
kf_outputs

[filtered_states,
 predicted_states,
 observed_states,
 filtered_covariances,
 predicted_covariances,
 observed_covariances,
 loglike_obs]

In [9]:
mu = bacthed_kf_outputs[1]
cov = bacthed_kf_outputs[4]
logp = bacthed_kf_outputs[-1]

In [10]:
mu.type.shape

(None, 10, 5)

In [11]:
pmss_dist = reload(pmss_dist)

In [12]:
mv_outputs = pmss_dist.SequenceMvNormal.dist(mus=mu, covs=cov, logp=logp)

mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)
mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)
mvn_seq.type.shape: (None, None, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)
mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)
mvn_seq.type.shape: (None, None, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)


In [13]:
np_batch_inputs = create_batch_inputs(3)

In [14]:
np_batch_inputs[0] = rng.normal(size=(3, 10, 1))

In [15]:
f_test = pytensor.function(batched_inputs, mv_outputs)
f_test(*np_batch_inputs).shape

(3, 10, 5)

In [16]:
f_mv = pytensor.function(batched_inputs, pm.logp(mv_outputs, batched_inputs[0]))

(None, 10, 1) (None, 10, 5) (None, 10, 5, 5)


In [17]:
f_mv(*np_batch_inputs).shape

(3, 10)

In [18]:
f = pytensor.function(batched_inputs, bacthed_kf_outputs)

In [19]:
for s in [1, 3, 10]:
    np_batch_inputs = create_batch_inputs(s)
    %timeit outputs = f(*np_batch_inputs)

675 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.64 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
5.28 ms ± 424 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother

In [7]:
def build_fk(data, a0, P0, c, d, T, Z, R, H, Q):
    kf = StandardFilter()
    kf_outputs = kf.build_graph(data, a0, P0, c, d, T, Z, R, H, Q)

    ks = KalmanSmoother()
    ks_outputs = ks.build_graph(T, R, Q, kf_outputs[0], kf_outputs[3])

    return (*kf_outputs, *ks_outputs)

In [15]:
signature = "(t, o), (s), (s, s), (s), (o), (s, s), (o, s), (s, p), (o, o), (p, p) -> (t, s), (t, s), (t, o), (t, s, s), (t, s, s), (t, o, o), (t), (t, s), (t, s, s)"

In [17]:
pt.vectorize(build_fk, signature=signature)(*[pt.as_tensor(x) for x in np_batch_inputs])[
    0
].eval().shape

Join [id A]
 ├─ 0 [id B]
 ├─ Subtensor{::step} [id C]
 │  ├─ Subtensor{start:} [id D]
 │  │  ├─ Scan{kalman_smoother, while_loop=False, inplace=none}.0 [id E]
 │  │  │  ├─ Minimum [id F]
 │  │  │  │  ├─ Subtensor{i} [id G]
 │  │  │  │  │  ├─ Shape [id H]
 │  │  │  │  │  │  └─ Subtensor{::step} [id I]
 │  │  │  │  │  │     ├─ Subtensor{start:} [id J]
 │  │  │  │  │  │     │  ├─ Subtensor{:stop} [id K]
 │  │  │  │  │  │     │  │  ├─ SpecifyShape [id L] 'filtered_states'
 │  │  │  │  │  │     │  │  │  ├─ Scan{forward_kalman_pass, while_loop=False, inplace=none}.2 [id M]
 │  │  │  │  │  │     │  │  │  │  ├─ Subtensor{i} [id N]
 │  │  │  │  │  │     │  │  │  │  │  ├─ Shape [id O]
 │  │  │  │  │  │     │  │  │  │  │  │  └─ Subtensor{start:} [id P]
 │  │  │  │  │  │     │  │  │  │  │  │     ├─ <Matrix(float64, shape=(10, 1))> [id Q]
 │  │  │  │  │  │     │  │  │  │  │  │     └─ 0 [id R]
 │  │  │  │  │  │     │  │  │  │  │  └─ 0 [id S]
 │  │  │  │  │  │     │  │  │  │  ├─ Subtensor{:stop} [id 

(3, 10, 5)

In [15]:
def make_signature(inputs, outputs):
    states = "s"
    obs = "p"
    exog = "r"
    time = "t"

    matrix_to_shape = {
        "data": (time, obs),
        "a0": (states,),
        "P0": (states, states),
        "c": (states,),
        "d": (obs,),
        "T": (states, states),
        "Z": (obs, states),
        "R": (states, exog),
        "H": (obs, obs),
        "Q": (exog, exog),
        "filtered_states": (time, states),
        "filtered_covariances": (time, states, states),
        "predicted_states": (time, states),
        "predicted_covariances": (time, states, states),
        "observed_states": (time, obs),
        "observed_covariances": (time, obs, obs),
        "smoothed_states": (time, states),
        "smoothed_covariances": (time, states, states),
        "loglike_obs": (time,),
    }
    input_shapes = []
    output_shapes = []

    for matrix in inputs:
        name = matrix.name
        input_shapes.append(matrix_to_shape[name])

    for matrix in outputs:
        print(matrix, matrix.name)
        name = matrix.name
        output_shapes.append(matrix_to_shape[name])

    input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes])
    output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes])

    return f"{input_signature} -> {output_signature}"

In [9]:
floatX = "float64"
data = pt.tensor(name="data", dtype=floatX, shape=(None, None))
a0 = pt.vector(name="a0", dtype=floatX)
P0 = pt.matrix(name="P0", dtype=floatX)
c = pt.vector(name="c", dtype=floatX)
d = pt.vector(name="d", dtype=floatX)
Q = pt.tensor(name="Q", dtype=floatX, shape=(None, None, None))
H = pt.tensor(name="H", dtype=floatX, shape=(None, None, None))
T = pt.tensor(name="T", dtype=floatX, shape=(None, None, None))
R = pt.tensor(name="R", dtype=floatX, shape=(None, None, None))
Z = pt.tensor(name="Z", dtype=floatX, shape=(None, None, None))

inputs = [data, a0, P0, c, d, T, Z, R, H, Q]

In [10]:
outputs = build_fk(*inputs)

In [16]:
make_signature(inputs, outputs)

filtered_states filtered_states
predicted_states predicted_states
observed_states observed_states
filtered_covariances filtered_covariances
predicted_covariances predicted_covariances
observed_covariances observed_covariances
loglike_obs loglike_obs
smoothed_states smoothed_states
smoothed_covariances smoothed_covariances


'(t,p),(s),(s,s),(s),(p),(s,s),(p,s),(s,r),(p,p),(r,r) -> (t,s),(t,s),(t,p),(t,s,s),(t,s,s),(t,p,p),(t),(t,s),(t,s,s)'

In [None]:
signature = "(t, o), (s), (s, s), (s), (o), (s, s), (o, s), (s, p), (o, o), (p, p) -> (t, s), (t, s), (t, o), (t, s, s), (t, s, s), (t, o, o), (t), (t, s), (t, s, s)"

In [18]:
pt.vectorize(build_fk, signature=make_signature(inputs, outputs))(
    *[pt.as_tensor(x) for x in np_batch_inputs]
)[0].eval().shape

filtered_states filtered_states
predicted_states predicted_states
observed_states observed_states
filtered_covariances filtered_covariances
predicted_covariances predicted_covariances
observed_covariances observed_covariances
loglike_obs loglike_obs
smoothed_states smoothed_states
smoothed_covariances smoothed_covariances


(3, 10, 5)

In [19]:
kf = StandardFilter()
ks = KalmanSmoother()

In [20]:
kf_outputs = kf.build_graph(*inputs)
kf_signature = make_signature(inputs, kf_outputs)

filtered_states filtered_states
predicted_states predicted_states
observed_states observed_states
filtered_covariances filtered_covariances
predicted_covariances predicted_covariances
observed_covariances observed_covariances
loglike_obs loglike_obs


In [21]:
make_batched_kf = pt.vectorize(kf.build_graph, signature=kf_signature)
ks_inputs = [T, R, Q, kf_outputs[0], kf_outputs[3]]
ks_outputs = ks.build_graph(*ks_inputs)

In [22]:
ks_signature = make_signature(ks_inputs, ks_outputs)
make_batched_ks = pt.vectorize(ks.build_graph, signature=ks_signature)

smoothed_states smoothed_states
smoothed_covariances smoothed_covariances


In [25]:
batched_kf_outputs = make_batched_kf(*[pt.as_tensor(x) for x in np_batch_inputs])

In [26]:
data, a0, P0, c, d, T, Z, R, H, Q = np_batch_inputs

In [30]:
batched_ks_outputs = make_batched_ks(
    *[pt.as_tensor_variable(x) for x in [T, R, Q, batched_kf_outputs[0], batched_kf_outputs[3]]]
)

In [31]:
batched_ks_outputs[0].eval().shape

(3, 10, 5)

# Test example: French Presidents' Approval

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import pymc as pm
import pymc_extras.statespace as pmss
import pytensor.tensor as pt
import xarray as xr

In [3]:
data = pd.read_csv("popularite.csv", index_col=0).reset_index().rename(columns={"index": "date"})
data = data[data.president.isin(["chirac2", "sarkozy", "hollande", "macron", "macron2"])]
data["date"] = pd.to_datetime(data["date"])
data[["day", "month", "year"]] = data["date"].apply(lambda x: pd.Series([x.day, x.month, x.year]))
data

Unnamed: 0,date,president,sondage,samplesize,method,approve_pr,disapprove_pr,day,month,year
862,2002-05-15,chirac2,Ifop,924,phone,51.0,44.0,15,5,2002
863,2002-05-20,chirac2,Kantar,972,face to face,50.0,48.0,20,5,2002
864,2002-05-23,chirac2,BVA,1054,phone,52.0,37.0,23,5,2002
865,2002-05-26,chirac2,Ipsos,907,phone,48.0,48.0,26,5,2002
866,2002-06-16,chirac2,Ifop,974,phone,49.0,43.0,16,6,2002
...,...,...,...,...,...,...,...,...,...,...
2030,2022-05-21,macron2,Ifop,1946,phone&internet,41.0,58.0,21,5,2022
2031,2022-05-25,macron2,Odoxa,1005,internet,44.0,56.0,25,5,2022
2032,2022-05-26,macron2,Harris,1002,internet,49.0,51.0,26,5,2022
2033,2022-05-30,macron2,Kantar,1000,internet,37.0,56.0,30,5,2022


## Just for testing: average over months

In [4]:
agg = data.groupby(["president", "year", "month"]).mean(numeric_only=True).reset_index()
agg["date"] = pd.to_datetime(agg[["year", "month"]].assign(DAY=1))
agg["month_id"] = agg.groupby(["president"]).cumcount().to_numpy()
agg

Unnamed: 0,president,year,month,samplesize,approve_pr,disapprove_pr,day,date,month_id
0,chirac2,2002,5,964.250000,50.250000,44.250000,21.000000,2002-05-01,0
1,chirac2,2002,6,970.000000,50.500000,42.500000,20.000000,2002-06-01,1
2,chirac2,2002,7,947.333333,53.333333,40.666667,21.000000,2002-07-01,2
3,chirac2,2002,8,1028.000000,52.000000,41.666667,20.333333,2002-08-01,3
4,chirac2,2002,9,1017.500000,52.500000,42.000000,21.250000,2002-09-01,4
...,...,...,...,...,...,...,...,...,...
238,sarkozy,2011,12,990.666667,33.000000,65.000000,20.333333,2011-12-01,55
239,sarkozy,2012,1,1033.500000,31.750000,66.500000,19.750000,2012-01-01,56
240,sarkozy,2012,2,1000.500000,31.750000,66.500000,18.250000,2012-02-01,57
241,sarkozy,2012,3,1022.500000,35.000000,63.000000,20.250000,2012-03-01,58


In [5]:
presidents = agg.president.unique()

In [6]:
mod = pmss.structural.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += pmss.structural.AutoregressiveComponent(order=1)
mod += pmss.structural.MeasurementError(name="obs")

In [21]:
ss_mod = pmss.BayesianSARIMA(order=(3, 0, 0), batch_coords={"president": presidents})

[autoreload of cutils_ext failed: Traceback (most recent call last):
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 283, in check
    superreload(m, reload, self.old_objects)
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 483, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/importlib/__init__.py", line 130, in reload
    raise ModuleNotFoundError(f"spec not found for the module {name!r}", name=name)
ModuleNotFoundError: spec not found for the module 'cutils_ext'
]


In [7]:
ss_mod = mod.build(
    name="president",
    batch_coords={"president": presidents},  # this is gonna be leftmost dimension
)

In [8]:
ss_array = (
    agg.set_index(["president", "month_id"])["approve_pr"].unstack("month_id").to_numpy()[..., None]
)  # (president, timesteps, obs_dim)

In [9]:
initial_trend_dims, sigma_trend_dims, ar_param_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

In [16]:
coords

{'trend_state': ['level', 'trend'],
 'trend_shock': ['trend'],
 'ar_lag': [1],
 'state': ['level', 'trend', 'L1.data'],
 'state_aux': ['level', 'trend', 'L1.data'],
 'observed_state': ['president'],
 'observed_state_aux': ['president'],
 'shock': ['trend', 'AutoRegressive_innovation'],
 'shock_aux': ['trend', 'AutoRegressive_innovation']}

In [None]:
with pm.Model(coords=coords | ss_mod.batch_coords) as model_1:
    P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5, dims="president")
    P0 = pm.Deterministic(
        "P0", pt.eye(ss_mod.k_states)[None] * P0_diag[..., None, None], dims=("president", *P0_dims)
    )

    initial_trend = pm.Normal("initial_trend", dims=("president", *initial_trend_dims))
    ar_params = pm.Beta("ar_params", alpha=3, beta=3, dims=("president", *ar_param_dims))

    sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=("president", *sigma_trend_dims))
    sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, dims="president")
    sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, dims="president")

    ss_mod.build_statespace_graph(ss_array, mode="JAX")
    # idata = pm.sample_prior_predictive()
model_1.to_graphviz()

  raise ValueError(


mus_.type.shape: (None, None, 1), covs_.type.shape: (None, None, 1, 1)
mus.type.shape: (None, None, 1), covs.type.shape: (None, None, 1, 1)
mvn_seq.type.shape: (None, None, 1)
mvn_seq.type.shape: (None, None, 1)
mvn_seq.type.shape: (None, None, 1)
mvn_seq.type.shape: (None, None, 1)
mus_.type.shape: (None, None, 1), covs_.type.shape: (None, None, 1, 1)
mus.type.shape: (None, None, 1), covs.type.shape: (None, None, 1, 1)
mvn_seq.type.shape: (None, None, 1)
mvn_seq.type.shape: (None, None, 1)
mvn_seq.type.shape: (None, None, 1)
mvn_seq.type.shape: (None, None, 1)


ValueError: shape mismatch: value array of shape (5,) could not be broadcast to indexing result of shape (5,1)
Apply node that caused the error: AdvancedSetSubtensor(Alloc.0, Pow.0, SliceConstant{None, None, None}, 0, [0], [0])
Toposort index: 4
Inputs types: [TensorType(float64, shape=(None, 1, 1, 1)), TensorType(float64, shape=(None,)), <pytensor.tensor.type_other.SliceType object at 0x17ee98740>, TensorType(int64, shape=()), TensorType(int64, shape=(1,)), TensorType(int64, shape=(1,))]
Inputs shapes: [(5, 1, 1, 1), (5,), 'No shapes', (), (1,), (1,)]
Inputs strides: [(8, 8, 8, 8), (8,), 'No strides', (), (8,), (8,)]
Inputs values: [array([[[[0.]]],


       [[[0.]]],


       [[[0.]]],


       [[[0.]]],


       [[[0.]]]]), array([0.08827229, 0.05365894, 1.10456341, 0.32252835, 0.00666647]), slice(None, None, None), array(0), array([0]), array([0])]
Outputs clients: [[Squeeze{axis=1}(AdvancedSetSubtensor.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3607, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3667, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/004t4lgj4_1bc491p3qg41tc0000gp/T/ipykernel_87868/2589321262.py", line 17, in <module>
    ss_mod.build_statespace_graph(ss_array, mode="JAX")
  File "<string>", line 49, in build_statespace_graph
  File "/Users/aandorra/open-source/pymc-extras/pymc_extras/statespace/core/statespace.py", line 734, in _insert_random_variables
    self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict)
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/site-packages/pytensor/graph/replace.py", line 301, in vectorize_graph
    vect_node = vectorize_node(node, *vect_inputs)
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/site-packages/pytensor/graph/replace.py", line 217, in vectorize_node
    return _vectorize_node(op, node, *batched_inputs)
  File "/Users/aandorra/miniforge3/envs/pymc-extras-test/lib/python3.12/functools.py", line 912, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

## Real use case

In [None]:
data["month_id"] = np.hstack(
    [
        pd.Categorical(data[data.president == president]["date"].dt.to_period("M")).codes
        for president in data.president.unique()
    ]
)
months = np.arange(max(data["month_id"]) + 1)

In [83]:
data

Unnamed: 0,date,president,sondage,samplesize,method,approve_pr,disapprove_pr,day,month,year,month_id
862,2002-05-15,chirac2,Ifop,924,phone,51.0,44.0,15,5,2002,0
863,2002-05-20,chirac2,Kantar,972,face to face,50.0,48.0,20,5,2002,0
864,2002-05-23,chirac2,BVA,1054,phone,52.0,37.0,23,5,2002,0
865,2002-05-26,chirac2,Ipsos,907,phone,48.0,48.0,26,5,2002,0
866,2002-06-16,chirac2,Ifop,974,phone,49.0,43.0,16,6,2002,1
...,...,...,...,...,...,...,...,...,...,...,...
2030,2022-05-21,macron2,Ifop,1946,phone&internet,41.0,58.0,21,5,2022,1
2031,2022-05-25,macron2,Odoxa,1005,internet,44.0,56.0,25,5,2022,1
2032,2022-05-26,macron2,Harris,1002,internet,49.0,51.0,26,5,2022,1
2033,2022-05-30,macron2,Kantar,1000,internet,37.0,56.0,30,5,2022,1


In [None]:
COORDS = {
    "month": months,
    # each observation is uniquely identified by (pollster, field_date):
    "observation": data.set_index(["sondage", "date"]).index,
}

In [None]:
mod = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += st.AutoregressiveComponent(order=1)
mod += st.MeasurementError(name="obs")

In [None]:
ss_mod = mod.build(
    name="nile",
    batch_coords={"president": presidents},  # this is gonna be leftmost dimension
)

In [None]:
initial_trend_dims, sigma_trend_dims, ar_param_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

In [None]:
with pm.Model(coords=coords) as model_1:
    P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5, dims="president")
    P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states) * P0_diag, dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=initial_trend_dims)
    ar_params = pm.Beta("ar_params", alpha=3, beta=3, dims=ar_param_dims)

    sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=sigma_trend_dims)
    sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, dims="president")
    sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, dims="president")

    ss_mod.build_statespace_graph(nile, mode="JAX")
    idata = pm.sample(**sampler_kwargs())