# Finding slowdowns in pypomp

Kevin here. At the moment, `pypomp` happens to be slower than the code it was based on. I had a few hunches as to where they were, and I think I've found them. To begin, here are package imports.

In [1]:
import time
import pypomp
import unittest
import tracemalloc
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
import jax.scipy.special

from tqdm import tqdm
from pypomp.pfilter import _pfilter_internal

from functools import partial
import os
import csv
import jax
import jax.numpy as jnp
import pandas as pd
from pypomp.pomp_class import Pomp
from pypomp.model_struct import RInit, RProc, DMeas, RMeas, _time_interp
import jax.scipy.special as jspecial
import numpy as np

Run this cell if you're on CPU. I don't run it. This notebook was run on a computer with an Intel i9-13900K CPU, 128GB of RAM, and a NVIDIA RTX 3090 GPU.

In [2]:
jax.config.update("jax_platform_name", "cpu")

## `pypomp`
How slow is `pypomp` at the moment? We can see this below. A particle filtering operation with $J=10000$ particles takes about 7-8 seconds to compile, and post-compilation runs in about 1.44-1.45 seconds. `diffPomp` did the same in approximately 190ms post-compilation.

In [3]:
d = pypomp.dacca()
import time

elapses2 = []
for i in range(10):
    start = time.perf_counter()
    d.pfilter(10000, key=jax.random.PRNGKey(42+i))
    end = time.perf_counter()
    elapses2.append(end - start)
elapses2

[7.1690989409107715,
 1.442341866204515,
 1.4409919779282063,
 1.4435567690525204,
 1.4464609760325402,
 1.4510853479150683,
 1.442436127923429,
 1.4402216428425163,
 1.439497048035264,
 1.4424295420758426]

## Slowdowns in the Euler integration loop

Unfortunately, the slowdown seems to be in the Euler integration loop. I will illustrate this below. But first, we need to copy a lot of code from the `dacca.py` file (I aim to keep this as self-contained as possible).

In [4]:

def get_thetas(theta):
    gamma = jnp.exp(theta[0])
    m = jnp.exp(theta[1])
    rho = jnp.exp(theta[2])
    epsilon = jnp.exp(theta[3])
    omega = jnp.exp(theta[4])
    c = jspecial.expit(theta[5])
    beta_trend = theta[6] / 100
    sigma = jnp.exp(theta[7])
    tau = jnp.exp(theta[8])
    bs = theta[9:15]
    omegas = theta[15:21]
    k = 3
    delta = 0.02
    return (
        gamma,
        m,
        rho,
        epsilon,
        omega,
        c,
        beta_trend,
        sigma,
        tau,
        bs,
        omegas,
        k,
        delta,
    )


def transform_thetas(
    gamma, m, rho, epsilon, omega, c, beta_trend, sigma, tau, bs, omegas
):
    return jnp.concatenate(
        [
            jnp.array(
                [
                    jnp.log(gamma),
                    jnp.log(m),
                    jnp.log(rho),
                    jnp.log(epsilon),
                    jnp.log(omega),
                    jspecial.logit(c),
                    beta_trend * 100,
                    jnp.log(sigma),
                    jnp.log(tau),
                ]
            ),
            bs,
            omegas,
        ]
    )


gamma = 20.8  # recovery rate
epsilon = 19.1  # rate of waning of immunity for severe infections
rho = 0  # rate of waning of immunity for inapparent infections
delta = 0.02  # baseline mortality rate
m = 0.06  # cholera mortality rate
c = jnp.array(1)  # fraction of infections that lead to severe infection
beta_trend = -0.00498  # slope of secular trend in transmission
bs = jnp.array([0.747, 6.38, -3.44, 4.23, 3.33, 4.55])  # seasonal transmission rates
sigma = 3.13  # 3.13 # 0.77 # environmental noise intensity
tau = 0.23  # measurement error s.d.
omega = jnp.exp(-4.5)
omegas = jnp.log(
    jnp.array([0.184, 0.0786, 0.0584, 0.00917, 0.000208, 0.0124])
)  # seasonal environmental reservoir parameters

data_dir = "../pypomp/pypomp/data/dacca"

dacca_path = os.path.join(data_dir, "dacca.csv")
covars_path = os.path.join(data_dir, "covars.csv")
covart_path = os.path.join(data_dir, "covart.csv")

with open(dacca_path, "r") as f:
    reader = csv.reader(f)
    next(reader)
    data = [(float(row[1]), float(row[2])) for row in reader]
    times, values = zip(*data)
    ys = pd.DataFrame(values, index=pd.Index(times), columns=pd.Index(["deaths"]))

with open(covart_path, "r") as f:
    reader = csv.reader(f)
    next(reader)
    covart_index = [float(row[1]) for row in reader]
    covart_index = jnp.array(covart_index)

with open(covars_path, "r") as f:
    reader = csv.reader(f)
    next(reader)
    covars_data = [[float(value) for value in row[1:]] for row in reader]
    covars = pd.DataFrame(covars_data, index=np.array(covart_index))

key = jax.random.key(111)
theta_names = (
    [
        "gamma",
        "m",
        "rho",
        "epsilon",
        "omega",
        "c",
        "beta_trend",
        "sigma",
        "tau",
    ]
    + [f"b{i}" for i in range(1, 7)]
    + [f"omega{i}" for i in range(1, 7)]
)
theta = dict(
    zip(
        theta_names,
        transform_thetas(
            gamma, m, rho, epsilon, omega, c, beta_trend, sigma, tau, bs, omegas
        ).tolist(),
    )
)


#@partial(RInit, t0=1891.0)
def rinit(theta_, key, covars, t0=None):
    S_0 = 0.621
    I_0 = 0.378
    Y_0 = 0
    R1_0 = 0.000843
    R2_0 = 0.000972
    R3_0 = 1.16e-07
    pop = covars[2]
    S = pop * S_0
    I = pop * I_0
    Y = pop * Y_0
    R1 = pop * R1_0
    R2 = pop * R2_0
    R3 = pop * R3_0
    Mn = 0
    count = 0
    return jnp.array([S, I, Y, Mn, R1, R2, R3, count])


#@partial(RProc, dt=1 / 240, step_type="euler", accumvars=(3,))
def rproc(X_, theta_, key, covars, t, dt):
    S = X_[0]
    I = X_[1]
    Y = X_[2]
    deaths = X_[3]
    pts = X_[4:-1]
    count = X_[-1]
    trend = covars[0]
    dpopdt = covars[1]
    pop = covars[2]
    seas = covars[3:]
    (
        gamma,
        deltaI,
        rho,
        eps,
        omega,
        clin,
        beta_trend,
        sd_beta,
        tau,
        bs,
        omegas,
        nrstage,
        delta,
    ) = get_thetas(theta_)
    nrstage = 3
    clin = 1  # HARDCODED SEIR
    rho = 0  # HARDCODED INAPPARENT INFECTIONS
    std = jnp.sqrt(dt)

    neps = eps * nrstage  # rate
    passages = jnp.zeros(nrstage + 1)

    # Get current time step values
    beta = jnp.exp(beta_trend * trend + jnp.dot(bs, seas))
    omega = jnp.exp(jnp.dot(omegas, seas))

    subkey, key = jax.random.split(key)
    dw = jax.random.normal(subkey) * std

    effI = I / pop
    births = dpopdt + delta * pop
    passages = passages.at[0].set(gamma * I)
    ideaths = delta * I
    disease = deltaI * I
    ydeaths = delta * Y
    wanings = rho * Y

    rdeaths = pts * delta
    passages = passages.at[1:].set(pts * neps)

    infections = (omega + (beta + sd_beta * dw / dt) * effI) * S
    sdeaths = delta * S

    S += (births - infections - sdeaths + passages[nrstage] + wanings) * dt
    I += (clin * infections - disease - ideaths - passages[0]) * dt
    Y += ((1 - clin) * infections - ydeaths - wanings) * dt

    pts = pts + (passages[:-1] - passages[1:] - rdeaths) * dt

    deaths = deaths + disease * dt

    count = count + jnp.any(jnp.hstack([jnp.array([S, I, Y, deaths]), pts]) < 0)

    S = jnp.clip(S, 0)
    I = jnp.clip(I, 0)
    Y = jnp.clip(Y, 0)
    pts = jnp.clip(pts, 0)
    deaths = jnp.clip(deaths, 0)

    return jnp.hstack([jnp.array([S, I, Y, deaths]), pts, jnp.array([count])])


def dmeas_helper(y, deaths, v, tol, ltol):
    return jnp.logaddexp(
        jax.scipy.stats.norm.logpdf(y, loc=deaths, scale=v + tol), ltol
    )


def dmeas_helper_tol(y, deaths, v, tol, ltol):
    return jnp.array([ltol])


#@DMeas
def dmeas(Y_, X_, theta_, covars=None, t=None):
    deaths = X_[3]
    count = X_[-1]
    tol = 1.0e-18
    ltol = jnp.log(tol)
    (gamma, m, rho, epsilon, omega, c, beta_trend, sigma, tau, bs, omegas, k, delta) = (
        get_thetas(theta_)
    )
    v = tau * deaths
    # return jax.scipy.stats.norm.logpdf(y, loc=deaths, scale=v)
    return jax.lax.cond(
        jnp.logical_or(
            (1 - jnp.isfinite(v)).astype(bool), count > 0
        ),  # if Y < 0 then count violation
        dmeas_helper_tol,
        dmeas_helper,
        *(Y_, deaths, v, tol, ltol),
    )


#@partial(RMeas, ydim=1)
def rmeas(X_, theta_, key, covars=None, t=None):
    deaths = X_[3]
    (gamma, m, rho, epsilon, omega, c, beta_trend, sigma, tau, bs, omegas, k, delta) = (
        get_thetas(theta_)
    )
    v = tau * deaths
    return jax.random.normal(key) * v + deaths


def dacca():
    dacca_obj = Pomp(
        rinit=rinit,
        rproc=rproc,
        dmeas=dmeas,
        rmeas=rmeas,
        ys=ys,
        theta=theta,
        covars=covars,
    )
    return dacca_obj


## Taking steps in `rproc()` instead of Euler timesteps outside of it

Here is an implementation of `rproc()` that looks a lot like what was in `diffPomp`. I jury-rig this by reusing the `count` variable as an index, looping through the covariate table accordingly with 20 iterations per `rprocess()` call. This brings the Euler logic into `rproc()` itself. 

In [5]:
def rproc_step(X_, theta_, key, ctimes, covariates, t1, t2):
    S = X_[0]
    I = X_[1]
    Y = X_[2]
    deaths = X_[3]
    pts = X_[4:-1]
    count = X_[-1]
    
    (
        gamma,
        deltaI,
        rho,
        eps,
        omega,
        clin,
        beta_trend,
        sd_beta,
        tau,
        bs,
        omegas,
        nrstage,
        delta,
    ) = get_thetas(theta_)
    nrstage = 3
    clin = 1  # HARDCODED SEIR
    rho = 0  # HARDCODED INAPPARENT INFECTIONS
    std = jnp.sqrt(dt)

    neps = eps * nrstage  # rate
    passages = jnp.zeros(nrstage + 1)

    t = count.astype(int)

    for i in range(20):
        covars = covariates[t]
    
        trend = covars[0]
        dpopdt = covars[1]
        pop = covars[2]
        seas = covars[3:]
        # Get current time step values
        beta = jnp.exp(beta_trend * trend + jnp.dot(bs, seas))
        omega = jnp.exp(jnp.dot(omegas, seas))
    
        subkey, key = jax.random.split(key)
        dw = jax.random.normal(subkey) * std
    
        effI = I / pop
        births = dpopdt + delta * pop
        passages = passages.at[0].set(gamma * I)
        ideaths = delta * I
        disease = deltaI * I
        ydeaths = delta * Y
        wanings = rho * Y
    
        rdeaths = pts * delta
        passages = passages.at[1:].set(pts * neps)
    
        infections = (omega + (beta + sd_beta * dw / dt) * effI) * S
        sdeaths = delta * S
    
        S += (births - infections - sdeaths + passages[nrstage] + wanings) * dt
        I += (clin * infections - disease - ideaths - passages[0]) * dt
        Y += ((1 - clin) * infections - ydeaths - wanings) * dt
    
        pts = pts + (passages[:-1] - passages[1:] - rdeaths) * dt
    
        deaths = deaths + disease * dt
    
        count = count + jnp.any(jnp.hstack([jnp.array([S, I, Y, deaths]), pts]) < 0)
    
        S = jnp.clip(S, 0)
        I = jnp.clip(I, 0)
        Y = jnp.clip(Y, 0)
        pts = jnp.clip(pts, 0)
        deaths = jnp.clip(deaths, 0)

        count += 1

    return jnp.hstack([jnp.array([S, I, Y, deaths]), pts, jnp.array([count])])



In [6]:
covars_arr = jnp.array(covars)
ctimes_arr = jnp.array(covars.index)
ys_arr = jnp.array(ys)
times_arr = jnp.array(times)
theta_arr = jnp.array(list(theta.values()))

t0 = 1891.0
dt=1 / 240
step_type="euler"
accumvars=(3,)
ydim = 1
dmeasure = jax.vmap(dmeas, (None, 0, None, None, None))
rprocess_hacked = jax.vmap(rproc_step, (0, None, 0, None, None, None, None))

rinitializer = jax.vmap(rinit, (None, 0, None, None))

This yields the following speedup. Each post-compilation iteration only takes 169ms now! This is faster than `diffPomp` was. 

In [7]:
import time
elapses_rproc_hacked = []
logliks = []
for i in range(10):
    start = time.perf_counter()
    logliks.append(_pfilter_internal(theta_arr, t0, times_arr, ys_arr, 10000, 
                                     rinitializer, rprocess_hacked, dmeasure, 
                                     ctimes_arr, covars_arr, -1, 
                                     jax.random.PRNGKey(42+i)))
    end = time.perf_counter()
    elapses_rproc_hacked.append(end - start)
elapses_rproc_hacked

[11.551992549095303,
 0.17272041691467166,
 0.1708134498912841,
 0.16952022584155202,
 0.16943312995135784,
 0.1699791399296373,
 0.16962700197473168,
 0.16964941378682852,
 0.16998983779922128,
 0.16927395202219486]

## Can we keep the Euler logic?

This then begs the question of whether we can keep the Euler logic. Fortunately, the answer is yes. I had a hunch that a lot of the slowdown might be due to the interpolation of covariates. This is because `_interp_covars()` in Line 179 of `internal_functions.py` has a call to `jnp.searchsorted()` at every Euler timestep. `jnp.searchsorted()` is implemented via binary search, with a per-call time complexity of $O(\log n)$, where $n$ is the number of Euler timesteps in total. Below, I remove it by instead writing `covars_t = covars[(t/dt).astype(int)]` instead of `covars_t = _interp_covars()`.

In [8]:
from typing import Callable

def _time_interp_hacked(
    rproc: Callable,
    step_type: str,
    dt: float | None,
    accumvars: tuple[int, ...] | None,
) -> Callable:
    def _interp_helper(
        i: int,
        inputs: tuple[jax.Array, jax.Array, jax.Array, float],
        ctimes: jax.Array,
        covars: jax.Array,
        dt: float,
    ) -> tuple[jax.Array, jax.Array, jax.Array, float]:
        X_, theta_, key, t = inputs
        covars_t = covars[(t/dt).astype(int)]
        X_ = rproc(X_, theta_, key, covars_t, t, dt)
        t = t + dt
        return (X_, theta_, key, t)

    def _num_onestep_steps(t1: float, t2: float, dt: float) -> tuple[int, float]:
        return 1, t2 - t1

    def _num_euler_steps(
        t1: float, t2: float, dt: float
    ) -> tuple[jax.Array, jax.Array]:
        tol = jnp.sqrt(jnp.finfo(float).eps)

        nstep = jnp.ceil((t2 - t1) / dt / (1 + tol)).astype(int)
        dt2 = (t2 - t1) / nstep

        check1 = t1 + dt >= t2
        nstep = jnp.where(check1, 1, nstep)
        dt2 = jnp.where(check1, t2 - t1, dt2)

        check2 = t1 >= t2
        nstep = jnp.where(check2, 0, nstep)
        dt2 = jnp.where(check2, 0.0, dt2)

        return nstep, dt2

    num_step_func = None
    match step_type:
        case "onestep":
            num_step_func = _num_onestep_steps
        case "euler":
            num_step_func = _num_euler_steps
    if num_step_func is None:
        raise ValueError("step_type must be either 'onestep' or 'euler'")

    def _rproc_interp(
        X_: jax.Array,
        theta_: jax.Array,
        key: jax.Array,
        ctimes: jax.Array,
        covars: jax.Array,
        t1: float,
        t2: float,
        dt: float | None,
        accumvars: tuple[int, ...] | None,
        num_step_func: Callable,
    ) -> jax.Array:
        X_ = jnp.where(accumvars is not None, X_.at[:, accumvars].set(0), X_)
        nstep, dt2 = num_step_func(t1, t2, dt=dt)
        interp_helper2 = partial(_interp_helper, ctimes=ctimes, covars=covars, dt=dt2)

        X_, theta_, key, t = jax.lax.fori_loop(
            lower=0,
            upper=nstep,
            body_fun=interp_helper2,
            init_val=(X_, theta_, key, t1),
        )
        return X_

    return partial(
        _rproc_interp, dt=dt, accumvars=accumvars, num_step_func=num_step_func
    )


rprocess_interp_hacked_loop = _time_interp_hacked(
            jax.vmap(rproc, (0, None, 0, None, None, None)),
            step_type=step_type,
            dt=dt,
            accumvars=accumvars,
        )

import time
elapses_interp_hacked_loop = []
logliks = []
for i in range(10):
    start = time.perf_counter()
    logliks.append(_pfilter_internal(theta_arr, t0, times_arr, ys_arr, 10000, 
                                     rinitializer, rprocess_interp_hacked_loop, dmeasure, 
                                     ctimes_arr, covars_arr, -1, 
                                     jax.random.PRNGKey(42+i)))
    end = time.perf_counter()
    elapses_interp_hacked_loop.append(end - start)
elapses_interp_hacked_loop

[6.132768925977871,
 0.3786546681076288,
 0.37912258016876876,
 0.37852500984445214,
 0.3786205898504704,
 0.37903198995627463,
 0.37919161398895085,
 0.3786817258223891,
 0.37845401000231504,
 0.38014819705858827]

This yields a speedup to about 370ms. Not bad! This takes care of the lion's share (about one second) of the slowdown.

## Removing the call to `fori_loop()`

It is known that jax.lax.fori_loop leads to faster compile time, which can be advantageous for large loops, but leads to lower post-compile execution. Furthermore, it compiles the function, and it is generally optimal to not compile anything except at the highest level. So, when there are only 20 iterations, it can be advantageous not to compile. 

In [9]:
from typing import Callable

def _time_interp_hacked_loop(
    rproc: Callable,
    step_type: str,
    dt: float | None,
    accumvars: tuple[int, ...] | None,
) -> Callable:
    def _interp_helper(
        i: int,
        inputs: tuple[jax.Array, jax.Array, jax.Array, float],
        ctimes: jax.Array,
        covars: jax.Array,
        dt: float,
    ) -> tuple[jax.Array, jax.Array, jax.Array, float]:
        X_, theta_, key, t = inputs
        covars_t = covars[(t/dt).astype(int)]
        X_ = rproc(X_, theta_, key, covars_t, t, dt)
        t = t + dt
        return (X_, theta_, key, t)

    def _num_onestep_steps(t1: float, t2: float, dt: float) -> tuple[int, float]:
        return 1, t2 - t1

    def _num_euler_steps(
        t1: float, t2: float, dt: float
    ) -> tuple[jax.Array, jax.Array]:
        tol = jnp.sqrt(jnp.finfo(float).eps)

        nstep = jnp.ceil((t2 - t1) / dt / (1 + tol)).astype(int)
        dt2 = (t2 - t1) / nstep

        check1 = t1 + dt >= t2
        nstep = jnp.where(check1, 1, nstep)
        dt2 = jnp.where(check1, t2 - t1, dt2)

        check2 = t1 >= t2
        nstep = jnp.where(check2, 0, nstep)
        dt2 = jnp.where(check2, 0.0, dt2)

        return nstep, dt2

    num_step_func = None
    match step_type:
        case "onestep":
            num_step_func = _num_onestep_steps
        case "euler":
            num_step_func = _num_euler_steps
    if num_step_func is None:
        raise ValueError("step_type must be either 'onestep' or 'euler'")

    def _rproc_interp(
        X_: jax.Array,
        theta_: jax.Array,
        key: jax.Array,
        ctimes: jax.Array,
        covars: jax.Array,
        t1: float,
        t2: float,
        dt: float | None,
        accumvars: tuple[int, ...] | None,
        num_step_func: Callable,
    ) -> jax.Array:
        X_ = jnp.where(accumvars is not None, X_.at[:, accumvars].set(0), X_)
        nstep, dt2 = num_step_func(t1, t2, dt=dt)
        interp_helper2 = partial(_interp_helper, ctimes=ctimes, covars=covars, dt=dt2)

        t = t1
        for i in range(20):
            X_, theta_, key, t = interp_helper2(i, [X_, theta_, key, t])
        # X_, theta_, key, t = jax.lax.fori_loop(
        #     lower=0,
        #     upper=nstep,
        #     body_fun=interp_helper2,
        #     init_val=(X_, theta_, key, t1),
        # )
        return X_

    return partial(
        _rproc_interp, dt=dt, accumvars=accumvars, num_step_func=num_step_func
    )



rprocess_interp_hacked_loop = _time_interp_hacked_loop(
            jax.vmap(rproc, (0, None, 0, None, None, None)),
            step_type=step_type,
            dt=dt,
            accumvars=accumvars,
        )

import time
elapses_interp_hacked_loop = []
logliks = []
for i in range(10):
    start = time.perf_counter()
    logliks.append(_pfilter_internal(theta_arr, t0, times_arr, ys_arr, 10000, 
                                     rinitializer, rprocess_interp_hacked_loop, dmeasure, 
                                     ctimes_arr, covars_arr, -1, 
                                     jax.random.PRNGKey(42+i)))
    end = time.perf_counter()
    elapses_interp_hacked_loop.append(end - start)
elapses_interp_hacked_loop

[9.565887465141714,
 0.1390279410406947,
 0.1389743739273399,
 0.13877295493148267,
 0.13896874198690057,
 0.1390167858917266,
 0.13913850905373693,
 0.1389609829057008,
 0.13993217796087265,
 0.13947042194195092]

Now a call to `_pfilter_internal()` takes only 138ms. Removing `fori_loop` led to a further 200ms speedup. This is now even faster than `diffPomp` was.