In [1]:
from jax import value_and_grad
import jax.numpy as jnp
import numpy as np
import optax
from torch.utils.data import Dataset, DataLoader


from rebayes.extended_kalman_filter.dual_ekf import (
    make_dual_ekf_estimator,
    EKFParams,
)
from rebayes.dual_base import (
    DualRebayesParams, 
    ObsModel, 
    dual_rebayes_optimize_scan,
    form_tril_matrix,
)
from rebayes.utils import datasets, utils

  from .autonotebook import tqdm as notebook_tqdm
2023-03-30 03:47:06.604073: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2023-03-30 03:47:06.604206: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


In [2]:
train, test = datasets.load_1d_synthetic_dataset(500, 200)
val, _ = datasets.load_1d_synthetic_dataset(1_000, key=1)

X_train, y_train = train
X_val, y_val = val
X_test, y_test = test

In [3]:
class ToyDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = ToyDataset(np.array(X_val), np.array(y_val))
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [4]:
model_dims = [1, 20, 20, 1,]
model, flat_params, _, apply_fn = utils.get_mlp_flattened_params(model_dims)

In [5]:
params = DualRebayesParams(
    mu0 = flat_params,
    eta0 = 1.0,
)
obs = ObsModel(
    emission_mean_function = apply_fn,
    emission_cov_function = None,
)
ekf_params = EKFParams(
    method = "fdekf",
    obs_noise_estimator = "post"
)
estimator = make_dual_ekf_estimator(params, obs, ekf_params)
    

In [6]:
tx = optax.adam(1e-4)
def grad_callback(params, bel, pred_obs, t, x, y, pred_bel, params_bel, update_fn, predict_fn):
    C = jnp.atleast_1d(y).shape[0]
    
    def loss_fn(theta, x, y):
        L = form_tril_matrix(theta, C)
        R = L @ L.T
        new_params = params.replace(obs_noise = R)
        new_bel = update_fn(new_params, pred_bel, x, y)
        y_pred = predict_fn(new_params, new_bel, x)
        
        return jnp.sum((y_pred - y)**2)
    
    theta = params_bel.params
    
    return value_and_grad(loss_fn)(theta, x, y)
    

In [7]:
dual_rebayes_optimize_scan(
    estimator,
    dataloader,
    num_epochs=5,
    tx=tx,
    grad_callback=grad_callback,
)