# Imports

In [1]:
import numpy as np
import jax.numpy as jnp
from numba import jit
import numba
import pytensor
import pytensor.tensor as pt
import timeit
import jax
import math
from jax.scipy.special import gammaln
from functools import partial

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

In [2]:
print("pytensor version:", pytensor.__version__)
print("jax version:", jax.__version__)
print("numba version:", numba.__version__)

pytensor version: 0+untagged.31335.g4312d8c
jax version: 0.7.2
numba version: 0.62.1


In [3]:
class Benchmarker:
    """
    Benchmark a set of functions by timing execution and summarizing statistics.

    Parameters
    ----------
    functions : list of callables
        List of callables to benchmark.
    names : list of str, optional
        Names corresponding to each function. Default is ['func_0', 'func_1', ...].
    number : int or None, optional
        Number of loops per timing. If None, auto-calibrated via Timer.autorange().
        Default is None.
    repeat : int, optional
        Number of repeats for timing. Default is 7.
    target_time : float, optional
        Target duration in seconds for auto-calibration. Default is 0.2.

    Attributes
    ----------
    results : dict
        Mapping from function names to a dict with keys:
        - 'raw_us': numpy.ndarray of raw timings in microseconds
        - 'loops': number of loops used per timing

    Methods
    -------
    run()
        Auto-calibrate (if needed) and run timings for all functions.
    summary(unit='us') -> pandas.DataFrame
        Return a summary DataFrame with statistics converted to the given unit.
    raw(name=None) -> dict or numpy.ndarray
        Return raw timing data in microseconds for a specific function or all.
    _convert_times(times, unit) -> numpy.ndarray
        Convert an array of times from microseconds to the specified unit.
    """

    def __init__(
        self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2
    ):
        self.functions = functions
        self.names = names or [f"func_{i}" for i in range(len(functions))]
        self.number = number
        self.min_rounds = min_rounds
        self.max_time = max_time
        self.target_time = target_time
        self.results = {}

    def run(self, inputs: dict[str, dict]):
        """
        Auto-calibrate loop count and sample rounds if needed, then time each function.
        """
        for name, func in zip(self.names, self.functions):
            for input_name, kwargs in inputs.items():
                timer = timeit.Timer(partial(func, **kwargs))

                # Calibrate loops
                if self.number is None:
                    loops, calib_time = timer.autorange()
                else:
                    loops = self.number
                    calib_time = timer.timeit(number=loops)

                # Determine rounds based on max_time and min_rounds
                if self.max_time is not None:
                    rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))
                else:
                    rounds = self.min_rounds

                raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))

                # Convert to microseconds per single execution
                raw_us = raw_round_times / loops * 1e6

                self.results[(name, input_name)] = {
                    "raw_us": raw_us,
                    "loops": loops,
                    "rounds": rounds,
                }

    def summary(self, unit="us"):
        """
        Summarize benchmark statistics in a DataFrame.

        Parameters
        ----------
        unit : {'us', 'ms', 'ns', 's'}, optional
            Unit for output times. 'us' means microseconds, 'ms' milliseconds,
            'ns' nanoseconds, 's' seconds. Default is 'us'.

        Returns
        -------
        pandas.DataFrame
            Summary with columns:
            Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),
            OPS (Kops/unit), Samples.
        """
        records = []
        indexes = []
        for name, data in self.results.items():
            raw_us = data["raw_us"]
            # Convert to target unit
            times = self._convert_times(raw_us, unit)
            if isinstance(name, tuple) and len(name) > 1:
                indexes.append(name)
            elif isinstance(name, tuple) and len(name) == 1:
                indexes.append(name[0])
            else:
                indexes.append(name)

            stats = {
                "Loops": data["loops"],
                f"Min ({unit})": np.min(times),
                f"Max ({unit})": np.max(times),
                f"Mean ({unit})": np.mean(times),
                f"StdDev ({unit})": np.std(times),
                f"Median ({unit})": np.median(times),
                f"IQR ({unit})": np.percentile(times, 75) - np.percentile(times, 25),
                "OPS (Kops/s)": 1e3 / (np.mean(raw_us)),
                "Samples": len(raw_us),
            }
            records.append(stats)

        if all(isinstance(idx, tuple) for idx in indexes):
            index = pd.MultiIndex.from_tuples(indexes)
        else:
            index = pd.Index(indexes)
        return pd.DataFrame(records, index=index)

    def raw(self, name=None):
        """
        Get raw timing data in microseconds.

        Parameters
        ----------
        name : str, optional
            If given, returns the raw_us array for that function. Otherwise returns
            a dict of all raw results.

        Returns
        -------
        numpy.ndarray or dict
        """
        if name:
            return self.results.get(name, {}).get("raw_us")
        return {n: d["raw_us"] for n, d in self.results.items()}

    def _convert_times(self, times, unit):
        """
        Convert an array of times from microseconds to the specified unit.

        Parameters
        ----------
        times : array-like
            Times in microseconds.
        unit : {'us', 'ms', 'ns', 's'}
            Target unit: 'us' microseconds, 'ms' milliseconds,
            'ns' nanoseconds, 's' seconds.

        Returns
        -------
        numpy.ndarray
            Converted times.

        Raises
        ------
        ValueError
            If `unit` is not one of the supported options.
        """
        unit = unit.lower()
        if unit == "us":
            factor = 1.0
        elif unit == "ms":
            factor = 1e-3
        elif unit == "ns":
            factor = 1e3
        elif unit == "s":
            factor = 1e-6
        else:
            raise ValueError(f"Unsupported unit: {unit}")
        return times * factor

In [4]:
# Set Pytensor to use float32
pytensor.config.floatX = "float32"

# Introduction

# Baby Steps

## Fibonacci Algorithm

In [5]:
# Pytensor creates functions itself
n_symbolic = pt.iscalar("n")

def step(a, b):
    return a + b, a

(outputs_a, outputs_b), _ = pytensor.scan(
    fn=step,
    outputs_info=[pt.constant(1.0), pt.constant(1.0)],
    n_steps=n_symbolic
)

# compile function returning final a
fibonacci_pytensor = pytensor.function([n_symbolic], outputs_a[-1], trust_input=True)
fibonacci_pytensor_numba = pytensor.function([n_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)


In [6]:
@jit(nopython=True)
def fibonacci_numba(n):
    a = np.ones(1, dtype=np.int32)
    b = np.ones(1, dtype=np.int32)
    for _ in range(n):
        a[0], b[0] = a[0] + b[0], a[0]
    return a[0]

In [7]:
# This is faster than running a scan or a fori_loop
@partial(jax.jit, static_argnums=0)
def fibonacci_jax(n):
    a, b = jnp.array(1, dtype=np.int32), jnp.array(1, dtype=np.int32)
    for _ in range(n):
        a, b = a + b, a
    return a

In [8]:
fibonacci_bench = Benchmarker(
    functions=[fibonacci_pytensor, fibonacci_numba, jax.block_until_ready(fibonacci_jax), fibonacci_pytensor_numba], 
    names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],
    number=10
)

In [9]:
fibonacci_bench.run(
    inputs={
        "fibonacci_inputs": {"n": 100_000},
    }
)
fibonacci_bench.summary()

Unnamed: 0,Unnamed: 1,Loops,Min (us),Max (us),Mean (us),StdDev (us),Median (us),IQR (us),OPS (Kops/s),Samples
fibonacci_pytensor,fibonacci_inputs,10,64518.025001,67663.2042,66403.93584,1071.000294,66790.325,910.641799,0.015059,5
fibonacci_numba,fibonacci_inputs,10,75.4917,104.4291,93.351733,6.430028,93.51875,4.254199,10.712174,12
fibonacci_jax,fibonacci_inputs,10,2.1125,7.441601,3.20332,2.11956,2.125,0.112499,312.176109,5
fibonacci_pytensor_numba,fibonacci_inputs,10,3325.6,3470.2958,3371.58414,50.820262,3355.9916,15.641699,0.296596,5


## Element-wise multiplication Algorithm

In [10]:
a_symbolic = pt.vector("a", dtype="int32")
b_symbolic = pt.vector("b", dtype="int32")

def step(a_element, b_element):
        return a_element * b_element
    
c, _ = pytensor.scan(
    fn=step,
    sequences=[a_symbolic, b_symbolic]
)

# compile function returning final a
elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True)

elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode="NUMBA", trust_input=True)

In [11]:
@jit(nopython=True)
def elementwise_multiply_numba(a, b):
    n = a.shape[0]
    c = np.empty(n, dtype=a.dtype)
    for i in range(n):
        c[i] = a[i] * b[i]
    return c

In [12]:
@jax.jit
def elementwise_multiply_jax(a, b):
    n = a.shape[0]
    c = jnp.empty(n, dtype=a.dtype)
    for i in range(n):
        c = c.at[i].set(a[i] * b[i])
    return c

In [13]:
a = np.random.normal(0, 1, (100)).astype(np.int32)
b = np.random.normal(0, 1, (100)).astype(np.int32)

In [14]:
elem_mult_bench = Benchmarker(
    functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, jax.block_until_ready(elementwise_multiply_jax), elementwise_multiply_pytensor_numba], 
    names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],
    number=10
)

In [15]:
elem_mult_bench.run(
    inputs={
        "elem_mult_inputs": {"a": a, "b": b},
    }
)
elem_mult_bench.summary()

Unnamed: 0,Unnamed: 1,Loops,Min (us),Max (us),Mean (us),StdDev (us),Median (us),IQR (us),OPS (Kops/s),Samples
elementwise_multiply_pytensor,elem_mult_inputs,10,3.095801,18.9167,3.748584,0.513067,3.7084,0.241699,266.767398,11841
elementwise_multiply_numba,elem_mult_inputs,10,0.341701,0.7792,0.386048,0.085013,0.370899,0.029099,2590.353075,23
elementwise_multiply_jax,elem_mult_inputs,10,7.995899,9.987499,8.917886,0.568675,8.8959,0.456251,112.134202,7
elementwise_multiply_pytensor_numba,elem_mult_inputs,10,3.6,8.941699,4.379155,1.635324,3.6666,0.291699,228.354537,9


# Changepoint Detection Algorithms

## Cumulative Sum (CUSUM) Algorithm

In [16]:
@jit(nopython=True)
def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):
    """
    Two-sided CUSUM with adaptive exponential moving average baseline.
    
    Parameters
    ----------
    x: np.ndarray
        input signal
    alpha: float
        EMA smoothing factor (0 < alpha <= 1)
    k: float
        slack to avoid small changes triggering alarms
    h: float
        threshold for raising an alarm
        
    Returns
    -------
    s_pos: np.ndarray
        upper CUSUM stats
    s_neg: np.ndarray
        lower CUSUM stats
    mu_t: np.ndarray
        evolving baseline estimate
    alarms_pos: np.ndarray
        alarms for upward changes
    alarms_neg: np.ndarray
        alarms for downward changes
    """
    n = x.shape[0]

    s_pos = np.zeros(n, dtype=np.float64)
    s_neg = np.zeros(n, dtype=np.float64)
    mu_t  = np.zeros(n, dtype=np.float64)
    alarms_pos = np.zeros(n, dtype=np.bool_)
    alarms_neg = np.zeros(n, dtype=np.bool_)

    # Initialization
    mu_t[0] = x[0]

    for i in range(1, n):
        # Update baseline (EMA)
        mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]

        # Update CUSUM stats
        s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)
        s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)

        # Alarms
        alarms_pos[i] = s_pos[i] > h
        alarms_neg[i] = s_neg[i] > h

    return s_pos, s_neg, mu_t, alarms_pos, alarms_neg

In [17]:
@jax.jit
def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):
    """
    Two-sided CUSUM with adaptive exponential moving average baseline.
    
    Parameters
    ----------
    x: jnp.ndarray
        input signal
    alpha: float
        EMA smoothing factor (0 < alpha <= 1)
    k: float
        slack to avoid small changes triggering alarms
    h: float
        threshold for raising an alarm
        
    Returns
    -------
    s_pos: jnp.ndarray
        upper CUSUM stats
    s_neg: jnp.ndarray
        lower CUSUM stats
    mu_t: jnp.ndarray
        evolving baseline estimate
    alarms_pos: jnp.ndarray
        alarms for upward changes
    alarms_neg: jnp.ndarray
        alarms for downward changes
    """
    def body(carry, x_t):
        s_pos_prev, s_neg_prev, mu_prev = carry
        
        # Update EMA baseline
        mu_t = alpha * x_t + (1 - alpha) * mu_prev
        
        # Update CUSUMs using updated baseline
        s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)
        s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)
        
        new_carry = (s_pos, s_neg, mu_t)
        output = (s_pos, s_neg, mu_t)
        return new_carry, output

    # Initialize: CUSUMs at 0, initial mean = first sample
    s0 = (0.0, 0.0, x[0])
    _, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)
    
    # Thresholding
    alarms_pos = s_pos_vals > h
    alarms_neg = s_neg_vals > h

    return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg


In [18]:
x_symbolic = pt.vector("x")
alpha_symbolic = pt.scalar("alpha")
k_symbolic = pt.scalar("k")
h_symbolic = pt.scalar("h")

def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):
    # Update EMA baseline
    mu_t = alpha * x_t + (1 - alpha) * mu_prev
    
    # Update CUSUMs using updated baseline
    s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)
    s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)
    
    return s_pos, s_neg, mu_t


(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(
    fn=step,
    outputs_info=[pt.constant(0.), pt.constant(0.), x_symbolic[0]],
    non_sequences=[alpha_symbolic, k_symbolic],
    sequences=[x_symbolic]
)

# Thresholding
alarms_pos = s_pos_vals > h_symbolic
alarms_neg = s_neg_vals > h_symbolic

cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)

cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode="NUMBA", trust_input=True)

In [19]:
cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode="JAX", trust_input=True)

In [20]:
xs1 = np.random.normal(80, 20, size=(50))
xs2 = np.random.normal(50, 20, size=(50))
xs = np.concat((xs1, xs2))
xs = xs.astype(np.float32)
xs = xs.astype(np.float32)
xs_std = (xs - xs.mean()) / xs.std()

In [21]:
cusum_bench = Benchmarker(
    functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, jax.block_until_ready(cusum_adaptive_jax), cusum_adaptive_pytensor_numba, jax.block_until_ready(cusum_adaptive_pytensor_jax)], 
    names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],
    number=10
)

In [22]:
cusum_bench.run(
    inputs={
        "cusum_inputs": {"x": xs, "alpha": 0.1, "k": 0.5, "h": 3.5},
    }
)
cusum_bench.summary()

Unnamed: 0,Unnamed: 1,Loops,Min (us),Max (us),Mean (us),StdDev (us),Median (us),IQR (us),OPS (Kops/s),Samples
cusum_adaptive_pytensor,cusum_inputs,10,120.6334,169.616699,135.634676,7.004116,134.1084,7.029201,7.372746,761
cusum_adaptive_numba,cusum_inputs,10,1.774999,3.3958,2.035714,0.556027,1.8042,0.05,491.228092,7
cusum_adaptive_jax,cusum_inputs,10,11.058401,13.825,12.22415,0.747579,12.2646,1.062525,81.805279,34
cusum_adaptive_pytensor_numba,cusum_inputs,10,25.525001,32.5542,27.4075,2.626003,26.5125,1.2958,36.486363,5
cusum_adaptive_pytensor_jax,cusum_inputs,10,19.275,57.133401,22.210748,6.481453,21.162501,1.7313,45.023246,31


In [23]:
outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)

In [24]:
fig = go.Figure()
fig.add_traces(
    [
        go.Scatter(
            x = np.arange(len(xs)),
            y = xs_std,
            name="series"
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[0],
            name="cum. positive devs."
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[1],
            name="cum. negative devs."
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[2],
            name="Exp. Mean"
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[3].astype(np.float16),
            name="positive alarms"
        ),
        go.Scatter(
            x = np.arange(len(xs)),
            y = outputs[4].astype(np.float16),
            name="negative alarms"
        ),
        
    ]
)
fig.update_layout(
    title = dict(
        text = "CUSUM Change Point Detection Algorithm"
    ),
    xaxis=dict(
        title = "Time Index"
    ),
    yaxis=dict(
        title = "Standardized Series Scaled"
    ),
    legend=dict(
        yanchor="top",
        y=1.1,
        xanchor="left",
        x=0,
        orientation="h"
    ),
    template="plotly_dark"
)

## Pruned Exact Linear Time (PELT) Algorithm

In [25]:
@jit(nopython=True)
def segment_cost_numba(S1, S2, i, j):
    """Cost of segment x[i:j], SSE around mean"""
    n = j - i
    sum_x = S1[j] - S1[i]
    sum_x2 = S2[j] - S2[i]
    if n > 0:
        return sum_x2 - (sum_x ** 2) / n
    else:
        return np.inf

@jit(nopython=True)
def pelt_numba(x, beta=10.0):
    """
    Pruned Exact Linear Time algorithm for change point detection

    Parameters
    ----------
    x: np.ndarray
        The timeseries signal
    beta: float
        Penalty of segmenting the series

    Returns
    -------
    C: np.ndarray
        The best costs up to segment t
    last_change: np.ndarray
        The last change point up to segment t
    """
    n = len(x)

    # cumulative sums for cost
    S1 = np.empty(n+1, dtype=np.float64)
    S2 = np.empty(n+1, dtype=np.float64)
    S1[0], S2[0] = 0.0, 0.0
    for i in range(1, n+1):
        S1[i] = S1[i-1] + x[i-1]
        S2[i] = S2[i-1] + x[i-1]**2

    # DP arrays
    C = np.full((n+1,), np.inf)
    C[0] = -beta
    last_change = np.full((n+1,), -1)
    min_size = 3

    for t in range(1, n+1):
        costs = np.full(n, np.inf)
        for s in range(n):
            if s < t and (t - s) >= min_size:
                costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta
        best_s = np.argmin(costs)
        C[t] = costs[best_s]
        last_change[t] = best_s

    return C, last_change

In [26]:
def segment_cost_jax(S1, S2, i, j):
    """Cost of segment x[i:j], SSE around mean"""
    n = j - i
    sum_x = S1[j] - S1[i]
    sum_x2 = S2[j] - S2[i]
    return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)


@jax.jit
def pelt_jax(x, beta=10.0):
    """
    Pruned Exact Linear Time algorithm for change point detection

    Parameters
    ----------
    x: np.ndarray
        The timeseries signal
    beta: float
        Penalty of segmenting the series

    Returns
    -------
    C: jnp.ndarray
        The best costs up to segment t
    last_change: jnp.ndarray
        The last change point up to segment t
    """
    n = len(x)

    # cumulative sums for cost
    S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])
    S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])

    # DP arrays
    C = jnp.full((n+1,), jnp.inf)
    C = C.at[0].set(-beta)
    last_change = jnp.full((n+1,), -1)
    min_size = 3

    s_all = jnp.arange(n)   # all possible candidates

    def body(t, carry):
        C, last_change = carry

        # Compute cost for all s < t, mask invalid
        # valid = s_all < t & ((t - s_all) >= min_size)
        
        valid = (s_all < t) & ((t - s_all) >= min_size)
        costs = jnp.where(
            valid,
            C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,
            jnp.inf
        )

        best_s = jnp.argmin(costs)
        C = C.at[t].set(costs[best_s])
        last_change = last_change.at[t].set(best_s)
        return C, last_change

    C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))
    return C, last_change

In [27]:
def segment_cost_pytensor(S1, S2, i, j):
    """Cost of segment x[i:j], SSE around mean"""
    n = j - i
    sum_x = S1[j] - S1[i]
    sum_x2 = S2[j] - S2[i]
    return pt.switch(
        pt.gt(n, 0),
        sum_x2 - (sum_x ** 2) / n,
        np.inf
    )


In [28]:
x_symbolic = pt.vector("x")
beta_symbolic = pt.scalar("beta")
n = x_symbolic.shape[0]

# cumulative sums for cost
S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])
S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])

# DP arrays
C_init = pt.alloc(np.inf, n+1)
C_init = pt.set_subtensor(C_init[0], -beta_symbolic)
last_change_init = pt.alloc(-1, n+1)

s_all = pt.arange(n)   # candidate change points
min_size = 3

def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):
    # valid = (s_all < t) & ((t - s_all) >= min_size)
    valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))

    # compute costs for all candidates
    costs, _ = pytensor.scan(
        fn=lambda s: pt.switch(
            valid[s],
            C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,
            np.inf
        ),
        sequences=[pt.arange(n)]
    )
    costs = costs.flatten()

    best_s = pt.argmin(costs, axis=0)
    C_new = pt.set_subtensor(C_prev[t], costs[best_s])
    last_change_new = pt.set_subtensor(last_change_prev[t], best_s)

    return C_new, last_change_new

(C_vals, last_change_vals), _ = pytensor.scan(
    fn=step,
    sequences=[pt.arange(1, n+1)],
    outputs_info=[C_init, last_change_init],
    non_sequences=[S1, S2, beta_symbolic, s_all]
)

pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)
pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode="NUMBA", trust_input=True)

In [29]:
pelt_bench = Benchmarker(
    functions=[pelt_pytensor, pelt_numba, jax.block_until_ready(pelt_jax), pelt_pytensor_numba], 
    names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],
    number=10
)

In [30]:
pelt_bench.run(
    inputs={
        "pelt_inputs": {"x": xs_std, "beta": 2. * np.log(len(xs_std))},
    }
)
pelt_bench.summary()

Unnamed: 0,Unnamed: 1,Loops,Min (us),Max (us),Mean (us),StdDev (us),Median (us),IQR (us),OPS (Kops/s),Samples
pelt_pytensor,pelt_inputs,10,11427.6291,12441.8708,12163.819422,287.094492,12253.829101,176.5875,0.082211,9
pelt_numba,pelt_inputs,10,19.295899,22.0125,20.3492,0.991247,20.3209,1.316701,49.141982,5
pelt_jax,pelt_inputs,10,15.066699,65.6042,50.399558,17.811394,56.879199,5.2688,19.841444,19
pelt_pytensor_numba,pelt_inputs,10,2507.9375,2881.2875,2627.49916,130.434474,2581.904099,26.666699,0.38059,5


In [31]:
outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))

In [32]:
def plot_pelt_diagnostics(x, cps, C):
    """
    Diagnostic plots for PELT changepoint detection.
    
    Args:
        x: 1D array, original time series
        C: 1D array, cumulative DP cost from pelt()
        cps: list of changepoint indices (sorted ascending)
    """
    n = len(x)
    cps_full = [0] + cps + [n]

    # Segment means, std, SSE
    segment_means = []
    segment_stds = []
    segment_costs = []
    for start, end in zip(cps_full[:-1], cps_full[1:]):
        seg = x[start:end]
        mean = np.mean(seg)
        std = np.std(seg)
        cost = np.sum((seg - mean)**2)
        segment_means.append(mean)
        segment_stds.append(std)
        segment_costs.append(cost)

    # Step function for segment mean
    mean_step = np.zeros(n)
    for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):
        mean_step[start:end] = segment_means[i]

    # Step function for segment std
    std_step = np.zeros(n)
    for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):
        std_step[start:end] = segment_stds[i]

    if len(x) < 20:
        title1 = "<span style='color: red;'>Warning</span>: Sample size is small - Detected Changepoints"
    else:
        title1 = "Detected Changepoints"

    fig = make_subplots(
        rows=4, 
        cols=1,
        subplot_titles=(title1, "Average Shifts", "Variability Shifts", "Cumulative Cost")
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = x,
            line_color="royalblue",
            name = "Actuals",
            mode="lines",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Actual</b>: %{y}"
        ),
        row=1, col=1
    )

    for cp in cps:
        fig.add_vline(x=cp, line_dash='dash', line_color="red", row=1, col=1)

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = x,
            name = "Actuals",
            mode="lines",
            line_color="rgba(105, 105, 105, 0.25)",
            showlegend=False,
            hoverinfo="skip"
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = mean_step,
            name = "Average",
            line_color="royalblue",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Average</b>: %{y}"
        ),
        row=2, col=1
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = std_step,
            name = "Standard Deviation",
            line_color="royalblue",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Standard Deviation</b>: %{y}"
        ),
        row=3, col=1
    )

    fig.add_trace(
        go.Scatter(
            x = np.arange(len(x)),
            y = C,
            name = "Cumulative Cost",
            line_color="royalblue",
            showlegend=False,
            hovertemplate="<b>Time Point</b>: %{x}<br><b>Cost</b>: %{y}"
        ),
        row=4, col=1
    )

    for cp in cps:
        fig.add_vline(x=cp, line_dash='dash', line_color="red", row=4, col=1)

    return fig.update_layout(height=1000, width=1200, template="plotly_dark")


In [33]:
def get_changepoints(last_change, n):
    """
    Backtrack changepoints from last_change array.
    
    Args:
        last_change: array from pelt()
        n: length of input series

    Returns:
        list of changepoint indices (sorted ascending)
    """
    cps = []
    t = n
    while t > 0:
        s = int(last_change[t])
        if s <= 0:
            break
        cps.append(s)
        t = s
    return list(reversed(cps))

In [34]:
cps = get_changepoints(outputs[1], n=len(xs_std))

In [35]:
plot_pelt_diagnostics(xs, cps, outputs[0])

# Kalman Filter Algorithms

## Linear Gaussian Kalman Filter

In [36]:
@jit(nopython=True)
def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):
    """
    This implementation of the Kalman filter is Atrocious and in standard Python would be a 
    BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time. 
    
    Linear Gaussian Kalman filter algorithm

    Parameters
    ----------
    z: np.ndarray
        shape (T, m)  - observations
    F: np.ndarray
        state transition matrix - shape (n, n)
    H: np.ndarray
        observation/design matrix - shape (m, n)
    Q: np.ndarray
        process noise covariance - shape (n, n)
    R: np.ndarray
        observation noise covariance - shape (m, m)
    x0: np.ndarray
        initial state mean - shape (n,)
    P0: np.ndarray
        initial state covariance - shape (n, n)

    Returns
    -------
    xs: np.ndarray
        shape (T, n)   - filtered state means
    Ps: np.ndarray
        shape (T, n, n) - filtered state covariances
    """
    T = z.shape[0]
    m = z.shape[1]
    n = x0.shape[0]

    xs = np.empty((T, n), dtype=np.float32)
    Ps = np.empty((T, n, n), dtype=np.float32)

    # local working arrays
    x = np.empty(n, dtype=np.float32)
    for i in range(n):
        x[i] = x0[i]
    P = np.empty((n, n), dtype=np.float32)
    for i in range(n):
        for j in range(n):
            P[i, j] = P0[i, j]

    # temporary matrices/vectors
    x_pred = np.empty((T, n), dtype=np.float32)
    P_pred = np.empty((T, n, n), dtype=np.float32)
    y = np.empty(m, dtype=np.float32)
    S = np.empty((m, m), dtype=np.float32)
    K = np.empty((n, m), dtype=np.float32)
    I_n = np.eye(n, dtype=np.float32)

    for t in range(T):
        # === Predict ===
        # x_pred = F @ x
        for i in range(n):
            s = 0.0
            for j in range(n):
                s += F[i, j] * x[j]
            x_pred[t, i] = s

        # P_pred = F @ P @ F.T + Q
        # temp = F @ P
        temp = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += F[i, k] * P[k, j]
                temp[i, j] = s
        # P_pred = temp @ F.T
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += temp[i, k] * F[j, k]   # F.T[k, j] = F[j, k]
                P_pred[t, i, j] = s + Q[i, j]

        # === Update ===
        # y = z[t] - H @ x_pred
        for i in range(m):
            s = 0.0
            for j in range(n):
                s += H[i, j] * x_pred[t, j]
            y[i] = z[t, i] - s

        # S = H @ P_pred @ H.T + R
        # temp2 = H @ P_pred
        temp2 = np.empty((m, n), dtype=np.float32)
        for i in range(m):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += H[i, k] * P_pred[t, k, j]
                temp2[i, j] = s
        # S = temp2 @ H.T
        for i in range(m):
            for j in range(m):
                s = 0.0
                for k in range(n):
                    s += temp2[i, k] * H[j, k]  # H.T[k,j] = H[j,k]
                S[i, j] = s + R[i, j]

        # K = P_pred @ H.T @ inv(S)
        # first compute P_pred @ H.T  -> (n, m)
        P_Ht = np.empty((n, m), dtype=np.float32)
        for i in range(n):
            for j in range(m):
                s = 0.0
                for k in range(n):
                    s += P_pred[t, i, k] * H[j, k]  # H.T[k,j] = H[j,k]
                P_Ht[i, j] = s

        # invert S
        S_inv = np.linalg.inv(S)

        # K = P_Ht @ S_inv  (n,m) @ (m,m) -> (n,m)
        for i in range(n):
            for j in range(m):
                s = 0.0
                for k in range(m):
                    s += P_Ht[i, k] * S_inv[k, j]
                K[i, j] = s

        # x = x_pred + K @ y
        for i in range(n):
            s = 0.0
            for j in range(m):
                s += K[i, j] * y[j]
            x[i] = x_pred[t, i] + s

        # P = (I - K H) P_pred
        # compute (I - K H)
        KH = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(m):
                    s += K[i, k] * H[k, j]
                KH[i, j] = s

        I_minus_KH = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            for j in range(n):
                I_minus_KH[i, j] = I_n[i, j] - KH[i, j]

        # P = I_minus_KH @ P_pred
        for i in range(n):
            for j in range(n):
                s = 0.0
                for k in range(n):
                    s += I_minus_KH[i, k] * P_pred[t, k, j]
                P[i, j] = s

        # store results
        for i in range(n):
            xs[t, i] = x[i]
        for i in range(n):
            for j in range(n):
                Ps[t, i, j] = P[i, j]

    return xs, Ps, x_pred, P_pred


In [37]:
@jit(nopython=True)
def kalman_filter_numba(z, F, H, Q, R, x0, P0):
    """
    Linear Gaussian Kalman filter algorithm

    Parameters
    ----------
    z: np.ndarray
        shape (T, m)  - observations
    F: np.ndarray
        state transition matrix - shape (n, n)
    H: np.ndarray
        observation/design matrix - shape (m, n)
    Q: np.ndarray
        process noise covariance - shape (n, n)
    R: np.ndarray
        observation noise covariance - shape (m, m)
    x0: np.ndarray
        initial state mean - shape (n,)
    P0: np.ndarray
        initial state covariance - shape (n, n)

    Returns
    -------
    xs: np.ndarray
        shape (T, n)   - filtered state means
    Ps: np.ndarray
        shape (T, n, n) - filtered state covariances
    """
    T, m = z.shape
    n = x0.shape[0]

    xs = np.zeros((T, n), dtype=np.float32)
    Ps = np.zeros((T, n, n), dtype=np.float32)

    x_pred = np.zeros((T, n), dtype=np.float32)
    P_pred = np.zeros((T, n, n), dtype=np.float32)

    x = x0.copy()
    P = P0.copy()

    I = np.eye(n, dtype=np.float32)

    for t in range(T):
        # --- Predict ---
        x_pred[t] = F @ x
        P_pred[t] = F @ P @ F.T + Q

        # --- Update ---
        y = z[t] - H @ x_pred[t]
        S = H @ P_pred[t] @ H.T + R
        K = P_pred[t] @ H.T @ np.linalg.inv(S)

        x = x_pred[t] + K @ y
        P = (I - K @ H) @ P_pred[t]

        xs[t] = x
        Ps[t] = P

    return xs, Ps, x_pred, P_pred

In [38]:
@jax.jit
def kalman_filter_jax(z, F, H, Q, R, x0, P0):
    """
    Linear Gaussian Kalman filter algorithm

    Parameters
    ----------
    z: np.ndarray
        shape (T, m)  - observations
    F: np.ndarray
        state transition matrix - shape (n, n)
    H: np.ndarray
        observation/design matrix - shape (m, n)
    Q: np.ndarray
        process noise covariance - shape (n, n)
    R: np.ndarray
        observation noise covariance - shape (m, m)
    x0: np.ndarray
        initial state mean - shape (n,)
    P0: np.ndarray
        initial state covariance - shape (n, n)

    Returns
    -------
    xs: jnp.ndarray
        shape (T, n)   - filtered state means
    Ps: jnp.ndarray
        shape (T, n, n) - filtered state covariances
    """

    n = x0.shape[0]
    I = jnp.eye(n)
    X_pred_init = jnp.zeros((1,))
    P_pred_init = jnp.zeros((1, 1,))

    def step(carry, z_t):
        x, P, _, _ = carry

        # --- Predict ---
        x_pred = F @ x
        P_pred = F @ P @ F.T + Q

        # --- Update ---
        y = z_t - H @ x_pred
        S = H @ P_pred @ H.T + R
        K = P_pred @ H.T @ jnp.linalg.inv(S)

        x_new = x_pred + K @ y
        P_new = (I - K @ H) @ P_pred

        return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)

    # run scan
    (_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)

    return xs, Ps, x_pred, P_pred

In [39]:
z_symbolic = pt.matrix("z")
F_symbolic = pt.matrix("F")
H_symbolic = pt.matrix("H")
Q_symbolic = pt.matrix("Q")
R_symbolic = pt.matrix("R")
x0_symbolic = pt.vector("x0")
P0_symbolic = pt.matrix("P0")

n = x0_symbolic.shape[0]
I = pt.eye(n)
X_pred_init = pt.zeros_like(x0_symbolic)
P_pred_init = pt.zeros_like(P0_symbolic)

def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):

    # --- Predict ---
    x_pred = F_symbolic @ x
    P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic

    # --- Update ---
    y = z_t - H_symbolic @ x_pred
    S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic
    K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)

    x_new = x_pred + K @ y
    P_new = (I - K @ H_symbolic) @ P_pred

    return x_new, P_new, x_pred, P_pred

# run scan
(xs, Ps, x_pred, P_pred), _ = pytensor.scan(
    fn=step,
    outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],
    sequences=[z_symbolic],
    non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I]
)

kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)

kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode="NUMBA", trust_input=True)
kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode="JAX", trust_input=True)

In [40]:
T = 500
F = np.array([[1.0]]).astype(np.float32)
H = np.array([[1.0]]).astype(np.float32)
Q = np.array([[0.01]]).astype(np.float32)
R = np.array([[0.1]]).astype(np.float32)
x0 = np.array([0.0]).astype(np.float32)
P0 = np.array([[1.0]]).astype(np.float32)

true = 1.0
z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)

In [41]:
kalman_filter_bench = Benchmarker(
    functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, jax.block_until_ready(kalman_filter_jax), kalman_filter_pytensor_numba, jax.block_until_ready(kalman_filter_pytensor_jax)], 
    names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],
    number=10
)

In [42]:
kalman_filter_bench.run(
    inputs={
        "kalman_filter_inputs": {"z": z, "F": F, "H": H, "Q": Q, "R": R, "x0": x0, "P0": P0},
    }
)
kalman_filter_bench.summary()


[1m[1mCannot cache compiled function "scan" as it uses dynamic globals (such as ctypes pointers and large global arrays)[0m[0m



Unnamed: 0,Unnamed: 1,Loops,Min (us),Max (us),Mean (us),StdDev (us),Median (us),IQR (us),OPS (Kops/s),Samples
kalman_filter_pytensor,kalman_filter_inputs,10,6024.866599,6579.8458,6250.349981,153.071106,6202.1771,232.759375,0.159991,16
atrocious_kalman_filter_numba,kalman_filter_inputs,10,321.066599,339.3833,326.6883,6.500478,324.787499,1.745901,3.061022,5
kalman_filter_numba,kalman_filter_inputs,10,661.0458,699.4208,679.645,15.146345,678.958399,28.6166,1.471356,5
kalman_filter_jax,kalman_filter_inputs,10,311.85,404.729199,349.466676,30.904895,345.983301,53.454102,2.861503,17
kalman_filter_pytensor_numba,kalman_filter_inputs,10,849.5916,914.7167,886.17166,22.54198,890.6167,26.891699,1.12845,5
kalman_filter_pytensor_jax,kalman_filter_inputs,10,338.3166,393.6875,360.706558,15.525106,359.574999,25.83125,2.772337,19


In [43]:
xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)

In [44]:
def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):
    T = z.shape[0]
    m = H.shape[0]
    mean = np.zeros((T, m))
    lower = np.zeros((T, m))
    upper = np.zeros((T, m))
    outside = np.zeros(T, dtype=np.bool_)

    for t in range(T):
        mean[t] = H @ x_pred[t]
        S = H @ P_pred[t] @ H.T + R
        std = np.sqrt(np.diag(S))
        lower[t] = mean[t] - zscore * std
        upper[t] = mean[t] + zscore * std

        # check coverage of actual obs
        outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))

    coverage = 1 - outside.mean()
    return mean, lower, upper, coverage


In [45]:
mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)

In [46]:
coverage

np.float64(0.916)

In [47]:
fig= go.Figure()
fig.add_traces(
    [
        go.Scatter(
            x = np.arange(T),
            y = z.ravel(),
            mode="markers",
            marker_color = "royalblue",
            name = "actuals"
        ),
        go.Scatter(
            x = np.arange(T),
            y = xs.ravel(),
            mode = "lines",
            marker_color = "orange",
            name = "filtered mean"
        ),
        go.Scatter(
                name="", 
                x=np.arange(T), 
                y=upper.ravel(), 
                mode="lines", 
                marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="95% CI",
                showlegend=False
            ),
            go.Scatter(
                name="95% CI", 
                x=np.arange(T), 
                y=lower.ravel(), 
                mode="lines", marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="95% CI", 
                fill='tonexty', 
                fillcolor='rgba(235, 140, 52, 0.2)'
            ),

    ]
)
fig.update_layout(
    xaxis=dict(
        title = "Time Index",
    ),
    yaxis=dict(
        title = "y"
    ),
    template = "plotly_dark"
)

## Non-linear Kalman Filter

In [48]:
@jit(nopython=True)
def loglik_poisson_numba(s, y):
    """Poisson Log Likelihood"""
    mu = np.exp(s)
    return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln

@jit(nopython=True)
def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):
    """
    1D particle filter.
    
    Parameters
    ----------
    A: float
        State transition
    Q: float
        Process covariance
    x0_mean: float
        Prior mean for the latent state
    x0_std: float
        Prior standard deviation 
    ys: np.ndarray
        observations
    N: int
        number of particles
    seed: int
        rng seed for reproducibility

    Returns
    -------
    filtered_means: np.ndarray
        The filtered mean for the latent state 
    filtered_vars: np.ndarray
        The filtered variance for the latent state
    pred_means: np.ndarray
        observation predicted mean 
    """
    np.random.seed(seed)
    T = ys.shape[0]
    particles = np.random.normal(x0_mean, x0_std, size=N)
    weights = np.ones(N) / N

    filtered_means = np.zeros(T)
    filtered_vars = np.zeros(T)
    pred_means = np.zeros(T)

    for t in range(T):
        y = ys[t]

        # propagate (vectorized)
        particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)

        # update weights
        logw = np.zeros(N)
        for i in range(N):
            logw[i] = loglik_poisson_numba(particles[i], y)
        logw = logw - np.max(logw)
        weights *= np.exp(logw)
        weights /= np.sum(weights) + 1e-12

        # filtered moments
        mean_t = np.sum(weights * particles)
        var_t = np.sum(weights * (particles - mean_t) ** 2)

        # predictive mean
        pred_mean = np.sum(weights * np.exp(particles))

        filtered_means[t] = mean_t
        filtered_vars[t] = var_t
        pred_means[t] = pred_mean

        # resample (multinomial resampling) because numba doesn't support np.random.choice
        cumulative_sum = np.cumsum(weights)
        cumulative_sum[-1] = 1.0  # guard against rounding error
        indices = np.searchsorted(cumulative_sum, np.random.rand(N))

        particles = particles[indices]
        weights = np.ones(N) / N

    return filtered_means, filtered_vars, pred_means

In [49]:
# Had to fix the loglikelihood and key to use benchmarker as is
def loglik_poisson_jax(s, y):
    """Poisson Log Likelihood"""
    mu = jnp.exp(s)
    return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)


@partial(jax.jit, static_argnums=5)
def particle_filter_1d_predict_jax(
    A, Q, x0_mean, x0_std, ys, N=1000,
):
    """
    1D particle filter.
    
    Parameters
    ----------
    A: float
        State transition
    Q: float
        Process covariance
    x0_mean: float
        Prior mean for the latent state
    x0_std: float
        Prior standard deviation 
    ys: np.ndarray
        observations
    loglik_fn: function
        The log likelihood function
    key: 
        JAX prng key
    N: int
        number of particles

    Returns
    -------
    filtered_means: jnp.ndarray
        The filtered mean for the latent state 
    filtered_vars: jnp.ndarray
        The filtered variance for the latent state
    pred_means: jnp.ndarray
        observation predicted mean 
    """
    key = jax.random.PRNGKey(0)
    T = ys.shape[0]
    particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors
    weights = jnp.ones(N) / N # particle weights, all particles equally likely prior

    def body_fun(carry, t):
        particles, weights, key = carry
        y = ys[t]

        # propagate
        key, subkey = jax.random.split(key)
        particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model

        # update weights
        logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel
        logw = logw - jnp.max(logw) # avoid overflow
        weights = weights * jnp.exp(logw) # old weights times the likelihood
        weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1

        # filtered moments
        mean_t = jnp.sum(weights * particles) # posterior mean of latent state
        var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state

        # predictive mean
        pred_mean = jnp.sum(weights * jnp.exp(particles))

        # resample to prevent dominant particles
        key, subkey = jax.random.split(key)
        indices = jax.random.choice(subkey, N, p=weights, shape=(N,))
        particles = particles[indices]
        weights = jnp.ones(N) / N

        carry = (particles, weights, key)
        out = (mean_t, var_t, pred_mean)
        return carry, out

    _, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))
    return outputs


In [50]:
from pytensor.tensor.random.utils import RandomStream

# Random stream for PyTensor
srng = RandomStream(seed=42)

# Poisson log-likelihood
def loglik_poisson_pytensor(s, y):
    mu = pt.exp(s)
    return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)


In [51]:
ys_symbolic = pt.vector("ys")
x0_mean_symbolic = pt.scalar("x0_mean")
x0_std_symbolic = pt.scalar("x0_std")
A_symbolic = pt.scalar("A")
Q_symbolic = pt.scalar("Q")
N_symbolic = pt.scalar("N", dtype='int64')

# Initialize particles and weights
particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic
weights_init = pt.ones((N_symbolic,)) / N_symbolic 

# Step function for scan
def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):
    # Propagate particles
    particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)

    # Update weights
    # logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])
    logw = loglik_poisson_pytensor(particles_prop, y_t)
    logw_stable = logw - pt.max(logw)
    w_unnorm = weights_prev * pt.exp(logw_stable)
    w = w_unnorm / (pt.sum(w_unnorm) + 1e-12) 

    # Filtered moments
    mean_t = pt.sum(w * particles_prop)
    var_t = pt.sum(w * (particles_prop - mean_t) ** 2)
    pred_mean = pt.sum(w * pt.exp(particles_prop))

    # Resample particles
    idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w) 
    particles_resampled = particles_prop[idx]
    weights_resampled = pt.ones((N_symbolic,)) / N_symbolic

    # Return flat tuple
    return particles_resampled, weights_resampled, mean_t, var_t, pred_mean

# first two are recurrent, rest are collected
outputs_info = [
    particles_init,
    weights_init,
    None,
    None,
    None
]

(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(
    fn=step,
    sequences=[ys_symbolic],
    outputs_info=outputs_info,
    non_sequences=[A_symbolic, Q_symbolic]
)

particle_filter_1d_predict_pytensor = pytensor.function(
    [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],
    [means_seq, vars_seq, preds_seq],
    updates=updates,
    no_default_updates=True,
    trust_input=True
)

particle_filter_1d_predict_pytensor_numba = pytensor.function(
    [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],
    [means_seq, vars_seq, preds_seq],
    updates=updates,
    no_default_updates=True,
    mode="NUMBA", 
    trust_input=True
)

In [52]:
key = jax.random.PRNGKey(0)
T = 300
A = 0.95
Q = 0.05
rng = np.random.RandomState(1)

target_mean = 10.0
latent_var = Q / (1 - A**2)
x0_mean = np.log(target_mean) - 0.5 * latent_var
x0_std = 1.0

# Simulate latent
x = np.zeros(T)
x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean
for t in range(1, T):
    x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)

ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)

In [53]:
nonlinear_kalman_filter_bench = Benchmarker(
    functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, jax.block_until_ready(particle_filter_1d_predict_jax), particle_filter_1d_predict_pytensor_numba,], 
    names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],
    number=5 # This takes a while to run reducing number of loops
)

In [54]:
nonlinear_kalman_filter_bench.run(
    inputs={
        "kalman_filter_inputs": {"A": A, "Q": Q, "x0_mean": x0_mean, "x0_std": x0_std, "ys": ys, "N": 2000},
    }
)
nonlinear_kalman_filter_bench.summary()

Unnamed: 0,Unnamed: 1,Loops,Min (us),Max (us),Mean (us),StdDev (us),Median (us),IQR (us),OPS (Kops/s),Samples
particle_filter_1d_predict_pytensor,kalman_filter_inputs,5,776326.149999,792344.283199,782741.538279,6040.856511,779343.408198,8662.6,0.001278,5
particle_filter_1d_predict_numba,kalman_filter_inputs,5,51164.908399,51900.6332,51432.376601,264.829443,51439.416601,276.491602,0.019443,5
particle_filter_1d_predict_jax,kalman_filter_inputs,5,17.358401,33723.75,7730.253572,12709.927696,19.724999,10139.645799,0.129362,7
particle_filter_1d_predict_pytensor_numba,kalman_filter_inputs,5,724575.624999,739611.624999,731249.66,5239.730301,731423.700001,6487.666798,0.001368,5


Slightly different estimates because I couldn't reproduce 1:1 

In [55]:
filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(
    A, Q, x0_mean, x0_std, ys, N=2000, seed=2
)

In [56]:
fig = make_subplots(
    rows=2, cols=1,
    subplot_titles=("Observation Predictions", "Latent State Estimation"),
    vertical_spacing=0.07,
    shared_xaxes=True
)

fig.add_traces(
    [
        go.Scatter(
            x = np.arange(T),
            y = ys,
            mode = "markers",
            marker_color = "cornflowerblue",
            name = "actuals"
        ),
        go.Scatter(
            x = np.arange(T),
            y = pred_means,
            mode = "lines",
            marker_color = "#eb8c34",
            name = "predicted mean"
        ),
        go.Scatter(
                name="", 
                x=np.arange(T), 
                y=pred_means + 2*jnp.sqrt(pred_means), 
                mode="lines", 
                marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="predicted mean 95% CI",
                showlegend=False
            ),
            go.Scatter(
                name="predicted mean 95% CI", 
                x=np.arange(T), 
                y=pred_means - 2*jnp.sqrt(pred_means), 
                mode="lines", marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="predicted mean 95% CI", 
                fill='tonexty', 
                fillcolor='rgba(235, 140, 52, 0.2)'
            ),
    ],
    rows=1, cols=1
)

fig.add_traces(
    [
        go.Scatter(
            x = np.arange(T),
            y = x,
            mode = "lines",
            marker_color = "cornflowerblue",
            name = "true latent state"
        ),
        go.Scatter(
            x = np.arange(T),
            y = filtered_means,
            mode = "lines",
            marker_color = "#eb8c34",
            name = "filtered state mean"
        ),
        go.Scatter(
                name="", 
                x=np.arange(T), 
                y=filtered_means + 2*jnp.sqrt(filtered_vars), 
                mode="lines", 
                marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="filtered state mean 95% CI",
                showlegend=False
            ),
            go.Scatter(
                name="filtered state mean 95% CI", 
                x=np.arange(T), 
                y=filtered_means - 2*jnp.sqrt(filtered_vars), 
                mode="lines", marker=dict(color="#eb8c34"), 
                line=dict(width=0), 
                legendgroup="filtered state mean 95% CI", 
                fill='tonexty', 
                fillcolor='rgba(235, 140, 52, 0.2)'
            ),
    ],
    rows=2, cols=1
)

for i, yaxis in enumerate(fig.select_yaxes(), 1):
    legend_name = f"legend{i}"
    fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor="top")}, showlegend=True)
    fig.update_traces(row=i, legend=legend_name)

fig.update_layout(height=1000, width=1200, template="plotly_dark")

fig.update_layout(
    legend1=dict(
        yanchor="top",
        y=1.0,
        xanchor="left",
        x=0,
        orientation="h"
    ),
    legend2=dict(
        yanchor="top",
        y=.465,
        xanchor="left",
        x=0,
        orientation="h"
    ),
    )