In [1]:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pymc_extras.statespace.filters import StandardFilter
from tests.statespace.utilities.test_helpers import make_test_inputs
from pytensor.graph.replace import vectorize_graph
from importlib import reload
import pymc_extras.statespace.filters.distributions as pmss_dist
from pymc_extras.statespace.filters.distributions import SequenceMvNormal
import pymc as pm

In [2]:
seed = sum(map(ord, "batched-kf"))
rng = np.random.default_rng(seed)

In [3]:
def create_batch_inputs(batch_size, p=1, m=5, r=1, n=10, rng=rng):
    """
    Create batched inputs for testing.

    Parameters
    ----------
    batch_size : int
        Number of batches to create
    p : int
        First dimension parameter
    m : int
        Second dimension parameter
    r : int
        Third dimension parameter
    n : int
        Fourth dimension parameter
    rng : numpy.random.Generator
        Random number generator

    Returns
    -------
    list
        List of stacked inputs for each batch
    """
    # Create individual inputs for each batch
    np_batch_inputs = []
    for i in range(batch_size):
        inputs = make_test_inputs(p, m, r, n, rng)
        np_batch_inputs.append(inputs)

    return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)]

In [4]:
# Create batch inputs with batch size 3
np_batch_inputs = create_batch_inputs(3)
np_batch_inputs[0].shape

(3, 10, 1)

In [5]:
p, m, r, n = 1, 5, 1, 10
inputs = [pt.as_tensor(x).type() for x in make_test_inputs(p, m, r, n, rng)]

In [6]:
kf = StandardFilter()
kf_outputs = kf.build_graph(*inputs)

In [7]:
batched_inputs = [pt.tensor(shape=(None, *x.type.shape)) for x in inputs]
vec_subs = dict(zip(inputs, batched_inputs))
bacthed_kf_outputs = vectorize_graph(kf_outputs, vec_subs)

In [8]:
kf_outputs

[filtered_states,
 predicted_states,
 observed_states,
 filtered_covariances,
 predicted_covariances,
 observed_covariances,
 loglike_obs]

In [9]:
mu = bacthed_kf_outputs[1]
cov = bacthed_kf_outputs[4]
logp = bacthed_kf_outputs[-1]

In [10]:
mu.type.shape

(None, 10, 5)

In [20]:
pmss_dist = reload(pmss_dist)

In [21]:
mv_outputs = pmss_dist.SequenceMvNormal.dist(mus=mu, covs=cov, logp=logp)

mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)
mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)
mvn_seq.type.shape: (None, None, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)
mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)
mvn_seq.type.shape: (None, None, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)
mvn_seq.type.shape: (None, 10, 5)


In [22]:
np_batch_inputs = create_batch_inputs(3)

In [23]:
np_batch_inputs[0] = rng.normal(size=(3, 10, 1))

In [24]:
f_test = pytensor.function(batched_inputs, mv_outputs)
f_test(*np_batch_inputs).shape

(3, 10, 5)

In [25]:
f_mv = pytensor.function(batched_inputs, pm.logp(mv_outputs, batched_inputs[0]))

(None, 10, 1) (None, 10, 5) (None, 10, 5, 5)


In [26]:
f_mv(*np_batch_inputs).shape

(3, 10)

In [27]:
f = pytensor.function(batched_inputs, bacthed_kf_outputs)

In [28]:
for s in [1, 3, 10]:
    np_batch_inputs = create_batch_inputs(s)
    %timeit outputs = f(*np_batch_inputs)

633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
