In [1]:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pymc_extras.statespace.filters import StandardFilter
from tests.statespace.utilities.test_helpers import make_test_inputs
from pytensor.graph.replace import vectorize_graph
from importlib import reload
import pymc_extras.statespace.filters.distributions as pmss_dist
from pymc_extras.statespace.filters.distributions import SequenceMvNormal
import pymc as pm

In [2]:
seed = sum(map(ord, "batched-kf"))
rng = np.random.default_rng(seed)

In [3]:
def create_batch_inputs(batch_size, p=1, m=5, r=1, n=10, rng=rng):
    """
    Create batched inputs for testing.

    Parameters
    ----------
    batch_size : int
        Number of batches to create
    p : int
        First dimension parameter
    m : int
        Second dimension parameter
    r : int
        Third dimension parameter
    n : int
        Fourth dimension parameter
    rng : numpy.random.Generator
        Random number generator

    Returns
    -------
    list
        List of stacked inputs for each batch
    """
    # Create individual inputs for each batch
    np_batch_inputs = []
    for i in range(batch_size):
        inputs = make_test_inputs(p, m, r, n, rng)
        np_batch_inputs.append(inputs)

    return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)]

In [4]:
# Create batch inputs with batch size 3
np_batch_inputs = create_batch_inputs(3)
np_batch_inputs[0].shape

(3, 10, 1)

In [5]:
p, m, r, n = 1, 5, 1, 10
inputs = [pt.as_tensor(x).type() for x in make_test_inputs(p, m, r, n, rng)]

In [6]:
kf = StandardFilter()
kf_outputs = kf.build_graph(*inputs)

In [7]:
batched_inputs = [pt.tensor(shape=(None, *x.type.shape)) for x in inputs]
vec_subs = dict(zip(inputs, batched_inputs))
bacthed_kf_outputs = vectorize_graph(kf_outputs, vec_subs)

In [8]:
kf_outputs

[filtered_states,
 predicted_states,
 observed_states,
 filtered_covariances,
 predicted_covariances,
 observed_covariances,
 loglike_obs]

In [9]:
mu = bacthed_kf_outputs[1]
cov = bacthed_kf_outputs[4]
logp = bacthed_kf_outputs[-1]

In [10]:
mu.type.shape

(None, 10, 5)

In [11]:
pmss_dist = reload(pmss_dist)

In [12]:
mv_outputs = pmss_dist.SequenceMvNormal.dist(mus=mu, covs=cov, logp=logp)

mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)
mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)
mvn_seq.type.shape: (None, None, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)
mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)
mvn_seq.type.shape: (None, None, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)


In [13]:
np_batch_inputs = create_batch_inputs(3)

In [14]:
np_batch_inputs[0] = rng.normal(size=(3, 10, 1))

In [15]:
f_test = pytensor.function(batched_inputs, mv_outputs)
f_test(*np_batch_inputs).shape

(3, 10, 5)

In [16]:
f_mv = pytensor.function(batched_inputs, pm.logp(mv_outputs, batched_inputs[0]))

(None, 10, 1) (None, 10, 5) (None, 10, 5, 5)


In [17]:
f_mv(*np_batch_inputs).shape

(3, 10)

In [18]:
f = pytensor.function(batched_inputs, bacthed_kf_outputs)

In [19]:
for s in [1, 3, 10]:
    np_batch_inputs = create_batch_inputs(s)
    %timeit outputs = f(*np_batch_inputs)

675 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.64 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
5.28 ms ± 424 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother

In [7]:
def build_fk(data, a0, P0, c, d, T, Z, R, H, Q):
    kf = StandardFilter()
    kf_outputs = kf.build_graph(data, a0, P0, c, d, T, Z, R, H, Q)

    ks = KalmanSmoother()
    ks_outputs = ks.build_graph(T, R, Q, kf_outputs[0], kf_outputs[3])

    return (*kf_outputs, *ks_outputs)

In [15]:
signature = "(t, o), (s), (s, s), (s), (o), (s, s), (o, s), (s, p), (o, o), (p, p) -> (t, s), (t, s), (t, o), (t, s, s), (t, s, s), (t, o, o), (t), (t, s), (t, s, s)"

In [17]:
pt.vectorize(build_fk, signature=signature)(*[pt.as_tensor(x) for x in np_batch_inputs])[
    0
].eval().shape

Join [id A]
 ├─ 0 [id B]
 ├─ Subtensor{::step} [id C]
 │  ├─ Subtensor{start:} [id D]
 │  │  ├─ Scan{kalman_smoother, while_loop=False, inplace=none}.0 [id E]
 │  │  │  ├─ Minimum [id F]
 │  │  │  │  ├─ Subtensor{i} [id G]
 │  │  │  │  │  ├─ Shape [id H]
 │  │  │  │  │  │  └─ Subtensor{::step} [id I]
 │  │  │  │  │  │     ├─ Subtensor{start:} [id J]
 │  │  │  │  │  │     │  ├─ Subtensor{:stop} [id K]
 │  │  │  │  │  │     │  │  ├─ SpecifyShape [id L] 'filtered_states'
 │  │  │  │  │  │     │  │  │  ├─ Scan{forward_kalman_pass, while_loop=False, inplace=none}.2 [id M]
 │  │  │  │  │  │     │  │  │  │  ├─ Subtensor{i} [id N]
 │  │  │  │  │  │     │  │  │  │  │  ├─ Shape [id O]
 │  │  │  │  │  │     │  │  │  │  │  │  └─ Subtensor{start:} [id P]
 │  │  │  │  │  │     │  │  │  │  │  │     ├─ <Matrix(float64, shape=(10, 1))> [id Q]
 │  │  │  │  │  │     │  │  │  │  │  │     └─ 0 [id R]
 │  │  │  │  │  │     │  │  │  │  │  └─ 0 [id S]
 │  │  │  │  │  │     │  │  │  │  ├─ Subtensor{:stop} [id 

(3, 10, 5)

In [15]:
def make_signature(inputs, outputs):
    states = "s"
    obs = "p"
    exog = "r"
    time = "t"

    matrix_to_shape = {
        "data": (time, obs),
        "a0": (states,),
        "P0": (states, states),
        "c": (states,),
        "d": (obs,),
        "T": (states, states),
        "Z": (obs, states),
        "R": (states, exog),
        "H": (obs, obs),
        "Q": (exog, exog),
        "filtered_states": (time, states),
        "filtered_covariances": (time, states, states),
        "predicted_states": (time, states),
        "predicted_covariances": (time, states, states),
        "observed_states": (time, obs),
        "observed_covariances": (time, obs, obs),
        "smoothed_states": (time, states),
        "smoothed_covariances": (time, states, states),
        "loglike_obs": (time,),
    }
    input_shapes = []
    output_shapes = []

    for matrix in inputs:
        name = matrix.name
        input_shapes.append(matrix_to_shape[name])

    for matrix in outputs:
        print(matrix, matrix.name)
        name = matrix.name
        output_shapes.append(matrix_to_shape[name])

    input_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in input_shapes])
    output_signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in output_shapes])

    return f"{input_signature} -> {output_signature}"

In [9]:
floatX = "float64"
data = pt.tensor(name="data", dtype=floatX, shape=(None, None))
a0 = pt.vector(name="a0", dtype=floatX)
P0 = pt.matrix(name="P0", dtype=floatX)
c = pt.vector(name="c", dtype=floatX)
d = pt.vector(name="d", dtype=floatX)
Q = pt.tensor(name="Q", dtype=floatX, shape=(None, None, None))
H = pt.tensor(name="H", dtype=floatX, shape=(None, None, None))
T = pt.tensor(name="T", dtype=floatX, shape=(None, None, None))
R = pt.tensor(name="R", dtype=floatX, shape=(None, None, None))
Z = pt.tensor(name="Z", dtype=floatX, shape=(None, None, None))

inputs = [data, a0, P0, c, d, T, Z, R, H, Q]

In [10]:
outputs = build_fk(*inputs)

In [16]:
make_signature(inputs, outputs)

filtered_states filtered_states
predicted_states predicted_states
observed_states observed_states
filtered_covariances filtered_covariances
predicted_covariances predicted_covariances
observed_covariances observed_covariances
loglike_obs loglike_obs
smoothed_states smoothed_states
smoothed_covariances smoothed_covariances


'(t,p),(s),(s,s),(s),(p),(s,s),(p,s),(s,r),(p,p),(r,r) -> (t,s),(t,s),(t,p),(t,s,s),(t,s,s),(t,p,p),(t),(t,s),(t,s,s)'

In [None]:
signature = "(t, o), (s), (s, s), (s), (o), (s, s), (o, s), (s, p), (o, o), (p, p) -> (t, s), (t, s), (t, o), (t, s, s), (t, s, s), (t, o, o), (t), (t, s), (t, s, s)"

In [18]:
pt.vectorize(build_fk, signature=make_signature(inputs, outputs))(
    *[pt.as_tensor(x) for x in np_batch_inputs]
)[0].eval().shape

filtered_states filtered_states
predicted_states predicted_states
observed_states observed_states
filtered_covariances filtered_covariances
predicted_covariances predicted_covariances
observed_covariances observed_covariances
loglike_obs loglike_obs
smoothed_states smoothed_states
smoothed_covariances smoothed_covariances


(3, 10, 5)

In [19]:
kf = StandardFilter()
ks = KalmanSmoother()

In [20]:
kf_outputs = kf.build_graph(*inputs)
kf_signature = make_signature(inputs, kf_outputs)

filtered_states filtered_states
predicted_states predicted_states
observed_states observed_states
filtered_covariances filtered_covariances
predicted_covariances predicted_covariances
observed_covariances observed_covariances
loglike_obs loglike_obs


In [21]:
make_batched_kf = pt.vectorize(kf.build_graph, signature=kf_signature)
ks_inputs = [T, R, Q, kf_outputs[0], kf_outputs[3]]
ks_outputs = ks.build_graph(*ks_inputs)

In [22]:
ks_signature = make_signature(ks_inputs, ks_outputs)
make_batched_ks = pt.vectorize(ks.build_graph, signature=ks_signature)

smoothed_states smoothed_states
smoothed_covariances smoothed_covariances


In [25]:
batched_kf_outputs = make_batched_kf(*[pt.as_tensor(x) for x in np_batch_inputs])

In [26]:
data, a0, P0, c, d, T, Z, R, H, Q = np_batch_inputs

In [30]:
batched_ks_outputs = make_batched_ks(
    *[pt.as_tensor_variable(x) for x in [T, R, Q, batched_kf_outputs[0], batched_kf_outputs[3]]]
)

In [31]:
batched_ks_outputs[0].eval().shape

(3, 10, 5)