# LoFi v.s. subspaceEKF

In [115]:
import jax
import einops
import jax.numpy as jnp
import flax.linen as nn
from typing import Callable
from rebayes.extended_kalman_filter import ekf
from rebayes.low_rank_filter.subspace_filter import subcify
from rebayes.low_rank_filter import lofi
from rebayes.datasets.datasets import load_mnist

In [28]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
res = load_mnist()
res = jax.tree_map(jnp.array, res)
(X_train, y_train), (X_test, y_test) = res
X_train = einops.rearrange(X_train, "obs width height -> obs (width height)")
X_test = einops.rearrange(X_test, "obs width height -> obs (width height)")

y_train_ohe = jax.nn.one_hot(y_train, 10)
y_test_ohe = jax.nn.one_hot(y_test, 10)

In [86]:
class NNet(nn.Module):
    activation: Callable = nn.elu
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(300)(x)
        x = self.activation(x)
        x = nn.Dense(300)(x)
        x = self.activation(x)
        x = nn.Dense(100)(x)
        x = self.activation(x)
        x = nn.Dense(10)(x)
        x = nn.softmax(x)
        return x

In [87]:
key = jax.random.PRNGKey(314)
model = NNet()
params_init = model.init(key, X_train[:10])
params_init_flat, _ = ravel_pytree(params_init)

## LoFi agent

In [107]:
agent, recfn = lofi.init_classification_agent(
    model,
    X_train[:10],
    dynamics_weights=1.0,
    dynamics_covariance=1e-7,
    memory_size=20,
)

In [108]:
bel, outputs = agent.scan(
    initial_mean=params_init_flat,
    initial_covariance=0.05,
    X=X_train[:5000],
    Y=y_train[:5000],
    progress_bar=False,
)

bel = jax.block_until_ready(bel)

In [109]:
yhat_test = agent.predict_obs(bel, X_test).argmax(axis=1)
(yhat_test == y_test).mean()

Array(0.098, dtype=float32)

## Subspace agent

In [127]:
SubNNet = subcify(NNet)
dim_in = 28 ** 2

model_sub = SubNNet(dim_in, dim_subspace=20)

params_init = model_sub.init(key, X_train[:10])
params_init_flat, _ = ravel_pytree(params_init)

In [131]:
pfixed, psubspace_init = params_init["fixed"], params_init["params"]
psubspace_init_flat, recfn = ravel_pytree(psubspace_init)

In [129]:
def applyfn(psubspace, X):
    pfull = {
        "fixed": pfixed,
        "params": psubspace
    }
    return model.apply(pfull, X)

In [136]:
dynamics_weights=1.0,
dynamics_covariance=1e-7,
method="fcekf"


In [137]:
def apply_fn(flat_params, x):
    return apply_fn(recfn(flat_params), x)

def emission_cov_fn(flat_params, x):
    p = apply_fn(flat_params, x)
    return jnp.diag(p) - jnp.outer(p, p) + eps * jnp.eye(len(p))

agent = ekf.RebayesEKF(
    dynamics_weights_or_function=dynamics_weights,
    dynamics_covariance=dynamics_covariance,
    emission_mean_function=apply_fn,
    emission_cov_function=emission_cov_fn,
    adaptive_emission_cov=False,
    dynamics_covariance_inflation_factor=0.0,
    emission_dist=lambda mean, cov: tfd.Normal(loc=mean, scale=jnp.sqrt(cov)),
    method=method,
)

In [140]:
bel, outputs = agent.scan(
    initial_mean=psubspace_init_flat,
    initial_covariance=0.05,
    X=X_train[:5000],
    Y=y_train[:5000],
    progress_bar=False,
)

bel = jax.block_until_ready(bel)

AttributeError: 'tuple' object has no attribute 'ndim'

In [109]:
yhat_test = agent.predict_obs(bel, X_test).argmax(axis=1)
(yhat_test == y_test).mean()

Array(0.098, dtype=float32)