In [1]:
try:
    import smolgp
except ImportError:
    %pip install -q smolgp

try:
    import tinygp
except ImportError:
    %pip install -q tinygp
    
import jax
key = jax.random.PRNGKey(0)
jax.config.update("jax_enable_x64", True)

(parallel)=

# Parallelized GP solvers on GPU

In {ref}`introssm`, we saw how the traditional Kalman filter and RTS smoother sequentially solve for the conditional filtered and smoothed distributions at each data point. As it turns out, these distributions have associative properties that enable a reframing of the Kalman/RTS algorithms as an [all-prefix-sums](https://en.wikipedia.org/wiki/Prefix_sum) problem, which can be efficiently solved by means of parallel-scan algorithms in 
\begin{align}
\mathcal{O}(N/T + \log T)
\end{align}
runtime complexity, for $N$ data ponts and $T$ parallel workers. We can see this gives the usual parallel speedup factor of $T$ when $N \gg T$ (usual scenario), although if $T \gtrsim N$ you may see scaling as good as $\mathcal{O}(\log N)$. 

[Särkkä and García-Fernández (2021)](https://ieeexplore.ieee.org/document/9013038) introduced this parallel method, which was extended to the case of integrated measurements in [Yaghoobi and Särkkä (2025)](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10804629), though those authors use a different framework to handle the integrations than we do. The `smolgp` method of augmenting the state space (see {ref}`integrated`) instead lets us use an elaboration of the [Särkkä and García-Fernández (2021)](https://ieeexplore.ieee.org/document/9013038) method, which is described in Section 3.2.4 of [Rubenzahl and Hattori et al. (2026)](https://arxiv.org/abs/2601.02527).

:::{admonition} **Running on GPU** 
:class: important
The parallel solvers are only significantly faster than their sequential counterparts when run on a GPU (see {ref}`benchmarks`). Make sure your hardware supports [`jax[cuda]`](https://github.com/google/jax/#installation), which you can install alongside `smolgp` with 
```bash
uv add smolgp[cuda]
```
or `uv add smolgp[cuda12]` or `uv add smolgp[cuda13]` for a specific version.
:::

In [2]:
import jax.numpy as jnp
from scipy.interpolate import make_smoothing_spline

# True kernel for sampling the underlying process
kernel_tiny = tinygp.kernels.quasisep.SHO(omega=2*jnp.pi/50, quality=5.0, sigma=1.0)

def get_true_process(true_kernel, tmin=0, tmax=1000, dt=1):
    t = jnp.arange(tmin, tmax, dt)
    true_gp = tinygp.GaussianProcess(true_kernel, t)
    # NOTE: gp.sample adds small random noise for numerical stability
    y_sample = true_gp.sample(key=jax.random.PRNGKey(32)) 
    f = make_smoothing_spline(t, y_sample, lam=dt/6)
    return t, f

## True process
t_true, f = get_true_process(kernel_tiny, tmin=0, tmax=1000, dt=1)
y_true = f(t_true)

## Mock data
t_train  = jnp.sort(jax.random.uniform(key, (50,), minval=0, maxval=1000))
yerr = 0.75 * jnp.ones_like(t_train)
y_train = f(t_train) + yerr * jax.random.normal(key, t_train.shape)
yerr_train = jnp.full_like(t_train, yerr)

To use the parallel solver, simply build the GP object with `solver=smolgp.solvers.ParallelStateSpaceSolver`:

In [8]:
gp_smol = smolgp.GaussianProcess(
    kernel=smolgp.kernels.SHO(omega=2*jnp.pi/50, quality=5.0, sigma=1.0),
    X=t_train,
    diag=yerr_train**2,
    solver=smolgp.solvers.ParallelStateSpaceSolver,
)
gp_smol.log_probability(y_train)

Array(-74.10304301, dtype=float64)

In [7]:
gp_tiny = tinygp.GaussianProcess(kernel=kernel_tiny, X=t_train, diag=yerr_train**2)
gp_tiny.log_probability(y_train)

Array(-74.10304301, dtype=float64)

:::{tip}
For integrated data, instead use `solver=smolgp.solvers.ParallelIntegratedStateSpaceSolver`.
:::