#CMGF-EKF Evaluation for MLP Training

Author: Peter Chang([@petergchang](https://github.com/petergchang))

##0. Imports

In [1]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

import warnings
warnings.filterwarnings('ignore')

In [2]:
try:
    from ssm_jax.cond_moments_gaussian_filter.inference import *
    from ssm_jax.cond_moments_gaussian_filter.containers import *
    import flax.linen as nn
except ModuleNotFoundError:
    print('installing ssm_jax')
    %pip install -qq git+https://github.com/probml/ssm-jax.git
    %pip install -qq flax
    from ssm_jax.cond_moments_gaussian_filter.inference import *
    from ssm_jax.cond_moments_gaussian_filter.containers import *
    import flax.linen as nn

installing ssm_jax
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 114 kB 8.4 MB/s 
[K     |████████████████████████████████| 180 kB 53.4 MB/s 
[K     |████████████████████████████████| 85 kB 3.9 MB/s 
[K     |████████████████████████████████| 145 kB 42.1 MB/s 
[K     |████████████████████████████████| 128 kB 46.6 MB/s 
[K     |████████████████████████████████| 217 kB 39.4 MB/s 
[K     |████████████████████████████████| 51 kB 7.3 MB/s 
[?25h  Building wheel for ssm-jax (PEP 517) ... [?25l[?25hdone




In [3]:
from typing import Sequence
from functools import partial

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.cm as cm
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.flatten_util import ravel_pytree
from jax import lax
from jax import vmap

#1. MLP Definition

In [4]:
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

In [5]:
def get_mlp_flattened_params(model_dims, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)

    # Define MLP model
    input_dim, features = model_dims[0], model_dims[1:]
    model = MLP(features)
    dummy_input = jnp.ones((input_dim,))

    # Initialize parameters using dummy input
    params = model.init(key, dummy_input)
    flat_params, unflatten_fn = ravel_pytree(params)

    # Define apply function
    def apply(flat_params, x, model, unflatten_fn):
        return model.apply(unflatten_fn(flat_params), jnp.atleast_1d(x))

    apply_fn = partial(apply, model=model, unflatten_fn=unflatten_fn)

    return model, flat_params, unflatten_fn, apply_fn

In [53]:
xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]

In [65]:
from collections import namedtuple
testtup = namedtuple("testtup", ['x', 'y', 'n'])
tup_list = [testtup(x=jnp.ones(5)*i, y=jnp.zeros(5) + i, n=i) for i in range(4)]


In [106]:
A = jnp.array([jnp.arange(100000)] * 10)

In [107]:
sum_fn_jax = lambda x: x.sum()

In [108]:
%%timeit
vmap(sum_fn_jax)(A)

1.05 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [109]:
A = {c: jnp.arange(100000) for c in [str(i) for i in range(10)]}

In [110]:
%%timeit
tree_map(lambda x: x.sum(), A)

240 µs ± 14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [67]:
tree_map(lambda *args: jnp.stack(args), *tup_list)

testtup(x=DeviceArray([[0., 0., 0., 0., 0.],
             [1., 1., 1., 1., 1.],
             [2., 2., 2., 2., 2.],
             [3., 3., 3., 3., 3.]], dtype=float32), y=DeviceArray([[0., 0., 0., 0., 0.],
             [1., 1., 1., 1., 1.],
             [2., 2., 2., 2., 2.],
             [3., 3., 3., 3., 3.]], dtype=float32), n=DeviceArray([0, 1, 2, 3], dtype=int32, weak_type=True))

In [138]:
# Define MLP architecture
input_dim, hidden_dims, output_dim = 2, [4, 4], 1
model_dims = [input_dim, *hidden_dims, output_dim]
model, flat_params, unflatten_fn, apply_fn = get_mlp_flattened_params(model_dims)

In [139]:
params = unflatten_fn(flat_params).unfreeze()

In [140]:
fn1, decoupled_params_dict_fn, *_ = decouple_flat_params(model_dims)

In [141]:
fn1(flat_params)

[DeviceArray([ 0.       , -1.2662824,  0.557601 ], dtype=float32),
 DeviceArray([0.        , 0.6269297 , 0.11622565], dtype=float32),
 DeviceArray([ 0.        ,  0.35720623, -0.27115023], dtype=float32),
 DeviceArray([ 0.        ,  0.04510251, -0.19996592], dtype=float32),
 DeviceArray([ 0.        ,  0.62617886,  0.5014711 , -0.6020613 ,
              -1.0135733 ], dtype=float32),
 DeviceArray([ 0.        ,  0.46124476,  0.5884816 , -0.5901677 ,
              -0.2912583 ], dtype=float32),
 DeviceArray([ 0.        , -0.05645338,  0.4206315 , -0.03776331,
               0.26864907], dtype=float32),
 DeviceArray([ 0.        , -0.03055437, -0.4431366 ,  0.53358877,
               0.12585698], dtype=float32),
 DeviceArray([ 0.        ,  0.12871431,  0.57244253, -0.36538973,
              -0.07911778], dtype=float32)]

In [149]:
A = decoupled_params_dict_fn(flat_params)

In [164]:
B = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])
B

DeviceArray([[[ 1,  2],
              [ 3,  4]],

             [[ 5,  6],
              [ 7,  8]],

             [[ 9, 10],
              [11, 12]]], dtype=int32)

In [165]:
jnp.sum(B, axis=0)

DeviceArray([[15, 18],
             [21, 24]], dtype=int32)

In [157]:
jnp.array(list(tree_map(lambda x: x.sum(), A).values())).sum()

DeviceArray(0.6834502, dtype=float32)

#2. Decouple Helper Functions

In [137]:
def decouple_flat_params(model_dims):
    assert len(model_dims) > 1
    decoupled_params_idx = []
    curr_idx = 0
    for layer in range(1, len(model_dims)):
        # Number of parameter elements corresponding to current layer
        num_prev, num_curr = model_dims[layer-1], model_dims[layer] # Number of nodes in prev, curr layer
        num_bias_params = num_curr
        num_weight_params = num_prev * num_curr
        num_params_curr_layer = num_bias_params + num_weight_params
        
        # Range of indices in flattened params array corresponding to current layer
        idx_range = jnp.arange(curr_idx, curr_idx + num_params_curr_layer)
        
        # Append list of indices for each node in current layer
        decoupled_params_idx += [jnp.array([idx_range[i + num_curr * j] for j in range(num_prev + 1)]) for i in range(num_curr)]
        
        curr_idx += num_params_curr_layer

    # Function to decouple parameters by node
    decoupled_params_fn = lambda params: [params[idx] for idx in decoupled_params_idx]
    decoupled_params_dict_fn = lambda params: {i: params[idx] for i, idx in enumerate(decoupled_params_idx)}

    params_sizes = jnp.array([0] + [len(node_params) for node_params in decoupled_params_idx])
    diag_idx = jnp.cumsum(params_sizes)
    # Function to separate parameter covariance matrix by node
    separate_cov_fn = lambda cov: [lax.dynamic_slice(cov, (diag_idx[i], diag_idx[i]), (params_sizes[i+1], params_sizes[i+1])) for i in range(0, len(diag_idx)-1)]
    
    # Function to recouple decoupled params list 
    def recouple_params_fn(decoupled_params, model_dims):
        assert len(model_dims) > 1
        recoupled_params_list = []
        curr_idx = 0
        for layer in range(1, len(model_dims)):
            # Flatten params sublist corresponding to each layer
            recoupled_params_list.append(jnp.ravel(jnp.array(decoupled_params[curr_idx:curr_idx + model_dims[layer]]), order='F'))
            curr_idx += model_dims[layer]
            
        return jnp.concatenate(recoupled_params_list)
    
    return decoupled_params_fn, decoupled_params_dict_fn, partial(recouple_params_fn, model_dims = model_dims), separate_cov_fn

#3. Decoupled CMGF-EKF-MLP

In [6]:
import chex

In [11]:
# Helper functions
_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
_process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
_process_input = lambda x, y: jnp.zeros((y,)) if x is None else x

In [7]:
@chex.dataclass
class CMGFParams:
    """Lightweight container for CMGF parameters.
    """
    initial_mean: chex.Array
    initial_covariance: chex.Array
    dynamics_function: Callable
    dynamics_covariance: chex.Array
    emission_mean_function: Callable
    emission_cov_function: Callable
    gaussian_expectation: Callable
    gaussian_cross_covariance: Callable

In [17]:
from jax import jacfwd

_jacfwd_2d = lambda f, x: jnp.atleast_2d(jacfwd(f)(x))

@chex.dataclass
class DCMGFParams(CMGFParams):
    decouple_fn: Callable = None
    gaussian_expectation: Callable = lambda f, m, P: jnp.atleast_1d(f(m))
    gaussian_cross_covariance: Callable = None

    def __post_init__(self):
        self.gaussian_cross_covariance = self._gaussian_cross_covariance
    
    lambda f, g, m, P: _jacfwd_2d(f, m) @ P @ _jacfwd_2d(g, m).T

    def _gaussian_cross_covariance(self, f, g, m, P, decoupled_params_idx):
        H = _jacfwd_2d(f, m)
        Hj = {i: H[idx] for i, idx in enumerate(decoupled_params_idx)}
        return 


In [26]:
A = [[1, 2], [1, 2, 3], [3, 4, 5, 6]]

In [175]:
A = jnp.array([[1, 2, 3], [4, 5, 6]])

In [179]:
A[:,jnp.array([0,1])]

DeviceArray([[1, 2],
             [4, 5]], dtype=int32)

In [169]:
A[1]

DeviceArray([0.        , 0.6269297 , 0.11622565], dtype=float32)

In [35]:
from jax.tree_util import tree_map


In [38]:
list_sum = lambda x: sum(x)

tree_map(list_sum, *A)

ValueError: ignored

In [180]:
AA = {1: 15, 2: 30}
BB = {1: 10, 2: 20}

In [182]:
tree_map(lambda x, b: x + b, AA, BB)

{1: 25, 2: 50}

In [None]:
def _decoupled_condition_on(m, P, decouple_fn, decouple_cov_fn, recouple_fn, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter):
    m_Y = lambda x: y_cond_mean(x, u)
    Cov_Y = lambda x: y_cond_cov(x, u)
    identity_fn = lambda x: x
    
    def _step(carry, _):
        prior_means, prior_covs = carry
        prior_mean = recouple_fn(prior_means)

        yhat = jnp.atleast_1d(m_Y(prior_mean))
        H = _jacfwd_2d(m_Y, prior_mean)
        Hs = {i: H[:,idx] for i, idx in enumerate(decoupled_params_idx)}
        S = jnp.atleast_1d(Cov_Y(prior_mean)) + jnp.sum(tree_map(lambda x: x.T @ P @ x, Hs), axis=0)
        log_likelihood = MVN(yhat, S).log_prob(jnp.atleast_1d(y))
        Ks = tree_map(lambda x: jnp.linalg.solve(S, x.T @ P), Hs)
        posterior_means = tree_map(lambda mm, kk: mm + kk, prior_means, tree_map(lambda x: x @ (y - yhat), Ks))
        posterior_covs = tree_map(lambda pp, kk, hh: pp - kk @ hh @ pp, prior_covs, Ks, Hs)
        return (posterior_means, posterior_covs), log_likelihood

    # Iterate re-linearization over posterior mean and covariance
    ms = decoupled_params_dict_fn(m)
    Ps = decoupled_cov_dict_fn(P)
    carry = (ms, Ps)
    (mus_cond, Sigmas_cond), lls = lax.scan(_step, carry, jnp.arange(num_iter))
    mu_cond, Sigma_cond = recouple_fn(mus_cond), recouple_cov_fn(Sigmas_cond)
    return lls[0], mu_cond, Sigma_cond

In [None]:
def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter):
    m_Y = lambda x: y_cond_mean(x, u)
    Cov_Y = lambda x: y_cond_cov(x, u)
    identity_fn = lambda x: x

    def _step(carry, _):
        prior_mean, prior_cov = carry
        yhat = g_ev(m_Y, prior_mean, prior_cov)
        S = g_ev(Cov_Y, prior_mean, prior_cov) + g_cov(m_Y, m_Y, prior_mean, prior_cov)
        log_likelihood = MVN(yhat, S).log_prob(jnp.atleast_1d(y))
        C = g_cov(identity_fn, m_Y, prior_mean, prior_cov)
        K = jnp.linalg.solve(S, C.T).T
        posterior_mean = prior_mean + K @ (y - yhat)
        posterior_cov = prior_cov - K @ S @ K.T
        return (posterior_mean, posterior_cov), log_likelihood

    # Iterate re-linearization over posterior mean and covariance
    carry = (m, P)
    (mu_cond, Sigma_cond), lls = lax.scan(_step, carry, jnp.arange(num_iter))
    return lls[0], mu_cond, Sigma_cond

In [None]:
def decoupled_extended_conditional_moments_gaussian_filter(params, emissions, num_iter=1, inputs=None):
    num_timesteps = len(emissions)

    # Process dynamics function and conditional emission moments to take in control inputs
    f = params.dynamics_function
    m_Y, Cov_Y = params.emission_mean_function, params.emission_cov_function
    f, m_Y, Cov_Y  = (_process_fn(fn, inputs) for fn in (f, m_Y, Cov_Y))
    inputs = _process_input(inputs, num_timesteps)

    # Gaussian expectation value function
    g_ev = params.gaussian_expectation
    g_cov = params.gaussian_cross_covariance

    def _step(carry, t):
        ll, pred_mean, pred_cov = carry

        # Get parameters and inputs for time index t
        Q = _get_params(params.dynamics_covariance, 2, t)
        u = inputs[t]
        y = emissions[t]

        # Condition on the emission
        log_likelihood, filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, m_Y, Cov_Y, u, y, g_ev, g_cov, num_iter)
        ll += log_likelihood

        # Predict the next state
        pred_mean, pred_cov, _ = _predict(filtered_mean, filtered_cov, f, Q, u, g_ev, g_cov)

        return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov)

    # Run the general linearization filter
    carry = (0.0, params.initial_mean, params.initial_covariance)
    (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
    return GSSMPosterior(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)


In [None]:
# Generate spiral dataset
# Adapted from https://gist.github.com/45deg/e731d9e7f478de134def5668324c44c5
def generate_spiral_dataset(key=0, num_per_class=250, zero_var=1., one_var=1., shuffle=True):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    key1, key2, key3, key4 = jr.split(key, 4)

    theta = jnp.sqrt(jr.uniform(key1, shape=(num_per_class,))) * 2*jnp.pi
    r = 2*theta + jnp.pi
    generate_data = lambda theta, r: jnp.array([jnp.cos(theta)*r, jnp.sin(theta)*r]).T

    # Data for output zero
    zero_input = generate_data(theta, r) + zero_var * jr.normal(key2, shape=(num_per_class, 2))
    zero_output = jnp.zeros((num_per_class, 1,))

    # Data for output one
    one_input = generate_data(theta, -r) + one_var * jr.normal(key3, shape=(num_per_class, 2))
    one_output = jnp.ones((num_per_class, 1,))

    # Stack the inputs and standardize
    input = jnp.concatenate([zero_input, one_input])
    input = (input - input.mean(axis=0)) / input.std(axis=0)

    # Generate binary output
    output = jnp.concatenate([jnp.zeros(num_per_class), jnp.ones(num_per_class)])

    if shuffle:
        idx = jr.permutation(key4, jnp.arange(num_per_class * 2))
        input, output = input[idx], output[idx]

    return input, output

In [None]:
# Generate data
input, output = generate_spiral_dataset()

# Define MLP architecture
input_dim, hidden_dims, output_dim = 2, [2], 1
model_dims = [input_dim, *hidden_dims, output_dim]
_, flat_params, _, apply_fn = get_mlp_flattened_params(model_dims)
decouple_fn, recouple_fn, decouple_cov_fn = decouple_flat_params(model_dims)

In [None]:
flat_params.size

9

In [None]:
jnp.arange(81).reshape(9,-1)

DeviceArray([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
             [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
             [18, 19, 20, 21, 22, 23, 24, 25, 26],
             [27, 28, 29, 30, 31, 32, 33, 34, 35],
             [36, 37, 38, 39, 40, 41, 42, 43, 44],
             [45, 46, 47, 48, 49, 50, 51, 52, 53],
             [54, 55, 56, 57, 58, 59, 60, 61, 62],
             [63, 64, 65, 66, 67, 68, 69, 70, 71],
             [72, 73, 74, 75, 76, 77, 78, 79, 80]], dtype=int32)

In [None]:
for A in decouple_cov_fn(jnp.arange(81).reshape(9,-1)):
    print(A, '\n')

[[ 0  1  2]
 [ 9 10 11]
 [18 19 20]] 

[[30 31 32]
 [39 40 41]
 [48 49 50]] 

[[60 61 62]
 [69 70 71]
 [78 79 80]] 



In [None]:
jax.scipy.linalg.block_diag(*decouple_cov_fn(jnp.eye(25))).shape

(9, 9)

In [None]:
from jax.config import config
config.update("jax_disable_jit", True)


from jax import numpy as jnp
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from ssm_jax.containers import GSSMPosterior

from jax.scipy.linalg import block_diag


# Helper functions
_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
_process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
_process_input = lambda x, y: jnp.zeros((y,)) if x is None else x


def _predict(m, P, f, Q, u, g_ev, g_cov):
    """Predict next mean and covariance under an additive-noise Gaussian filter

        p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
        where
            mu_pred = gev(f, m, P)
                    = \int f(x_t, u) N(x_t | m, P) dx_t
            Sigma_pred = gev((f - mu_pred)(f - mu_pred)^T, m, P) + Q
                       = \int (f(x_t, u) - mu_pred)(f(x_t, u) - mu_pred)^T
                           N(x_t | m, P)dx_t + Q

    Args:
        m (D_hid,): prior mean.
        P (D_hid,D_hid): prior covariance.
        f (Callable): dynamics function.
        Q (D_hid,D_hid): dynamics covariance matrix.
        u (D_in,): inputs.
        g_ev (Callable): Gaussian expectation value function.
        g_cov (Callable): Gaussian cross covariance function.

    Returns:
        mu_pred (D_hid,): predicted mean.
        Sigma_pred (D_hid,D_hid): predicted covariance.
        cross_pred (D_hid,D_hid): cross covariance term.
    """
    dynamics_fn = lambda x: f(x, u)
    identity_fn = lambda x: x
    mu_pred = g_ev(dynamics_fn, m, P)
    Sigma_pred = g_cov(dynamics_fn, dynamics_fn, m, P) + Q
    cross_pred = g_cov(identity_fn, dynamics_fn, m, P)
    return mu_pred, Sigma_pred, cross_pred


def _decoupled_predict(ms, Ps, f, Q, u, g_ev, g_cov):
    return vmap(_predict, (0, 0, None, None, None, None, None))(ms, Ps, f, Q, u, g_ev, g_cov)


def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter):
    """Condition a Gaussian potential on a new observation with arbitrary
       likelihood with given functions for conditional moments and make a
       Gaussian approximation.
       p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
         propto p(x_t | y_{1:t-1}, u_{1:t-1}) p(y_t | x_t, u_t)
         = N(x_t | m, P) ArbitraryDist(y_t |y_cond_mean(x_t), y_cond_cov(x_t))
         \approx N(x_t | mu_cond, Sigma_cond)
     where
        mu_cond = m + K*(y - yhat)
        yhat = gev(h, m, P)
        S = gev((h - yhat)(h - yhat)^T, m, P) + R
        C = gev((Identity - m)(h - yhat)^T, m, P)
        K = C * S^{-1}
        Sigma_cond = P - K S K'

    Args:
        m (D_hid,): prior mean.
        P (D_hid,D_hid): prior covariance.
        y_cond_mean (Callable): conditional emission mean function.
        y_cond_cov (Callable): conditional emission covariance function.
        u (D_in,): inputs.
        y (D_obs,): observation.
        g_ev (Callable): Gaussian expectation value function.
        g_cov (Callable): Gaussian cross covariance function.
        num_iter (int): number of re-linearizations around posterior for update step.

     Returns:
        log_likelihood (Scalar): prediction log likelihood for observation y
        mu_cond (D_hid,): conditioned mean.
        Sigma_cond (D_hid,D_hid): conditioned covariance.
    """
    m_Y = lambda x: y_cond_mean(x, u)
    Cov_Y = lambda x: y_cond_cov(x, u)
    identity_fn = lambda x: x

    def _step(carry, _):
        prior_mean, prior_cov = carry
        yhat = g_ev(m_Y, prior_mean, prior_cov)
        S = g_ev(Cov_Y, prior_mean, prior_cov) + g_cov(m_Y, m_Y, prior_mean, prior_cov)
        log_likelihood = MVN(yhat, S).log_prob(jnp.atleast_1d(y))
        C = g_cov(identity_fn, m_Y, prior_mean, prior_cov)
        K = jnp.linalg.solve(S, C.T).T
        posterior_mean = prior_mean + K @ (y - yhat)
        posterior_cov = prior_cov - K @ S @ K.T
        return (posterior_mean, posterior_cov), log_likelihood

    # Iterate re-linearization over posterior mean and covariance
    carry = (m, P)
    (mu_cond, Sigma_cond), lls = lax.scan(_step, carry, jnp.arange(num_iter))
    return lls[0], mu_cond, Sigma_cond


def _decoupled_condition_on(m, P, decouple_fn, decouple_cov_fn, recouple_fn, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter):
    m_Y = lambda x: y_cond_mean(x, u)
    Cov_Y = lambda x: y_cond_cov(x, u)
    identity_fn = lambda x: x
    
    def _step(carry, _):
        prior_mean, prior_cov = carry
        prior_means = decouple_fn(prior_mean)
        prior_covs = decouple_cov_fn(prior_cov)

        yhat = g_ev(m_Y, prior_mean, prior_cov)
        print([g_cov(m_Y, m_Y, prior_means[i], prior_covs[i]) for i, _ in enumerate(prior_means)])
        S = g_ev(Cov_Y, prior_mean, prior_cov) + vmap(g_cov, (None, None, 0, 0))(m_Y, m_Y, prior_means, prior_covs).sum()
        log_likelihood = MVN(yhat, S).log_prob(jnp.atleast_1d(y))
        Cs = vmap(g_cov, (None, None, 0, 0))(identity_fn, m_Y, prior_means, prior_covs)
        compute_K = lambda S, C: jnp.linalg.solve(S, C.T).T
        Ks = vmap(compute_K, (None, 0))(S, Cs)
        posterior_means = prior_means + Ks @ (y - yhat)
        compute_cov_diff = lambda S, K: K @ S @ K.T
        posterior_covs = prior_covs - vmap(compute_cov_diff, (None, 0))(S, Ks)
        return (posterior_means, posterior_covs), log_likelihood

    # Iterate re-linearization over posterior mean and covariance
    carry = (m, P)
    (mu_cond, Sigma_cond), lls = lax.scan(_step, carry, jnp.arange(num_iter))
    return lls[0], mu_cond, Sigma_cond


def statistical_linear_regression(mu, Sigma, m, S, C):
    """Return moment-matching affine coefficients and approximation noise variance
    given joint moments.
        g(x) \approx Ax + b + e where e ~ N(0, Omega)
        p(x) = N(x | mu, Sigma)
        m = E[g(x)]
        S = Var[g(x)]
        C = Cov[x, g(x)]

    Args:
        mu (D_hid): prior mean.
        Sigma (D_hid, D_hid): prior covariance.
        m (D_obs): E[g(x)].
        S (D_obs, D_obs): Var[g(x)]
        C (D_hid, D_obs): Cov[x, g(x)]

    Returns:
        A (D_obs, D_hid): _description_
        b (D_obs):
        Omega (D_obs, D_obs):
    """
    A = jnp.linalg.solve(Sigma.T, C).T
    b = m - A @ mu
    Omega = S - A @ Sigma @ A.T
    return A, b, Omega


def conditional_moments_gaussian_filter(params, emissions, num_iter=1, inputs=None):
    """Run an (iterated) conditional moments Gaussian filter to produce the
    marginal likelihood and filtered state estimates.

    Args:
        params: a CMGFParams instance (or object with the same fields)
        emissions (T,D_hid): array of observations.
        num_iter (int): number of linearizations around prior/posterior for update step.
        inputs (T,D_in): array of inputs.

    Returns:
        filtered_posterior: GSSMPosterior instance containing,
            marginal_log_lik
            filtered_means (T, D_hid)
            filtered_covariances (T, D_hid, D_hid)
    """
    num_timesteps = len(emissions)

    # Process dynamics function and conditional emission moments to take in control inputs
    f = params.dynamics_function
    m_Y, Cov_Y = params.emission_mean_function, params.emission_cov_function
    f, m_Y, Cov_Y  = (_process_fn(fn, inputs) for fn in (f, m_Y, Cov_Y))
    inputs = _process_input(inputs, num_timesteps)

    # Gaussian expectation value function
    g_ev = params.gaussian_expectation
    g_cov = params.gaussian_cross_covariance

    def _step(carry, t):
        ll, pred_mean, pred_cov = carry

        # Get parameters and inputs for time index t
        Q = _get_params(params.dynamics_covariance, 2, t)
        u = inputs[t]
        y = emissions[t]

        # Condition on the emission
        log_likelihood, filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, m_Y, Cov_Y, u, y, g_ev, g_cov, num_iter)
        ll += log_likelihood

        # Predict the next state
        pred_mean, pred_cov, _ = _predict(filtered_mean, filtered_cov, f, Q, u, g_ev, g_cov)

        return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov)

    # Run the general linearization filter
    carry = (0.0, params.initial_mean, params.initial_covariance)
    (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
    return GSSMPosterior(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)


def decoupled_conditional_moments_gaussian_filter(params, decouple_fn, decouple_cov_fn, recouple_fn, emissions, num_iter=1, inputs=None):
    num_timesteps = len(emissions)

    # Process dynamics function and conditional emission moments to take in control inputs
    f = params.dynamics_function
    m_Y, Cov_Y = params.emission_mean_function, params.emission_cov_function
    f, m_Y, Cov_Y  = (_process_fn(fn, inputs) for fn in (f, m_Y, Cov_Y))
    inputs = _process_input(inputs, num_timesteps)

    # Gaussian expectation value function
    g_ev = params.gaussian_expectation
    g_cov = params.gaussian_cross_covariance

    def _step(carry, t):
        ll, pred_mean, pred_cov = carry

        # Get parameters and inputs for time index t
        Q = _get_params(params.dynamics_covariance, 2, t)
        u = inputs[t]
        y = emissions[t]

        # Condition on the emission
        log_likelihood, filtered_mean, filtered_cov = _decoupled_condition_on(pred_mean, pred_cov, decouple_fn, decouple_cov_fn,
                                                                              recouple_fn, m_Y, Cov_Y, u, y, g_ev, g_cov, num_iter)
        ll += log_likelihood

        # Predict the next state
        pred_mean, pred_cov, _ = _predict(filtered_mean, filtered_cov, f, Q, u, g_ev, g_cov)

        return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov)

    # Run the general linearization filter
    carry = (0.0, params.initial_mean, params.initial_covariance)
    (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
    return GSSMPosterior(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs)


def iterated_conditional_moments_gaussian_filter(params, emissions, num_iter=2, inputs=None):
    """Run an iterated conditional moments Gaussian filter.

    Args:
        params: a CMGFParams instance (or object with the same fields)
        emissions (T,D_hid): array of observations.
        num_iter (int): number of linearizations around smoothed posterior.
        inputs (T,D_in): array of inputs.

    Returns:
        filtered_posterior: GSSMPosterior instance containing,
            marginal_log_lik
            filtered_means (T, D_hid)
            filtered_covariances (T, D_hid, D_hid)
    """
    filtered_posterior = conditional_moments_gaussian_filter(params, emissions, num_iter, inputs)
    return filtered_posterior


def conditional_moments_gaussian_smoother(params, emissions, filtered_posterior=None, inputs=None):
    """Run a conditional moments Gaussian smoother.

    Args:
        params: a CMGFParams instance (or object with the same fields)
        emissions (T,D_hid): array of observations.
        filtered_posterior (GSLRPosterior): filtered posterior to use for smoothing.
            If None, the smoother computes the filtered posterior directly.
        inputs (T,D_in): array of inputs.

    Returns:
        nlgssm_posterior: GSSMPosterior instance containing properties of
            filtered and smoothed posterior distributions.
    """
    num_timesteps = len(emissions)

    # Get filtered posterior
    if filtered_posterior is None:
        filtered_posterior = conditional_moments_gaussian_filter(params, emissions, inputs=inputs)
    ll, filtered_means, filtered_covs, *_ = filtered_posterior.to_tuple()

    # Process dynamics function to take in control inputs
    f  = _process_fn(params.dynamics_function, inputs)
    inputs = _process_input(inputs, num_timesteps)

    # Gaussian expectation value function
    g_ev = params.gaussian_expectation
    g_cov = params.gaussian_cross_covariance

    def _step(carry, args):
        # Unpack the inputs
        smoothed_mean_next, smoothed_cov_next = carry
        t, filtered_mean, filtered_cov = args

        # Get parameters and inputs for time index t
        Q = _get_params(params.dynamics_covariance, 2, t)
        u = inputs[t]

        # Prediction step
        pred_mean, pred_cov, pred_cross = _predict(filtered_mean, filtered_cov, f, Q, u, g_ev, g_cov)
        G = jnp.linalg.solve(pred_cov, pred_cross.T).T

        # Compute smoothed mean and covariance
        smoothed_mean = filtered_mean + G @ (smoothed_mean_next - pred_mean)
        smoothed_cov = filtered_cov + G @ (smoothed_cov_next - pred_cov) @ G.T

        return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov)

    # Run the smoother
    init_carry = (filtered_means[-1], filtered_covs[-1])
    args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1])
    _, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args)

    # Reverse the arrays and return
    smoothed_means = jnp.row_stack((smoothed_means[::-1], filtered_means[-1][None, ...]))
    smoothed_covs = jnp.row_stack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
    return GSSMPosterior(
        marginal_loglik=ll,
        filtered_means=filtered_means,
        filtered_covariances=filtered_covs,
        smoothed_means=smoothed_means,
        smoothed_covariances=smoothed_covs,
    )


def iterated_conditional_moments_gaussian_smoother(params, emissions, num_iter=1, inputs=None):
    """Run an iterated conditional moments Gaussian smoother.

    Args:
        params: an CMGFParams instance (or object with the same fields)
        emissions (T,D_hid): array of observations.
        num_iter (int): number of re-linearizations around smoothed posterior.
        inputs (T,D_in): array of inputs.

    Returns:
        nlgssm_posterior: GSSMPosterior instance containing properties of
            filtered and smoothed posterior distributions.
    """
    def _step(carry, _):
        # Relinearize around smoothed posterior from previous iteration
        smoothed_prior = carry
        smoothed_posterior = conditional_moments_gaussian_smoother(params, emissions, smoothed_prior, inputs)
        return smoothed_posterior, None

    smoothed_posterior, _ = lax.scan(_step, None, jnp.arange(num_iter))
    return smoothed_posterior


In [None]:
# Generate data
input, output = generate_spiral_dataset()

In [None]:
# Define MLP architecture
input_dim, hidden_dims, output_dim = 2, [2], 1
model_dims = [input_dim, *hidden_dims, output_dim]
_, flat_params, _, apply_fn = get_mlp_flattened_params(model_dims)
decouple_fn, recouple_fn, decouple_cov_fn = decouple_flat_params(model_dims)

In [None]:
flat_params

DeviceArray([ 0.        ,  0.        , -0.2788025 , -0.74077964,
             -0.47987294,  0.25528678,  0.        ,  0.94259536,
             -0.05666351], dtype=float32)

In [None]:
decouple_fn(flat_params)

[DeviceArray([ 0.        , -0.2788025 , -0.47987294], dtype=float32),
 DeviceArray([ 0.        , -0.74077964,  0.25528678], dtype=float32),
 DeviceArray([ 0.        ,  0.94259536, -0.05666351], dtype=float32)]

In [None]:
# Some model parameters and helper funciton
state_dim, emission_dim = flat_params.size, output_dim
sigmoid_fn = lambda w, x: jax.nn.sigmoid(apply_fn(w, x))

# Run CMGF-EKF to train the MLP Classifier
cmgf_ekf_params = EKFParams(
    initial_mean=flat_params,
    initial_covariance=jnp.eye(state_dim),
    dynamics_function=lambda w, x: w,
    dynamics_covariance=jnp.eye(state_dim) * 1e-4,
    emission_mean_function = lambda w, x: sigmoid_fn(w, x),
    emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
)
cmgf_ekf_post = decoupled_conditional_moments_gaussian_filter(cmgf_ekf_params, decouple_fn, decouple_cov_fn, recouple_fn, output, inputs=input)

# Extract history of filtered weight values
w_means, w_covs = cmgf_ekf_post.filtered_means, cmgf_ekf_post.filtered_covariances

TypeError: ignored