In [3]:
import numpy as np
from jaxtyping import Float, Array
from typing import Callable, NamedTuple, Union, Tuple, Any
from functools import partial
import chex
import optax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import lax, jacfwd, vmap, grad, jit
from jax.tree_util import tree_map, tree_reduce
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN

import matplotlib.pyplot as plt
import matplotlib.cm as cm

from dataclasses import dataclass
from itertools import cycle





In [2]:
import torch
print(torch.__version__)
import torch.func


2.0.0


In [4]:
def allclose(u, v):
    # we cast to numpy so we can compare pytorch and jax
    return np.allclose(np.array(u), np.array(v), atol=1e-3)

# Dynamax version

In [5]:
from dynamax.linear_gaussian_ssm import LinearGaussianSSM

In [29]:

def make_linreg_data_small():
    n_obs = 21
    x = jnp.linspace(0, 20, n_obs)
    X = x[:, None] # reshape to (T,1)
    y = jnp.array(
        [2.486, -0.303, -4.053, -4.336, -6.174, -5.604, -3.507, -2.326, -4.638, -0.233, -1.986, 1.028, -2.264,
        -0.451, 1.167, 6.652, 4.145, 5.268, 6.34, 9.626, 14.784])
    Y = y[:, None] # reshape to (T,1)
    return X, Y

def make_linreg_data(N, D):
    n_obs = N
    key = jr.PRNGKey(0)
    keys = jr.split(key, 3)
    X = jr.normal(keys[0], (N, D))
    w = jr.normal(keys[1], (D, 1))
    y = X @ w + 0.1*jr.normal(keys[2], (N, 1))
    return X, y

def make_linreg_prior(D):
    obs_var = 0.1
    mu0 = jnp.zeros(D)
    Sigma0 = jnp.eye(D) * 1
    return (obs_var, mu0, Sigma0)

def batch_bayes(X,Y):
    N, D = X.shape
    X1 = jnp.column_stack((jnp.ones(N), X))  # Include column of 1s
    y = Y[:,0] # extract column vector
    (obs_var, mu0, Sigma0) = make_linreg_prior(D+1)
    posterior_prec = jnp.linalg.inv(Sigma0) + X1.T @ X1 / obs_var
    cov_batch = jnp.linalg.inv(posterior_prec)
    b = jnp.linalg.inv(Sigma0) @ mu0 + X1.T @ y / obs_var
    mu_batch = jnp.linalg.solve(posterior_prec, b)
    return mu_batch, cov_batch



In [49]:
X, Y = make_linreg_data(N=200, D=50)
N, D = X.shape
X1 = jnp.column_stack((jnp.ones(N), X))  # Include column of 1s
(obs_var, mu0, Sigma0) = make_linreg_prior(D+1)
nfeatures = X1.shape[1]
F = jnp.eye(nfeatures) # dynamics = I
Q = jnp.zeros((nfeatures, nfeatures))  # No parameter drift.
R = jnp.ones((1, 1)) * obs_var





In [50]:
# dynamax linear gaussian ssm: stores posterior covariance at each time step
lgssm = LinearGaussianSSM(state_dim = nfeatures, emission_dim = 1, input_dim = 0)

params, _ = lgssm.initialize(
    initial_mean=mu0,
    initial_covariance=Sigma0,
    dynamics_weights=F,
    dynamics_covariance=Q,
    emission_weights=X1[:, None, :], # (t, 1, D) where D = num input features
    emission_covariance=R,
    )
lgssm_posterior = lgssm.filter(params, Y) 


mu_kf = lgssm_posterior.filtered_means[-1]
cov_kf = lgssm_posterior.filtered_covariances[-1]
mu_batch, cov_batch = batch_bayes(X,Y)

mu_kf = lgssm_posterior.filtered_means[-1]
cov_kf = lgssm_posterior.filtered_covariances[-1]
mu_batch, cov_batch = batch_bayes(X,Y)
assert allclose(mu_batch, mu_kf)
assert allclose(cov_batch, cov_kf)

# JAX version

In [54]:
def predict(m, S, F, Q):
    mu_pred = F @ m 
    Sigma_pred = F @ S @ F.T + Q
    return mu_pred, Sigma_pred

def psd_solve(A,b):
    A = A + 1e-6
    return jnp.linalg.solve(A,b)

def condition_on(m, P, H, R, y):
    S = R + H @ P @ H.T
    K = psd_solve(S, H @ P).T
    Sigma_cond = P - K @ S @ K.T
    mu_cond = m + K @ (y - H @ m)
    return mu_cond, Sigma_cond


def kf(params, emissions, return_covs=False):
    F, Q, R = params['F'], params['Q'], params['R']
    def step(carry, t):
        ll, pred_mean, pred_cov = carry
        H = params['Ht'][t]
        y = emissions[t]
        ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
        filtered_mean, filtered_cov = condition_on(pred_mean, pred_cov, H, R, y)
        pred_mean, pred_cov = predict(filtered_mean, filtered_cov, F, Q)
        carry = (ll, pred_mean, pred_cov)
        if return_covs:
            return carry, (filtered_mean, filtered_cov)
        else:
            return carry, (filtered_mean, None)     
    
    num_timesteps = len(emissions)
    carry = (0.0, params['mu0'], params['Sigma0'])
    (ll, _, _), (filtered_means, filtered_covs) = lax.scan(step, carry, jnp.arange(num_timesteps))
    return ll, filtered_means, filtered_covs





In [55]:

Ht = X1[:, None, :] # (T,D) -> (T,1,D), H[t]'z = (b w)' (1 x)
param_dict = {'mu0': mu0, 'Sigma0': Sigma0, 'F': F, 'Q': Q, 'R': R, 'Ht': Ht}

return_covs = True
ll, kf_means, kf_covs = kf(param_dict, Y, return_covs) 
print(kf_means.shape, kf_covs)

# compare to dynamax

assert allclose(ll, lgssm_posterior.marginal_loglik)
assert allclose(kf_means, lgssm_posterior.filtered_means)
if return_covs:
    assert allclose(kf_covs, lgssm_posterior.filtered_covariances)


(200, 51) None
CPU times: user 311 ms, sys: 4.42 ms, total: 316 ms
Wall time: 309 ms


# Torch version

In [57]:
def predict_pt(m, S, F, Q):
    mu_pred = F @ m 
    Sigma_pred = F @ S @ F.T + Q
    return mu_pred, Sigma_pred


def psd_solve_pt(A,b):
    A = A + 1e-6
    return torch.linalg.solve(A,b)

def condition_on_pt(m, P, H, R, y):
    S = R + H @ P @ H.T
    K = psd_solve_pt(S, H @ P).T
    Sigma_cond = P - K @ S @ K.T
    mu_cond = m + K @ (y - H @ m)
    return mu_cond, Sigma_cond

def kf_pt(params, emissions, return_covs):
    F, Q, R = params['F'], params['Q'], params['R']
    def step(carry, t):
        ll, pred_mean, pred_cov = carry
        H = params['Ht'][t]
        y = emissions[t]
        #ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
        filtered_mean, filtered_cov = condition_on_pt(pred_mean, pred_cov, H, R, y)
        pred_mean, pred_cov = predict_pt(filtered_mean, filtered_cov, F, Q)
        carry = (ll, pred_mean, pred_cov)
        if return_covs:
            return carry, (filtered_mean, filtered_cov)
        else:
            return carry, filtered_mean
    
    num_timesteps = len(emissions)
    D = len(params['mu0'])
    filtered_means = torch.zeros((num_timesteps, D))
    if return_covs:
        filtered_covs = torch.zeros((num_timesteps, D, D))
    else:
        filtered_covs = None
    ll = 0
    carry = (ll, params['mu0'], params['Sigma0'])
    for t in range(num_timesteps):
        if return_covs:
            carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
        else:
            carry, filtered_means[t] = step(carry, t)
    return ll, filtered_means, filtered_covs

In [66]:
def kf_pt_jit(params, emissions, return_covs):
    F, Q, R = params['F'], params['Q'], params['R']
    def step(carry, t):
        ll, pred_mean, pred_cov = carry
        H = params['Ht'][t]
        y = emissions[t]
        #ll += MVN(H @ pred_mean, H @ pred_cov @ H.T + R).log_prob(y)
        filtered_mean, filtered_cov = condition_on_pt(pred_mean, pred_cov, H, R, y)
        pred_mean, pred_cov = predict_pt(filtered_mean, filtered_cov, F, Q)
        carry = (ll, pred_mean, pred_cov)
        if return_covs:
            return carry, (filtered_mean, filtered_cov)
        else:
            return carry, filtered_mean
    
    step_jit = torch.compile(step)
    num_timesteps = len(emissions)
    D = len(params['mu0'])
    filtered_means = torch.zeros((num_timesteps, D))
    if return_covs:
        filtered_covs = torch.zeros((num_timesteps, D, D))
    else:
        filtered_covs = None
    ll = 0
    carry = (ll, params['mu0'], params['Sigma0'])
    for t in range(num_timesteps):
        if return_covs:
            carry, (filtered_means[t], filtered_covs[t]) = step_jit(carry, t)
        else:
            carry, filtered_means[t] = step_jit(carry, t)
    return ll, filtered_means, filtered_covs

In [58]:


F_pt = torch.tensor(np.array(F))
Q_pt = torch.tensor(np.array(Q))
R_pt = torch.tensor(np.array(R))
Ht_pt = torch.tensor(np.array(Ht))
mu0_pt = torch.tensor(np.array(mu0))
Sigma0_pt = torch.tensor(np.array(Sigma0))
param_dict_pt = {'mu0': mu0_pt, 'Sigma0': Sigma0_pt, 'F': F_pt, 'Q': Q_pt, 'R': R_pt, 'Ht': Ht_pt}
Y_pt = torch.tensor(np.array(Y))

return_covs = True
ll_pt, kf_means_pt, kf_covs_pt = kf_pt(param_dict_pt, Y_pt, return_covs) 
print(kf_means_pt.shape, kf_covs_pt)
assert(allclose(kf_means, kf_means_pt))
if return_covs:
    assert(allclose(kf_covs, kf_covs_pt))


torch.Size([200, 51]) None
CPU times: user 29.1 ms, sys: 2.26 ms, total: 31.3 ms
Wall time: 30.4 ms


# Timing comparison

In [59]:
X, Y = make_linreg_data(N=500, D=1000) # make larger dataset
N, D = X.shape
X1 = jnp.column_stack((jnp.ones(N), X))  # Include column of 1s
(obs_var, mu0, Sigma0) = make_linreg_prior(D+1)
nfeatures = X1.shape[1]
F = jnp.eye(nfeatures) # dynamics = I
Q = jnp.zeros((nfeatures, nfeatures))  # No parameter drift.
R = jnp.ones((1, 1)) * obs_var



In [61]:
%%time
Ht = X1[:, None, :] # (T,D) -> (T,1,D), H[t]'z = (b w)' (1 x)
param_dict = {'mu0': mu0, 'Sigma0': Sigma0, 'F': F, 'Q': Q, 'R': R, 'Ht': Ht}

return_covs = False
ll, kf_means, kf_covs = kf(param_dict, Y, return_covs) 



CPU times: user 49.7 s, sys: 1.03 s, total: 50.8 s
Wall time: 4.83 s


In [62]:
%%time

F_pt = torch.tensor(np.array(F))
Q_pt = torch.tensor(np.array(Q))
R_pt = torch.tensor(np.array(R))
Ht_pt = torch.tensor(np.array(Ht))
mu0_pt = torch.tensor(np.array(mu0))
Sigma0_pt = torch.tensor(np.array(Sigma0))
param_dict_pt = {'mu0': mu0_pt, 'Sigma0': Sigma0_pt, 'F': F_pt, 'Q': Q_pt, 'R': R_pt, 'Ht': Ht_pt}
Y_pt = torch.tensor(np.array(Y))

ll_pt, kf_means_pt, kf_covs_pt = kf_pt(param_dict_pt, Y_pt, return_covs) 

assert(allclose(kf_means, kf_means_pt))
if return_covs:
    assert(allclose(kf_covs, kf_covs_pt))

CPU times: user 22.4 s, sys: 830 ms, total: 23.2 s
Wall time: 20.9 s


In [63]:
#kf_pt_jit = torch.compile(kf_pt)

In [67]:
%%time

ll_pt, kf_means_pt, kf_covs_pt = kf_pt_jit(param_dict_pt, Y_pt, return_covs) 

assert(allclose(kf_means, kf_means_pt))
if return_covs:
    assert(allclose(kf_covs, kf_covs_pt))

BackendCompilerFailed: debug_wrapper raised InvalidCxxCompiler: No working C++ compiler found in torch._inductor.config.cpp.cxx: (None, 'g++')

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
