# (Coreset) Variational Continual Learning
## Rotating MNIST

In [100]:
import jax
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
from typing import Callable

from dynamax.utils import datasets

In [5]:
from rebayes import variational_continual_learning as vcl

In [2]:
%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = "retina"

## Load dataset

In [101]:
np.random.seed(314)
num_train = 10_000

train, test = datasets.load_rotated_mnist(target_digit=2)
X_train, y_train = train
X_test, y_test = test

X_train = jnp.array(X_train)
y_train = jnp.array(y_train)

X = jnp.array(X_train)[:num_train]
y = jnp.array(y_train)[:num_train]

ix_sort = jnp.argsort(y)
X = X[ix_sort]
y = y[ix_sort]

## Setup

In [86]:
class bMLP(nn.Module):
    n_out: int = 2
    activation = nn.relu
    @nn.compact
    def __call__(self, x):
        # μ(i) and ρ(i) params
        x = einops.repeat(x, "... -> c ...", c=2)
        x = vcl.BatchDense(100)(x)
        x = nn.relu(x)
        x = vcl.BatchDense(10)(x)
        x = nn.relu(x)
        x = vcl.BatchDense(self.n_out)(x)
        return x

In [93]:
model = bMLP()
key = jax.random.PRNGKey(314)
batch = jnp.ones((1, dim_features))
params_init = model.init(key, batch)

jax.tree_map(jnp.shape, params_init)

FrozenDict({
    params: {
        VmapDense_0: {
            bias: (2, 100),
            kernel: (2, 784, 100),
        },
        VmapDense_1: {
            bias: (2, 10),
            kernel: (2, 100, 10),
        },
        VmapDense_2: {
            bias: (2, 2),
            kernel: (2, 10, 2),
        },
    },
})

In [96]:
dim_features = 28 ** 2
dim_output = 1
buffer_size = 100
learning_rate = 1e-3


state_init = vcl.VCLState.create(
    apply_fn=model.apply,
    params=params_init,
    tx=optax.adam(learning_rate),
    buffer_size=buffer_size,
    dim_features=dim_features,
    dim_output=dim_output,
    prior_mean=0.0,
    prior_std=1/2,
)

In [103]:
def cost_fn(params, state, X, y):
    """
    TODO:
    Add more general way to compute observation-model log-probability
    """
    scale_obs = 1.0
    scale_prior = state.prior_std
    reconstruct_fn = state.reconstruct_fn
    
    # Sampled params
    params = state.sample_params(key, state, reconstruct_fn)
    params_flat = vcl.get_leaves(params)
    
    # Prior log probability (use initialised vals for mean?)
    logp_prior = distrax.Normal(loc=0.0, scale=scale_prior).log_prob(params_flat).sum()
    
    # Observation log-probability
    mu_obs = state.apply_fn(params, X).ravel()
    logp_obs = distrax.Normal(loc=mu_obs, scale=scale_obs).log_prob(y).sum()

    # Variational log-probability
    logp_variational = jax.tree_map(
        lambda mean, logvar, x: distrax.Normal(loc=mean, scale=jnp.exp(logvar / 2)).log_prob(x),
        state.mean, state.logvar, params
    )
    logp_variational = bbb.get_leaves(logp_variational).sum()
    
    return logp_variational - logp_prior - logp_obs

In [None]:
cost_fn(state_init.params, state_init, X_test, y_test)