# Subspace EKF

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.flatten_util import ravel_pytree
from rebayes.low_rank_filter import subspace_filter
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

In [110]:
from typing import Callable

In [111]:
class NNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10, name="gets_cov")(x)
        x = nn.elu(x)
        x = nn.Dense(10)(x)
        x = nn.elu(x)
        x = nn.Dense(1)(x)
        return x

In [112]:
key = jax.random.PRNGKey(314)

model = NNet()
n_features = 2
Xinit = jnp.ones((10, n_features))
dim_subspace = 10
params_init = model.init(key, Xinit)
params_flat, rfn = ravel_pytree(params_init)
dim_full = len(params_flat)

pmatrix = jax.random.bernoulli(key, p=0.5, shape=(dim_full, dim_subspace))
params_latent = jax.random.normal(key, shape=(dim_subspace,))

In [114]:
params_latent

Array([-1.4314531 , -1.4140283 ,  0.5512039 , -0.33057716, -0.39858606,
        0.6754447 , -0.08324226, -0.3476744 ,  0.34246936, -0.03261461],      dtype=float32)

In [115]:
NNet().apply(params_init, Xinit).ravel()

Array([2.5442393, 2.5442393, 2.5442393, 2.5442393, 2.5442393, 2.5442393,
       2.5442393, 2.5442393, 2.5442393, 2.5442393], dtype=float32)

In [52]:
def get_fparams(cls, X):
    dummy = cls(None, name="d")
    dummy_params = dummy.init(key, X)
    dummy_params, _ = ravel_pytree(dummy_params)
    return len(dummy_params)

In [None]:
class SubNNet(nn.Module):
    # Why do we need Callable? fails if it's not there
    init_proj: Callable = nn.initializers.normal()

    def init(self, rngs, *args, **kwargs):
        n_full = get_fparams(NNet, Xinit)
        self.n_full = n_full
        res = nn.Module.init(self, rngs, *args, **kwargs) 
        return res

    def setup(self):
        self.subspace = self.param(
            "subspace",
            self.init_proj,
            (dim_subspace,)
        )
        self.projection = self.param(
            "projection",
            self.init_proj,
            (dim_full, dim_subspace)
        )

    @nn.compact
    def __call__(self, x):
        params = self.projection @ self.subspace
        params = rfn(params)
        self.a = 3
        return NNet().apply(params, x)

model = SubNNet()
pinit = model.init(key, Xinit)

## Subspace agent

In [118]:
import optax

In [152]:
def subcify(cls):
    class SubspaceModule(nn.Module):
        dim_in: int
        dim_subspace: int
        init_normal: Callable = nn.initializers.normal()
        init_proj: Callable = nn.initializers.normal()


        def init(self, rngs, *args, **kwargs):
            # TODO: Add case when
            r1, r2 = jax.random.split(rngs, 2)
            rngs_dict = {"params": r1, "fixed": r2}
            
            return nn.Module.init(self, rngs_dict, *args, **kwargs)

        def setup(self):

            key_dummy = jax.random.PRNGKey(0)
            params = cls().init(key_dummy, jnp.ones((1, self.dim_in)))
            params_all, reconstruct_fn = ravel_pytree(params)
            
            self.dim_full = len(params_all)
            self.reconstruct_fn = reconstruct_fn
            
            self.subspace = self.param(
                "subspace",
                self.init_proj,
                (self.dim_subspace,)
            )
            # self.projection = self.param(
            #     "projection",
            #     self.init_proj,
            #     (self.dim_full, self.dim_subspace)
            # )

            shape = (self.dim_full, self.dim_subspace)
            init_fn = lambda shape: self.init_proj(self.make_rng("fixed"), shape)
            self.projection = self.variable(
                
            )
            self.projection = self.variable("fixed", "P", init_fn, shape).value

        @nn.compact
        def __call__(self, x):
            params = self.projection @ self.subspace
            params = self.reconstruct_fn(params)
            return cls().apply(params, x)
    
    return SubspaceModule

In [153]:
@subcify
class SubNNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.elu(x)
        x = nn.Dense(10)(x)
        x = nn.elu(x)
        x = nn.Dense(1)(x)
        return x

In [158]:
key = jax.random.PRNGKey(314)
dim_in = 2
model = SubNNet(dim_in, dim_subspace=5)
Xinit = jnp.ones((10, 2))
params_init = model.init(key, Xinit)
# params_init

jax.tree_map(jnp.shape, params_init)

FrozenDict({
    fixed: {
        P: (151, 5),
    },
    params: {
        subspace: (5,),
    },
})

In [159]:
model.apply(params_init, Xinit)

Array([[0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768],
       [0.0002768]], dtype=float32)

## Yet another test

In [None]:
class SubspaceBiasMLP(nn.Module):
    n_hidden: int
    n_in: int
    n_out: int
    dim_subspace: int
    b: ArrayDevice
    init_normal: Callable = nn.initializers.normal()
    init_proj: Callable = nn.initializers.normal()
    base_model: nn.Module = FeatureMLP
    
    def setup(self):
        self.mlp = self.base_model(self.n_hidden)
        
        key_dummy = jax.random.PRNGKey(0)
        params = self.mlp.init(key_dummy, jnp.ones((1, self.n_in)))
        params_all, reconstruct_fn = ravel_pytree(params)
        
        self.dim_full = len(params_all)
        self.reconstruct_fn = reconstruct_fn
        
        self.z = self.param("z", self.init_normal, (self.dim_subspace,))
        self.last_layer = nn.Dense(self.n_out, name="last_layer")
        
        shape = (self.dim_full, self.dim_subspace)
        init_fn = lambda shape: self.init_proj(self.make_rng("fixed"), shape)
        self.projection_matrix = self.variable("fixed", "P", init_fn, shape).value

        
    @nn.compact
    def __call__(self, x):
        params = self.projection_matrix @ self.z + self.b
        params = self.reconstruct_fn(params)
        x = self.mlp.apply(params, x)
        y = self.last_layer(x)
        
        return y
