# 0. Imports

In [None]:
try:
    from dynamax.generalized_gaussian_ssm.inference import *
    from dynamax.generalized_gaussian_ssm.models import ParamsGGSSM
    import chex
    import flax
    import ml_collections
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -qq git+https://github.com/probml/dynamax
    print('installing chex')
    %pip install chex
    print('installing flax')
    %pip install flax
    print('installing ml_collections')
    %pip install ml_collections
    from dynamax.generalized_gaussian_ssm.inference import *
    from dynamax.generalized_gaussian_ssm.models import ParamsGGSSM
    import chex
    import flax
    import ml_collections

In [None]:
from typing import Sequence
from functools import partial
from collections import deque
from pathlib import Path

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import lax, jit, vmap, pmap
from jax.tree_util import tree_map, tree_reduce
import flax.linen as nn
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalDiag as MVND
import chex
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from jax.flatten_util import ravel_pytree
import optax
from tqdm import tqdm
from tensorflow_probability.substrates import jax as tfp
import pandas as pd

from sklearn import datasets
from sklearn import preprocessing
from dynamax.generalized_gaussian_ssm.dekf.diagonal_inference import (
    _jacrev_2d,
    DEKFParams,
    _stationary_dynamics_diagonal_predict,
    _fully_decoupled_ekf_condition_on,
    stationary_dynamics_fully_decoupled_conditional_moments_gaussian_filter,
    _variational_diagonal_ekf_condition_on,
    stationary_dynamics_variational_diagonal_extended_kalman_filter
)
from dynamax.generalized_gaussian_ssm.dekf.utils import (
    MLP,
    get_mlp_flattened_params,
    loss_optax
)

# 1. Dataset Generation

In [None]:
def generate_input_grid(input):
    """Generate grid on input space.
    Args:
        input (DeviceArray): Input array to determine the range of the grid.
    Returns:
        input_grid: Generated input grid.
    """    
    # Define grid limits
    xmin, ymin = input.min(axis=0) - 0.1
    xmax, ymax = input.max(axis=0) + 0.1

    # Define grid
    step = 0.1
    x_grid, y_grid = jnp.meshgrid(jnp.mgrid[xmin:xmax:step], jnp.mgrid[ymin:ymax:step])
    input_grid = jnp.concatenate([x_grid[...,None], y_grid[...,None]], axis=2)

    return input_grid

In [None]:
def posterior_predictive_grid_ekf(grid, mean, cov, apply, binary=False, ekf_type='fcekf', 
                                  post_pred_type='mc', num_samples=100,key=0):
    """Compute posterior predictive probability for each point in grid
    Args:
        grid (DeviceArray): Grid on which to predict posterior probability.
        mean (DeviceArray): Posterior mean of parameters.
        cov (DeviceArray): Posterior cov of parameters.
        apply (Callable): Apply function for MLP.
        binary (bool, optional): Flag to determine whether to round probabilities to binary outputs. Defaults to True.
    Returns:
        _type_: _description_
    """
    if isinstance(key, int):
        key = jr.PRNGKey(key)

    if ekf_type == 'fcekf':
        mvn = MVN(loc=mean, covariance_matrix=cov)
    else:
        mvn = MVND(loc=mean, scale_diag=cov)
    # Sample parameters
    sampled_params = mvn.sample(seed=key, sample_shape=num_samples)
    
    def posterior_predictive_mc(param):
        inferred_fn = lambda x: apply(param, x)
        fn_vec = jnp.vectorize(inferred_fn, signature='(2)->(3)')
        Z = fn_vec(grid)
        if binary:
            Z = jnp.rint(Z)
        return Z

    def posterior_predictive_immer(param):
        def inferred_fn(x):
            apply_fn = lambda p: apply(p, x)
            H = _jacrev_2d(apply_fn, mean)
            return apply_fn(mean) + H @ (param - mean)
        fn_vec = jnp.vectorize(inferred_fn, signature='(2)->(3)')
        Z = fn_vec(grid)
        if binary:
            Z = jnp.rint(Z)
        return Z
    if post_pred_type == 'mc':
        Zs = vmap(posterior_predictive_mc)(sampled_params)
    else:
        Zs = vmap(posterior_predictive_immer)(sampled_params)
    
    return Zs.mean(axis=0)

In [None]:
def posterior_predictive_grid(grid, mean, apply, binary=False):
    """Compute posterior predictive probability for each point in grid
    Args:
        grid (DeviceArray): Grid on which to predict posterior probability.
        mean (DeviceArray): Posterior mean of parameters.
        apply (Callable): Apply function for MLP.
        binary (bool, optional): Flag to determine whether to round probabilities to binary outputs. Defaults to True.
    Returns:
        _type_: _description_
    """    
    inferred_fn = lambda x: apply(mean, x)
    fn_vec = jnp.vectorize(inferred_fn, signature='(2)->(3)')
    Z = fn_vec(grid)
    if binary:
        Z = jnp.rint(Z)
    return Z

In [None]:
def plot_posterior_predictive(ax, X, Y, title, Xspace=None, Zspace=None, cmap=cm.rainbow):
    """Plot the 2d posterior predictive distribution.
    Args:
        ax (axis): Matplotlib axis.
        X (DeviceArray): Input array.
        title (str): Title for the plot.
        colors (list): List of colors that correspond to each element in X.
        Xspace (DeviceArray, optional): Input grid to predict on. Defaults to None.
        Zspace (DeviceArray, optional): Predicted posterior on the input grid. Defaults to None.
        cmap (str, optional): Matplotlib colormap. Defaults to "viridis".
    """    
    if Xspace is not None and Zspace is not None:
        ax.contourf(*(Xspace.T), (Zspace.T[0]), cmap=cmap, levels=50)
        ax.axis('off')
    colors = ['red' if y else 'blue' for y in Y]
    ax.scatter(*X.T, c=colors, edgecolors='black', s=50)
    ax.set_title(title)
    plt.tight_layout()
    return ax


In [None]:
def generate_spiral_dataset(num_per_class=500, validation_split=0.2, test_split=0.3, zero_var=1., one_var=1., shuffle=True, key=0):
    """Generate balanced, standardized 2d "spiral" binary classification dataset.
    Code adapted from https://gist.github.com/45deg/e731d9e7f478de134def5668324c44c5
    Args:
        num_per_class (int, optional): Number of points to generate per class. Defaults to 250.
        zero_val (float, optional): Noise variance for inputs withj label '0'. Defaults to 1.
        one_val (float, optional): Noise variance for inputs withj label '1'. Defaults to 1.
        shuffle (bool, optional): Flag to determine whether to return shuffled dataset. Defaults to True.
        key (int, optional): Initial PRNG seed for jax.random. Defaults to 0.
    Returns:
        input: Generated input.
        output: Generated binary output.
    """    
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    key1, key2, key3, key4 = jr.split(key, 4)

    theta = jnp.sqrt(jr.uniform(key1, shape=(num_per_class,))) * 2*jnp.pi
    r = 2*theta + jnp.pi
    generate_data = lambda theta, r: jnp.array([jnp.cos(theta)*r, jnp.sin(theta)*r]).T

    # Input data for output zero
    zero_input = generate_data(theta, r) + zero_var * jr.normal(key2, shape=(num_per_class, 2))

    # Input data for output one
    one_input = generate_data(theta, -r) + one_var * jr.normal(key3, shape=(num_per_class, 2))

    # Stack the inputs and standardize
    input = jnp.concatenate([zero_input, one_input])
    input = (input - input.mean(axis=0)) / input.std(axis=0)

    # Generate binary output
    output = jnp.concatenate([jnp.zeros(num_per_class), jnp.ones(num_per_class)])

    if shuffle:
        idx = jr.permutation(key4, jnp.arange(num_per_class * 2))
        input, output = input[idx], output[idx]

    val_index, test_index = num_per_class, int(1.4 * num_per_class)
    X_train, X_val, X_test = input[:val_index], input[val_index:test_index], input[test_index:]
    y_train, y_val, y_test = output[:val_index], output[val_index:test_index], output[test_index:]

    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
num_per_class=1000
X_train, X_val, X_test, y_train, y_val, y_test = generate_spiral_dataset(num_per_class)

In [None]:
print(X_train.shape)
print(X_val.shape)
print(X_test.shape)

In [None]:
init_vars = jnp.array([jnp.arange(9, 0, -1) * dec for dec in [1e-1, 1e-2, 1e-3]]).ravel()
model_dim_grid = {'MLP1': [2, 1], 
                  'MLP2': [2, 100, 1], 
                  'MLP3': [2, 30, 30, 1], 
                  'MLP4': [2, 20, 20, 20, 1]}
lrs = ms = jnp.array([1e-1, 5e-2, 1e-2, 5e-3, 1e-3, 5e-4, 1e-4, 1e-5])

# 2. Evaluation Functions

In [None]:
def reg_loss_fn(params, x, y, prior_dist, apply_fn, lamb):
    neg_log_prior = -prior_dist.log_prob(params)
    
    y, y_hat = jnp.atleast_1d(y), apply_fn(params, x)
    loss_value = optax.sigmoid_binary_cross_entropy(y_hat, y)
    nll = loss_value.mean()
    
    return nll + lamb * neg_log_prior

In [None]:
def fit_optax(params, optimizer, input, output, loss_fn, num_epochs, return_history=False,
              polyak_averaging=False, polyak_window_len=20):
    opt_state = optimizer.init(params)

    @jax.jit
    def step(params, opt_state, x, y):
        loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value
    
    if return_history:
        params_history = []
        if polyak_averaging:
            polyak_window = deque()
    
    _, unflatten_fn = ravel_pytree(params)
    
    for epoch in range(num_epochs):
        for i, (x, y) in enumerate(zip(input, output)):
            params, opt_state, loss_value = step(params, opt_state, x, y)
            if return_history:
                if polyak_averaging:
                    flattened_params, _  = ravel_pytree(params)
                    polyak_window.append(flattened_params)
                    if len(polyak_window) == polyak_window_len + 1:
                        polyak_window.popleft()
                    params = unflatten_fn(jnp.array(polyak_window).mean(axis=0))
                params_history.append(params)
    
    if return_history:
        return jnp.array(params_history)
    return params

In [None]:
@partial(jit, static_argnums=(2, 5, 6))
def evaluate_posterior(params, cov, apply_fn, X_test, y_test, 
                       posterior_predictive_estimation='map',
                       cov_type='full', num_samples=50, key=0):
    """ Evaluate negative log likelihood for given parameters and test set
    """
    if isinstance(key, int):
        key = jr.PRNGKey(key)

    @jit
    def evaluate_nll(params, X, y):
        logits = apply_fn(params, X)
        return optax.sigmoid_binary_cross_entropy(logits, y)

    def evaluate_ece(params, X, y):
        eps = 1e-3
        sigmoid_fn = lambda w, x: jnp.clip(jax.nn.sigmoid(apply_fn(w, x)), eps, 1-eps) # Clip to prevent divergence
        pred = vmap(sigmoid_fn, (None, 0))(params, X)
        pred = jnp.concatenate((1-pred, pred), axis=1)
        return tfp.stats.expected_calibration_error(20, logits=jnp.log(pred), labels_true=y.astype(int))
    
    @jit
    def evaluate_linearized_nll(params_map, params, X, y):
        logit_fn = lambda p: apply_fn(p, X)
        H = _jacrev_2d(logit_fn, params_map)
        logits = jnp.clip(logit_fn(params_map) + H @ (params - params_map), -10, 10)
        return optax.sigmoid_binary_cross_entropy(logits, y)

    ece = evaluate_ece(params, X_test, y_test)

    if posterior_predictive_estimation == 'map':
        nlls = vmap(evaluate_nll, (None, 0, 0))(params, X_test, y_test)
        result = {'nll': nlls.mean(), 'ece': ece}
        return result

    if cov_type == 'full':
        mvn = MVN(loc=params, covariance_matrix=cov)
    elif cov_type == 'diagonal':
        mvn = MVND(loc=params, scale_diag=cov)
    else:
        raise ValueError()
    sampled_params = mvn.sample(seed=key, sample_shape=num_samples)
    
    if posterior_predictive_estimation == 'mc':
        evaluate_average_nll = lambda p: vmap(evaluate_nll, (None, 0, 0))(p, X_test, y_test)
    elif posterior_predictive_estimation == 'immer':
        evaluate_average_nll = lambda p: vmap(evaluate_linearized_nll, (None, None, 0, 0))(params, p, X_test, y_test)
    else:
        raise ValueError()

    nlls = vmap(evaluate_average_nll)(sampled_params)
    result = {'nll': nlls.mean(), 'ece': ece}
    return result

In [None]:
def apply_hyperparam(init_var, model_dim, train_len):
    # Generate MLP model with specified dimensions
    _, flat_params, _, apply_fn = get_mlp_flattened_params(model_dim)

    # FCEKF parameters
    state_dim = flat_params.size
    eps = 1e-4
    sigmoid_fn = lambda w, x: jnp.clip(jax.nn.sigmoid(apply_fn(w, x)), eps, 1-eps) # Clip to prevent divergence
    fcekf_params = ParamsGGSSM(
        initial_mean=flat_params,
        initial_covariance=jnp.eye(state_dim) * init_var,
        dynamics_function=lambda w, _: w,
        dynamics_covariance = jnp.eye(state_dim) * 0.,
        emission_mean_function = lambda w, x: sigmoid_fn(w, x),
        emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
    )

    # DEKF parameters
    dekf_params = DEKFParams(
        initial_mean=flat_params,
        initial_cov_diag=jnp.ones((state_dim,)) * init_var,
        dynamics_cov_diag=jnp.ones((state_dim,)) * 0.,
        emission_mean_function = lambda w, x: sigmoid_fn(w, x),
        emission_cov_function = lambda w, x: sigmoid_fn(w, x) * (1 - sigmoid_fn(w, x))
    )

    
    # Loss function
    prior_dist = MVN(loc=jnp.zeros(flat_params.shape), covariance_matrix=jnp.eye(state_dim) * init_var)
    loss_fn = partial(reg_loss_fn, prior_dist = prior_dist, apply_fn = apply_fn, lamb = 1/train_len)

    return fcekf_params, dekf_params, flat_params, loss_fn, apply_fn

In [None]:
def evaluate_sgd(learning_rate, momentum, model_dim, X_train, y_train, X_test, y_test, regularize=False, init_var=None):
    sgd_optimizer = optax.sgd(learning_rate=learning_rate, momentum=momentum)
    _, flat_params, _, apply_fn = get_mlp_flattened_params(model_dim)
    if regularize:
        prior_dist = MVN(loc=jnp.zeros(flat_params.shape), covariance_matrix=jnp.eye(flat_params.size) * init_var)
        train_len = len(y_train)
        loss_fn = partial(reg_loss_fn, prior_dist=prior_dist, apply_fn = apply_fn, lamb=1/train_len)
    else:
        loss_fn = partial(loss_optax, 
                          loss_fn = lambda y, 
                          yhat: -(y * jnp.log(yhat) + (1-y) * jnp.log(1 - yhat)), 
                          apply_fn = lambda w, x: jax.nn.sigmoid(apply_fn(w, x)))
    sgd_post = fit_optax(flat_params, sgd_optimizer, X_train, y_train, loss_fn, num_epochs=1, return_history=True)

    evaluate_sgd = lambda p: evaluate_posterior(p, None, apply_fn, X_test, y_test)['nll']
    sgd_result = vmap(evaluate_sgd)(sgd_post)
    
    # Take mean of nll over the final 'avg_window' parameters
    avg_window = int(len(X_test) / 4)
    return sgd_result[-avg_window:].mean()

# 3. SGD Tuning

In [None]:
def grid_search_sgd(learning_rates, momentums, init_var, model_dims, X_train, y_train, X_test, y_test, regularize=False):
    def _evaluate_sgd(mdim):
        mdim = list(mdim)
        _evaluate_sgd_mdim = lambda lr, m: evaluate_sgd(lr, m, mdim, X_train, y_train, X_test, y_test, regularize, init_var)
        result = vmap(vmap(_evaluate_sgd_mdim, (None, 0)), (0, None))(learning_rates, momentums)
        lr_indx, m_indx = jnp.unravel_index(result.argmin(), result.shape)
        lr_opt, m_opt = learning_rates[lr_indx], momentums[m_indx]

        return lr_opt, m_opt

    model_dims = {key: jnp.array(val) for key, val in model_dims.items()}
    result_sgd = tree_map(
        lambda model_dim: _evaluate_sgd(model_dim),
        model_dims
    )

    return result_sgd

In [None]:
gridsearch_result = dict()
for init_var in tqdm(init_vars):
    tmp_result = grid_search_sgd(lrs, ms, init_var, model_dim_grid, X_train, y_train, X_val, y_val, regularize=True)
    for key, val in tmp_result.items():
        lr, m = float(val[0]), float(val[1])
        if key not in gridsearch_result:
            gridsearch_result[key] = dict()
        gridsearch_result[key][f'{init_var:.3f}'] = (lr, m)

# 4. Model Comparison

In [None]:
def generate_model_comparison_table(init_var_grid, model_dim_grid, X_train, y_train, X_test, y_test, sgd_gridsearch_result):
    fcekf_result, fdekf_result, vdekf_result, sgd_result = dict(), dict(), dict(), dict()
    results = {'fcekf': fcekf_result, 'fdekf': fdekf_result, 'vdekf': vdekf_result, 'sgd': sgd_result}

    for model_dim_type, model_dim in model_dim_grid.items():
        pbar = tqdm(init_var_grid)
        for init_var in pbar:
            pbar.set_description(f'model_dim={model_dim_type}, init_var={init_var:.3f}')
            curr_index = f'{model_dim_type}_{init_var}'
            fcekf_params, dekf_params, flat_params, loss_fn, apply_fn = apply_hyperparam(init_var, model_dim, len(y_train))
            fcekf_post = conditional_moments_gaussian_filter(fcekf_params, EKFIntegrals(), y_train, inputs=X_train)
            fdekf_post = stationary_dynamics_fully_decoupled_conditional_moments_gaussian_filter(dekf_params, y_train, inputs=X_train)
            vdekf_post = stationary_dynamics_variational_diagonal_extended_kalman_filter(dekf_params, y_train, inputs=X_train)
            posts = {'fcekf': fcekf_post, 'fdekf': fdekf_post, 'vdekf': vdekf_post}
            for post_type, post in posts.items():
                if post_type == 'fcekf':
                    evaluate_posterior_map = lambda p, c: evaluate_posterior(p, c, apply_fn, X_test, y_test, 'map')
                    evaluate_posterior_mc = lambda p, c: evaluate_posterior(p, c, apply_fn, X_test, y_test, 'mc')
                    evaluate_posterior_immer = lambda p, c: evaluate_posterior(p, c, apply_fn, X_test, y_test, 'immer')
                else:
                    evaluate_posterior_map = lambda p, c: evaluate_posterior(p, c, apply_fn, X_test, y_test, 'map', 'diagonal')
                    evaluate_posterior_mc = lambda p, c: evaluate_posterior(p, c, apply_fn, X_test, y_test, 'mc', 'diagonal')
                    evaluate_posterior_immer = lambda p, c: evaluate_posterior(p, c, apply_fn, X_test, y_test, 'immer', 'diagonal')
                result_map = vmap(evaluate_posterior_map, (0, 0))(post.filtered_means, post.filtered_covariances)
                result_mc = vmap(evaluate_posterior_mc, (0, 0))(post.filtered_means, post.filtered_covariances)
                result_immer = vmap(evaluate_posterior_immer, (0, 0))(post.filtered_means, post.filtered_covariances)
                result_by_type = {'map': result_map, 'mc': result_mc, 'immer': result_immer}
                for est_type, result_nll in result_by_type.items():
                    for eval_type in ['nll', 'ece']:
                        results[post_type][curr_index+'_'+est_type+'_'+eval_type] = result_nll[eval_type]
            
            # SGD Optimizer
            sgd_optimizer = optax.sgd(learning_rate=sgd_gridsearch_result[model_dim_type][f'{init_var:.3f}'][0],
                                      momentum=sgd_gridsearch_result[model_dim_type][f'{init_var:.3f}'][1])

            sgd_post = fit_optax(flat_params, sgd_optimizer, X_train, y_train, loss_fn, num_epochs=1, return_history=True)
            evaluate_sgd = lambda p: evaluate_posterior(p, None, apply_fn, X_test, y_test, 'map')
            sgd_result = vmap(evaluate_sgd)(sgd_post)
            for result_type in ['map', 'mc', 'immer']:
                for eval_type in ['nll', 'ece']:
                    results['sgd'][curr_index+'_'+result_type+'_'+eval_type] = sgd_result[eval_type]
    return results

In [None]:
result = generate_model_comparison_table(init_vars, model_dim_grid, X_train, y_train, X_test, y_test, gridsearch_result)

In [None]:
# Store resulting NLLs and ECEs as .csv files
for model_type in ['fcekf', 'fdekf', 'vdekf', 'sgd']:
    df = pd.DataFrame.from_dict(result[model_type], orient='index')
    filepath = Path(Path.cwd(), f'nll_results_{model_type}.csv')
    filepath.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(filepath)

In [None]:
fcekf_df = pd.read_csv(Path(Path.cwd(), 'nll_results_fcekf.csv'), index_col=0)
fdekf_df = pd.read_csv(Path(Path.cwd(), 'nll_results_fdekf.csv'), index_col=0)
vdekf_df = pd.read_csv(Path(Path.cwd(), 'nll_results_vdekf.csv'), index_col=0)
sgd_df = pd.read_csv(Path(Path.cwd(), 'nll_results_sgd.csv'), index_col=0)
dfs = {'fcekf': fcekf_df, 'fdekf': fdekf_df, 'vdekf': vdekf_df, 'sgd': sgd_df}

In [None]:
train_steps = jnp.arange(len(X_train))

predictive_dict = {'map': 'plugin', 'mc': 'MC', 'immer': 'linearized'}

for init_var in init_vars:
    for model_dim_type, model_dim in model_dim_grid.items():
        fig, ax = plt.subplots()
        for predictive_type in ['map', 'mc', 'immer']:
            ax.plot(train_steps, fcekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'FCEKF-{predictive_dict[predictive_type]}')
            ax.plot(train_steps, fdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'FDEKF-{predictive_dict[predictive_type]}')
            ax.plot(train_steps, vdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'VDEKF-{predictive_dict[predictive_type]}')
        ax.plot(train_steps, sgd_df.loc[f'{model_dim_type}_{init_var}_map_nll'], label='SGD')
        ax.set_title(f'NLL comparison for Initial Var={init_var:.3f}, Model Dim=[{",".join(map(str, model_dim))}]')
        ax.legend();
        fig.savefig(Path(Path.cwd(), 'nll_all_models', f'{model_dim_type}_{init_var:.3f}.png'))

In [None]:
train_steps = jnp.arange(len(X_train))

predictive_dict = {'map': 'plugin', 'mc': 'MC', 'immer': 'linearized'}

for init_var in init_vars:
    for model_dim_type, model_dim in model_dim_grid.items():
        fig, ax = plt.subplots()
        for predictive_type in ['map', 'immer']:
            ax.plot(train_steps, fcekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'FCEKF-{predictive_dict[predictive_type]}')
            ax.plot(train_steps, fdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'FDEKF-{predictive_dict[predictive_type]}')
            ax.plot(train_steps, vdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'VDEKF-{predictive_dict[predictive_type]}')
        ax.plot(train_steps, sgd_df.loc[f'{model_dim_type}_{init_var}_map_nll'], label='SGD')
        ax.set_title(f'NLL comparison for Initial Var.={init_var:.3f}, Model Dim. =[{",".join(map(str, model_dim))}]')
        ax.legend();
        fig.savefig(Path(Path.cwd(), 'nll_plugin_linearized', f'{model_dim_type}_{init_var:.3f}.png'))

In [None]:
train_steps = jnp.arange(len(X_train))

predictive_dict = {'map': 'plugin', 'mc': 'MC', 'immer': 'linearized'}

for init_var in init_vars:
    for model_dim_type, model_dim in model_dim_grid.items():
        fig, ax = plt.subplots()
        for predictive_type in ['map', 'immer']:
            ax.plot(train_steps, fcekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'FCEKF-{predictive_dict[predictive_type]}')
            ax.plot(train_steps, fdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'FDEKF-{predictive_dict[predictive_type]}')
            ax.plot(train_steps, vdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_nll'], label=f'VDEKF-{predictive_dict[predictive_type]}')
        # ax.plot(train_steps, sgd_df.loc[f'{model_dim_type}_{init_var}_map_nll'], label='SGD')
        ax.set_title(f'NLL comparison for Initial Var.={init_var:.3f}, Model Dim. =[{",".join(map(str, model_dim))}]')
        ax.legend();
        fig.savefig(Path(Path.cwd(), 'nll_ekfs', f'{model_dim_type}_{init_var:.3f}.png'))

In [None]:
train_steps = jnp.arange(len(X_train))

for init_var in init_vars:
    for model_dim_type, model_dim in model_dim_grid.items():
        fig, ax = plt.subplots()
        for predictive_type in ['map', 'immer']:
            ax.plot(train_steps, fcekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_ece'], label=f'fcekf_{predictive_type}')
            ax.plot(train_steps, fdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_ece'], label=f'fdekf_{predictive_type}')
            ax.plot(train_steps, vdekf_df.loc[f'{model_dim_type}_{init_var}_{predictive_type}_ece'], label=f'vdekf_{predictive_type}')
        ax.plot(train_steps, sgd_df.loc[f'{model_dim_type}_{init_var}_map_ece'], label='sgd')
        ax.set_title(f'ECE comparison for init_var={init_var:.3f}, model_dim=[{",".join(map(str, model_dim))}]')
        ax.legend();
        fig.savefig(Path(Path.cwd(), 'ece_all_models', f'{model_dim_type}_{init_var:.3f}.png'))

# 5. Heat Maps

In [None]:
input_grid = generate_input_grid(X_train)

In [None]:
selected_mlps = ['MLP2', 'MLP3', 'MLP4']
selected_init_vars = [0.1, 0.01, 0.001]
X, y = X_train, y_train
for mlp in selected_mlps:
    model_dim = model_dim_grid[mlp]
    for init_var in selected_init_vars:
        curr_init_var = f'{init_var:.3f}'
        fcekf_params, dekf_params, flat_params, loss_fn, apply_fn = apply_hyperparam(init_var, model_dim, len(X))
        # EKF Models
        fcekf_weight = conditional_moments_gaussian_filter(fcekf_params, EKFIntegrals(), y, inputs=X).filtered_means[-1]
        fcekf_cov = conditional_moments_gaussian_filter(fcekf_params, EKFIntegrals(), y, inputs=X).filtered_covariances[-1]
        fdekf_weight = stationary_dynamics_fully_decoupled_conditional_moments_gaussian_filter(dekf_params, y, inputs=X).filtered_means[-1]
        fdekf_cov = stationary_dynamics_fully_decoupled_conditional_moments_gaussian_filter(dekf_params, y, inputs=X).filtered_covariances[-1]
        vdekf_weight = stationary_dynamics_variational_diagonal_extended_kalman_filter(dekf_params, y, inputs=X).filtered_means[-1]
        vdekf_cov = stationary_dynamics_variational_diagonal_extended_kalman_filter(dekf_params, y, inputs=X).filtered_covariances[-1]
        
        #SGD
        lr, m = gridsearch_result[mlp][curr_init_var]
        sgd_optimizer = optax.sgd(learning_rate=lr, momentum=m)
        sgd_weight = fit_optax(flat_params, sgd_optimizer, X, y, loss_fn, 1, True)[-1]
        eps = 1e-3
        sigmoid_fn = lambda w, x: jnp.clip(jax.nn.sigmoid(apply_fn(w, x)), eps, 1-eps)
        
        # Plot
        ekf_params = {'fcekf': (fcekf_weight, fcekf_cov), 
                      'fdekf': (fdekf_weight, fdekf_cov),
                      'vdekf': (vdekf_weight, vdekf_cov)}
        
        for ekf_type, params in ekf_params.items():
            weight, cov = params
            Z_plugin = posterior_predictive_grid(input_grid, weight, sigmoid_fn, binary=False)
            Z_mc = posterior_predictive_grid_ekf(input_grid, weight, cov, sigmoid_fn, binary=False, ekf_type=ekf_type)
            Z_linearized = posterior_predictive_grid_ekf(input_grid, weight, cov, sigmoid_fn, binary=False, ekf_type=ekf_type, post_pred_type='immer')
            for Z_type, Z in {'Plugin': Z_plugin, 'MC': Z_mc, 'Linearized': Z_linearized}.items():
                fig, ax = plt.subplots(figsize=(6, 5))
                title=f'{ekf_type.upper()} {Z_type}: Initial Var={curr_init_var}, Model Dim=[{",".join(map(str, model_dim))}]'
                plot_posterior_predictive(ax, X, y, title, input_grid, Z);
                fig.savefig(f'{ekf_type}_{Z_type}_{mlp}_{init_var:.3f}.png', bbox_inches = 'tight')

        Z_sgd = posterior_predictive_grid(input_grid, sgd_weight, sigmoid_fn, binary=False)
        fig, ax = plt.subplots(figsize=(6, 5))
        title=f'SGD: Initial Var={curr_init_var}, Model Dim=[{",".join(map(str, model_dim))}]'
        plot_posterior_predictive(ax, X, y, title, input_grid, Z_sgd);
        fig.savefig(f'sgd_{mlp}_{init_var:.3f}.png', bbox_inches = 'tight')
