# Subspace EKF

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax.flatten_util import ravel_pytree
from rebayes.low_rank_filter import subspace_filter
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

In [34]:
from typing import Callable

In [41]:
class NNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(3)(x)
        x = nn.elu(x)
        x = nn.Dense(3)(x)
        x = nn.elu(x)
        x = nn.Dense(1)(x)
        return x

In [42]:
key = jax.random.PRNGKey(314)

In [43]:
model = NNet()
Xinit = jnp.ones((10, 2))
dim_subspace = 10
params_init = model.init(key, Xinit)
params_flat, rfn = ravel_pytree(params_init)
dim_full = len(params_flat)

pmatrix = jax.random.bernoulli(key, p=0.5, shape=(dim_full, dim_subspace))
params_latent = jax.random.normal(key, shape=(dim_subspace,))

In [44]:
params_latent

Array([-1.4314531 , -1.4140283 ,  0.5512039 , -0.33057716, -0.39858606,
        0.6754447 , -0.08324226, -0.3476744 ,  0.34246936, -0.03261461],      dtype=float32)

In [45]:
NNet().apply(params_init, Xinit)

Array([[0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289],
       [0.01260289]], dtype=float32)

In [56]:
class SubNNet(nn.Module):
    # Why do we need Callable? fails if it's not there
    init_proj: Callable = nn.initializers.normal()

    # How to get the dim_subspace and dim_full?
    def setup(self):
        self.subspace = self.param(
            "subspace",
            self.init_proj,
            (dim_subspace,)
        )
        self.projection = self.param(
            "projection",
            self.init_proj,
            (dim_full, dim_subspace)
        )

    @nn.compact
    def __call__(self, x):
        params = self.projection @ self.subspace
        params = rfn(params)
        return NNet().apply(params, x)

model = SubNNet()
pinit = model.init(key, Xinit)

In [61]:
jax.tree_map(jnp.shape, pinit)

FrozenDict({
    params: {
        projection: (25, 10),
        subspace: (10,),
    },
})

In [58]:
pinit

FrozenDict({
    params: {
        subspace: Array([-0.00447499,  0.00400571,  0.00299776,  0.00182526,  0.00267615,
               -0.01588027, -0.02566168,  0.00466987, -0.00023074,  0.00239409],      dtype=float32),
        projection: Array([[-3.35986656e-03,  8.87939893e-03,  2.63753789e-03,
                -9.52443632e-04, -2.63619353e-03,  1.00701936e-02,
                 1.81923248e-02, -2.12386553e-03, -7.71253696e-03,
                -7.08877854e-03],
               [-7.91336969e-03, -3.24458559e-03,  1.06810564e-02,
                -1.96658017e-04, -6.87617576e-03,  3.18945415e-04,
                -6.20286632e-03,  6.22718781e-03, -1.12819381e-03,
                -2.13826890e-03],
               [ 1.43241156e-02, -7.11677223e-03,  4.49566497e-03,
                -6.44596992e-03, -1.64688937e-03, -1.20240133e-02,
                 4.03506635e-03,  1.77159607e-02,  6.19832147e-03,
                 1.88459537e-03],
               [-8.38529412e-03, -6.91581890e-03,  5.47279092e-0

In [57]:
model.apply(pinit, Xinit)

Array([[7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05],
       [7.723302e-05]], dtype=float32)

In [52]:
jax.tree_map(jnp.shape, params_init)

FrozenDict({
    params: {
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (3,),
            kernel: (3, 3),
        },
        Dense_2: {
            bias: (1,),
            kernel: (3, 1),
        },
    },
})

## Subspace agent

In [256]:
import optax

In [22]:
@subspace_filter.subcify
class NNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(3)(x)
        x = nn.elu(x)
        x = nn.Dense(3)(x)
        x = nn.elu(x)
        x = nn.Dense(1)(x)
        return x

In [23]:
key = jax.random.PRNGKey(314)
model = NNet(dim_subspace=10)
Xinit = jnp.ones((10, 2))
params_init = model.init(key, Xinit, 1)
# params_init

jax.tree_map(jnp.shape, params_init)

FrozenDict({
    params: {
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (3,),
            kernel: (3, 3),
        },
        Dense_2: {
            bias: (1,),
            kernel: (3, 1),
        },
        projection_matrix: (10, 10),
    },
})

In [24]:
model.apply(params_init, Xinit, 1)

1.0

## Random ideas

In [244]:
import flax

In [237]:
class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

In [246]:
model = BiasAdderWithRunningMean()

In [249]:
variables = model.init(key, Xinit)
variables

FrozenDict({
    batch_stats: {
        mean: Array([0., 0.], dtype=float32),
    },
    params: {
        bias: Array([0., 0.], dtype=float32),
    },
})

In [250]:
state, params = flax.core.pop(variables, 'params')
state

FrozenDict({
    batch_stats: {
        mean: Array([0., 0.], dtype=float32),
    },
})

In [251]:
params

FrozenDict({
    bias: Array([0., 0.], dtype=float32),
})