#CMGF-EKF Evaluation for MLP Training

Author: Peter Chang([@petergchang](https://github.com/petergchang))

##0. Imports

In [1]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

import warnings
warnings.filterwarnings('ignore')

In [2]:
try:
    from ssm_jax.cond_moments_gaussian_filter.inference import *
    from ssm_jax.cond_moments_gaussian_filter.containers import *
    import flax.linen as nn
except ModuleNotFoundError:
    print('installing ssm_jax')
    %pip install -qq git+https://github.com/probml/ssm-jax.git
    %pip install -qq flax
    from ssm_jax.cond_moments_gaussian_filter.inference import *
    from ssm_jax.cond_moments_gaussian_filter.containers import *
    import flax.linen as nn

installing ssm_jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [3]:
from typing import Sequence
from functools import partial

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.cm as cm
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.flatten_util import ravel_pytree

#1. MLP Definition

In [4]:
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

In [5]:
def get_mlp_flattened_params(model_dims, key=0):
    if isinstance(key, int):
        key = jr.PRNGKey(key)

    # Define MLP model
    input_dim, features = model_dims[0], model_dims[1:]
    model = MLP(features)
    dummy_input = jnp.ones((input_dim,))

    # Initialize parameters using dummy input
    params = model.init(key, dummy_input)
    flat_params, unflatten_fn = ravel_pytree(params)

    # Define apply function
    def apply(flat_params, x, model, unflatten_fn):
        return model.apply(unflatten_fn(flat_params), jnp.atleast_1d(x))

    apply_fn = partial(apply, model=model, unflatten_fn=unflatten_fn)

    return model, flat_params, unflatten_fn, apply_fn

In [66]:
def separate_flat_params(model_dims):
    assert len(model_dims) > 1
    separate_params = []
    curr_idx = 0
    for layer in range(1, len(model_dims)):
        # Number of parameter elements corresponding to current layer
        num_prev, num_curr = model_dims[layer-1], model_dims[layer] # Number of nodes in prev, curr layer
        num_bias_params = num_curr
        num_weight_params = num_prev * num_curr
        num_params_curr_layer = num_bias_params + num_weight_params
        
        # Range of indices in flattened params array corresponding to current layer
        idx_range = jnp.arange(curr_idx, curr_idx + num_params_curr_layer)
        
        # Append list of indices for each node in current layer
        separate_params += [jnp.array([idx_range[i + num_curr * j] for j in range(num_prev + 1)]) for i in range(num_curr)]
        
        curr_idx += num_params_curr_layer
    
    # Function to aggregate separated params list 
    def aggregate_fn(separate_params, model_dims):
        assert len(model_dims) > 1
        aggregate_params_list = []
        curr_idx = 0
        for layer in range(1, len(model_dims)):
            # Flatten params sublist corresponding to each layer
            aggregate_params_list.append(jnp.ravel(jnp.array(separate_params[curr_idx:curr_idx + model_dims[layer]]), order='F'))
            curr_idx += model_dims[layer]
            
        return jnp.concatenate(aggregate_params_list)
    
    return separate_params, partial(aggregate_fn, model_dims = model_dims)