# Subspace EKF

In [9]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.flatten_util import ravel_pytree
from rebayes.low_rank_filter import subspace_filter
from functools import partial



In [42]:
import tensorflow_datasets as tfds

def process_dataset(Xtr, Ytr, Xval, Yval, Xte, Yte, shuffle=False, oh_train=True, key=0):
    if isinstance(key, int):
        key = jax.random.PRNGKey(key)
        
    # Reshape data
    Xtr = Xtr.reshape(-1, 1, 28, 28, 1)
    if oh_train:
        Ytr = jax.nn.one_hot(Ytr, 10) # one-hot encode labels
    
    # Shuffle data
    if shuffle:
        idx = jax.random.permutation(key, jnp.arange(len(Xtr)))
        Xtr, Ytr = Xtr[idx], Ytr[idx]
    
    Xtr, Ytr, Xval, Yval, Xte, Yte = (jnp.array(data) for data in [Xtr, Ytr, Xval, Yval, Xte, Yte])
    
    dataset = {
        'train':[jnp.array(d) for d in (Xtr, Ytr)],
        'val': [jnp.array(d) for d in (Xval, Yval)],
        'test': [jnp.array(d) for d in (Xte, Yte)],
    }
    
    return dataset


def load_mnist_dataset(fashion=False, n_train=None, n_val=None, n_test=None):
    """Load MNIST train and test datasets into memory."""
    dataset='mnist'
    if fashion:
        dataset='fashion_mnist'
    ds_builder = tfds.builder(dataset)
    ds_builder.download_and_prepare()
    
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[10%:]', batch_size=-1))
    val_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[:10%]', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    
    # Normalize pixel values
    for ds in [train_ds, val_ds, test_ds]:
        ds['image'] = np.float32(ds['image']) / 255.
    
    n_train = min(n_train, len(train_ds['image'])) if n_train else len(train_ds['image'])
    n_val = min(n_val, len(val_ds['image'])) if n_val else len(val_ds['image'])
    n_test = min(n_test, len(test_ds['image'])) if n_test else len(test_ds['image'])
    
    X_train, y_train = (jnp.array(train_ds[key][:n_train]) for key in ['image', 'label'])
    X_val, y_val = (jnp.array(val_ds[key][:n_val]) for key in ['image', 'label'])
    X_test, y_test = (jnp.array(test_ds[key][:n_test]) for key in ['image', 'label'])
    
    dataset = process_dataset(X_train, y_train, X_val, y_val, X_test, y_test, shuffle=True)
        
    return dataset

In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
from typing import Callable

In [30]:
class NNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10, name="gets_cov")(x)
        x = nn.elu(x)
        x = nn.Dense(10)(x)
        x = nn.elu(x)
        x = nn.Dense(1)(x)
        return x

In [31]:
key = jax.random.PRNGKey(314)

model = NNet()
n_features = 2
Xinit = jnp.ones((10, n_features))
dim_subspace = 10
params_init = model.init(key, Xinit)
params_flat, rfn = ravel_pytree(params_init)
dim_full = len(params_flat)

pmatrix = jax.random.bernoulli(key, p=0.5, shape=(dim_full, dim_subspace))
params_latent = jax.random.normal(key, shape=(dim_subspace,))

In [32]:
params_latent

Array([-1.4314531 , -1.4140283 ,  0.5512039 , -0.33057716, -0.39858606,
        0.6754447 , -0.08324226, -0.3476744 ,  0.34246936, -0.03261461],      dtype=float32)

In [33]:
NNet().apply(params_init, Xinit).ravel()

Array([2.5442393, 2.5442393, 2.5442393, 2.5442393, 2.5442393, 2.5442393,
       2.5442393, 2.5442393, 2.5442393, 2.5442393], dtype=float32)

## Subspace agent

See [this issue](https://github.com/deepmind/optax/discussions/167) for implementation

In [34]:
import optax

In [35]:
def subcify(cls):
    class SubspaceModule(nn.Module):
        dim_in: int
        dim_subspace: int
        init_normal: Callable = nn.initializers.normal()
        init_proj: Callable = nn.initializers.normal()


        def init(self, rngs, *args, **kwargs):
            # TODO: Add case 
            r1, r2 = jax.random.split(rngs, 2)
            rngs_dict = {"params": r1, "fixed": r2}
            
            return nn.Module.init(self, rngs_dict, *args, **kwargs)

        def setup(self):

            key_dummy = jax.random.PRNGKey(0)
            params = cls().init(key_dummy, jnp.ones((1, self.dim_in)))
            params_all, reconstruct_fn = ravel_pytree(params)
            
            self.dim_full = len(params_all)
            self.reconstruct_fn = reconstruct_fn
            
            self.subspace = self.param(
                "subspace",
                self.init_proj,
                (self.dim_subspace,)
            )

            shape = (self.dim_full, self.dim_subspace)
            init_fn = lambda shape: self.init_proj(self.make_rng("fixed"), shape)
            self.projection = self.variable("fixed", "P", init_fn, shape).value

            shape = (self.dim_full,)
            init_fn = lambda shape: self.init_proj(self.make_rng("fixed"), shape)
            self.bias = self.variable("fixed", "b", init_fn, shape).value

        @nn.compact
        def __call__(self, x):
            params = self.projection @ self.subspace  + self.bias
            params = self.reconstruct_fn(params)
            return cls().apply(params, x)
    
    return SubspaceModule

In [36]:
@subcify
class SubNNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.elu(x)
        x = nn.Dense(10)(x)
        x = nn.elu(x)
        x = nn.Dense(1)(x)
        return x

In [37]:
key = jax.random.PRNGKey(314)
dim_in = 2
Xinit = jnp.ones((1, dim_in))
model = SubNNet(dim_in, dim_subspace=5)
params_init = model.init(key, Xinit)
# params_init

jax.tree_map(jnp.shape, params_init)

FrozenDict({
    fixed: {
        P: (151, 5),
        b: (151,),
    },
    params: {
        subspace: (5,),
    },
})

In [38]:
from flax.core import frozen_dict
from flax.training.train_state import TrainState

In [39]:
# https://github.com/deepmind/optax/discussions/167
def create_mask():
    # mask = {'params': {'fixed': 'fixed', 'params': 'params'}}
    mask = {"params": "params", "fixed": "fixed"}
    return frozen_dict.freeze(mask)

def zero_grads():
    def init_fn(_): 
        return ()
    def update_fn(updates, state, params=None):
        return jax.tree_map(jnp.zeros_like, updates), ()
    return optax.GradientTransformation(init_fn, update_fn)

tx = optax.multi_transform({'fixed': zero_grads(), 'params': optax.adam(learning_rate=1.0)},
                           create_mask())


In [40]:
state = TrainState.create(apply_fn=model.apply,
                                      params=params_init,
                                      tx=tx)

In [43]:
dataset = load_mnist_dataset()