In [None]:
%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = "retina"

# 1. Regression (Linear Model)

## 1.0. Dataset

In [None]:
from functools import partial

from bayes_opt import BayesianOptimization
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp

from rebayes.extended_kalman_filter.ekf import RebayesEKF, RebayesOCLEKF
import rebayes.utils.callbacks as callbacks
import rebayes.utils.models as models

tfd = tfp.distributions
MVN = tfd.MultivariateNormalTriL
MVD = tfd.MultivariateNormalDiag

In [None]:
MEAN_VALS = [1.3, 1.0, 1.3, 0.95, 0.6, 0.25, 0.8, 0.5,]
CHANGE_POINTS = [451, 709, 958, 1547, 2147, 2769, 2957,]

def generate_time_series_data(
    key=0,
    n_points=3058,
    change_points=CHANGE_POINTS,
    means=MEAN_VALS,
    variance=0.01
):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    xs = jnp.arange(n_points)

    change_points = jnp.array([0, *change_points, n_points])
    segment_lengths = jnp.array(
        [change_points[i + 1] - change_points[i] for i in range(len(change_points) - 1)]
    )
    result = jnp.concatenate(
        [means[i]*jnp.ones(segment_lengths[i]) for i in range(len(segment_lengths))]
    )
    noise = jnp.sqrt(variance) * jr.normal(key, result.shape)
    ys = result + noise

    return xs, ys

In [None]:
xs, ys = generate_time_series_data()

In [None]:
# Plot
fig, ax = plt.subplots(figsize=(12, 3))
plt.scatter(xs, ys, s=0.5, color="black")
ax.set_xlabel("Time")
ax.set_ylabel("Output")
ax.set_yticks([0, 1]);

## 1.1. Kalman Filter (gamma = 1, Q = 0)

In [None]:
ekf = RebayesEKF(
    dynamics_weights_or_function=1.0,
    dynamics_covariance=0.,
    emission_mean_function=lambda w, x: w,
    emission_cov_function=lambda w, x: 0.05,
    method="fcekf",
)
init_mean, init_cov = jnp.array([0.0]), jnp.array([0.01])
ekf_kwargs = {
    "agent": ekf,
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    agent = kwargs["agent"]
    y_mean = agent.predict_obs(bel_pred, x)
    y_cov = agent.predict_obs_cov(bel_pred, x)
    log_prob = agent.evaluate_log_prob(bel, x, y)
    
    return y_mean, y_cov, log_prob

ekf_bel, ekf_outputs = ekf.scan(
    init_mean, init_cov, xs, ys, callback=callback, **ekf_kwargs
)

In [None]:
ekf_mean, ekf_var, ekf_log_prob = ekf_outputs
ekf_mean, ekf_var, ekf_log_prob = ekf_mean.ravel(), ekf_var.ravel(), ekf_log_prob.ravel()

# Compute cumulative average
ekf_log_prob = jnp.cumsum(ekf_log_prob) / jnp.arange(1, len(ekf_log_prob) + 1)

In [None]:
# Plot
fig, ax = plt.subplots(figsize=(12, 3))
plt.scatter(xs, ys, s=0.5, color="black")
plt.plot(xs, ekf_mean, color="darkorange", linewidth=1.5)
plt.plot(xs, ekf_mean + 2*jnp.sqrt(ekf_var), color="darkorange", linewidth=1.0, linestyle="--")
plt.plot(xs, ekf_mean - 2*jnp.sqrt(ekf_var), color="darkorange", linewidth=1.0, linestyle="--")
ax.set_xlabel("Time")
ax.set_ylabel("Output")
ax.set_ylim(-0.5, 2)
ax.set_yticks([0, 1]);

In [None]:
# Plot log prob
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(xs, ekf_log_prob, color="darkorange", linewidth=1.5)
ax.set_ylim(-4, 2)
ax.set_xlabel("Time")
ax.set_ylabel("Log Predictive density");


## 1.2. Kalman Filter (gamma = 1, Q = 0.01)

In [None]:
nsekf = RebayesEKF(
    dynamics_weights_or_function=1.0,
    dynamics_covariance=1e-10,
    emission_mean_function=model_dict["emission_mean_function"],
    emission_cov_function=model_dict["emission_cov_function"],
    method="fcekf",
)
init_mean, init_cov = model_dict["flat_params"], 0.01
nsekf_kwargs = {
    "agent": nsekf,
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    agent = kwargs["agent"]
    y_mean = agent.predict_obs(bel_pred, x)
    y_cov = agent.predict_obs_cov(bel_pred, x)
    log_prob = agent.evaluate_log_prob(bel, x, y)
    
    return y_mean, y_cov, log_prob

nsekf_bel, nsekf_outputs = nsekf.scan(
    init_mean, init_cov, xs, ys, callback=callback, **nsekf_kwargs
)

In [None]:
nsekf_mean, nsekf_var, nsekf_log_prob = nsekf_outputs
nsekf_mean, nsekf_var, nsekf_log_prob = nsekf_mean.ravel(), nsekf_var.ravel(), nsekf_log_prob.ravel()

# Compute cumulative average
nsekf_log_prob = jnp.cumsum(nsekf_log_prob) / jnp.arange(1, len(nsekf_log_prob) + 1)

In [None]:
# Plot
fig, ax = plt.subplots(figsize=(12, 3))
plt.scatter(xs, ys, s=0.5, color="black")
plt.plot(xs, nsekf_mean, color="darkorange", linewidth=1.5)
plt.plot(xs, nsekf_mean + 2*jnp.sqrt(nsekf_var), color="darkorange", linewidth=1.0, linestyle="--")
plt.plot(xs, nsekf_mean - 2*jnp.sqrt(nsekf_var), color="darkorange", linewidth=1.0, linestyle="--")
ax.set_xlabel("Time")
ax.set_ylabel("Output")
ax.set_ylim(-0.5, 2)
ax.set_yticks([0, 1]);

In [None]:
# Plot log prob
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(xs, nsekf_log_prob, color="darkorange", linewidth=1.5)
ax.set_ylim(-4, 2)
ax.set_xlabel("Time")
ax.set_ylabel("Log Predictive density");


## 1.3. Adaptive Kalman Filter

In [None]:
oclekf = RebayesOCLEKF(
    dynamics_decay_delta=0.0,
    dynamics_covariance=0.01,
    emission_mean_function=lambda w, x: w,
    emission_cov_function=lambda w, x: 0.05,
    method="fcekf",
    decay_dynamics_weight=False,
    learning_rate=1.0,
)
init_mean, init_cov = jnp.array([0.0]), jnp.array([0.01])
oclekf_kwargs = {
    "agent": oclekf,
}

def oclekf_callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    agent = kwargs["agent"]
    y_mean = agent.predict_obs(bel_pred, x)
    y_cov = agent.predict_obs_cov(bel_pred, x)
    log_prob = agent.evaluate_log_prob(bel, x, y)
    gamma_sq = jnp.exp(-bel.dynamics_decay_delta)
    
    return y_mean, y_cov, gamma_sq, log_prob

oclekf_bel, oclekf_outputs = oclekf.scan(
    init_mean, init_cov, xs, ys, callback=oclekf_callback, **oclekf_kwargs
)

In [None]:
oclekf_mean, oclekf_var, oclekf_gammasq, oclekf_log_prob = oclekf_outputs
oclekf_mean, oclekf_var, oclekf_gammasq, oclekf_log_prob = \
    oclekf_mean.ravel(), oclekf_var.ravel(), oclekf_gammasq.ravel(), oclekf_log_prob.ravel()

# Compute cumulative average
oclekf_log_prob = jnp.cumsum(oclekf_log_prob) / jnp.arange(1, len(oclekf_log_prob) + 1)

In [None]:
# Plot
fig, ax = plt.subplots(figsize=(12, 3))
plt.scatter(xs, ys, s=0.5, color="black")
plt.plot(xs, oclekf_mean, color="darkorange", linewidth=1.5)
plt.plot(xs, oclekf_mean + 2*jnp.sqrt(oclekf_var), color="darkorange", linewidth=1.0, linestyle="--")
plt.plot(xs, oclekf_mean - 2*jnp.sqrt(oclekf_var), color="darkorange", linewidth=1.0, linestyle="--")
ax.set_xlabel("Time")
ax.set_ylabel("Output")
ax.set_ylim(-0.5, 2)
ax.set_yticks([0, 1]);

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(xs, oclekf_gammasq, color="green", linewidth=1.5);

In [None]:
# Plot log prob
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(xs, oclekf_log_prob, color="darkorange", linewidth=1.5)
ax.set_ylim(-4, 2)
ax.set_xlabel("Time")
ax.set_ylabel("Log Predictive density");


# 2. Classification (Stationary)

## 2.0. Dataset

In [None]:
def generate_spiral_dataset(
    num_per_class=2000,
    zero_var=1.,
    one_var=1.,
    shuffle=True,
    key=0,
):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    key1, key2, key3, key4, key5, key6 = jr.split(key, 6)

    theta = jnp.sqrt(jr.uniform(key1, shape=(num_per_class,))) * 2*jnp.pi
    r = 2*theta + jnp.pi
    def generate_data(theta, r, key):
        x0, y0 = 20*jr.normal(key, shape=(2,))
        return jnp.array([x0+jnp.cos(theta)*r, y0+jnp.sin(theta)*r]).T

    # Input data for output zero
    zero_input = generate_data(theta, r, key2) + zero_var * jr.normal(key3, shape=(num_per_class, 2))

    # Input data for output one
    one_input = generate_data(theta, -r, key4) + one_var * jr.normal(key5, shape=(num_per_class, 2))

    # Stack the inputs and standardize
    input = jnp.concatenate([zero_input, one_input])
    # input = (input - input.mean(axis=0)) / input.std(axis=0)

    # Generate binary output
    output = jnp.concatenate([jnp.zeros(num_per_class), jnp.ones(num_per_class)])

    if shuffle:
        idx = jr.permutation(key6, jnp.arange(num_per_class * 2))
        input, output = input[idx], output[idx]

    val_index, test_index = num_per_class, int(1.4 * num_per_class)
    X_train, X_val, X_test = input[:val_index], input[val_index:test_index], input[test_index:]
    y_train, y_val, y_test = output[:val_index], output[val_index:test_index], output[test_index:]

    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
X_train, X_val, X_test, y_train, y_val, y_test = generate_spiral_dataset()

In [None]:
plt.plot(X_train[y_train==0, 0], X_train[y_train==0, 1], 'o', label='0');
plt.plot(X_train[y_train==1, 0], X_train[y_train==1, 1], 'o', label='1');

In [None]:
mlp_features = [10, 10]
model_init_fn = partial(models.initialize_classification_mlp,
                        hidden_dims=[20, 20,], input_dim=(2,),
                        output_dim=1)
model_params = model_init_fn(jr.PRNGKey(0))

cb = partial(callbacks.cb_eval, evaluate_fn=partial(callbacks.evaluate_function,
                                                    loss_fn=callbacks.ll_binary))

## 2.1. Kalman Filter

In [None]:
def bbf_ekf(log_init_cov):
    init_cov = jnp.exp(log_init_cov).item()
    ekf = RebayesEKF(
        dynamics_weights_or_function=1.0,
        dynamics_covariance=0.0,
        emission_mean_function=model_params["emission_mean_function"],
        emission_cov_function=model_params["emission_cov_function"],
        emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
        method="fcekf"
    )
    fcekf_val_cb_kwargs = {
        "agent": ekf, "X_test": X_val, "y_test": y_val, "apply_fn": model_params["apply_fn"],
    }
    _, metric = ekf.scan(model_params["flat_params"], init_cov, X_train, y_train,
                         callback=cb, **fcekf_val_cb_kwargs)
    metric = jnp.array(list(metric.values())).mean()
    if jnp.isnan(metric):
        metric = -1e8
    
    return metric

In [None]:
bounds={
    "log_init_cov": (-10.0, 2.0),
}
n_explore, n_exploit = 10, 15

ekf_optimizer = BayesianOptimization(
    f=bbf_ekf,
    pbounds=bounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

ekf_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

In [None]:
hparams = ekf_optimizer.max["params"]
init_cov = jnp.exp(hparams["log_init_cov"]).item()

ekf = RebayesEKF(
    dynamics_weights_or_function=1.0,
    dynamics_covariance=0.,
    emission_mean_function=model_params["emission_mean_function"],
    emission_cov_function=model_params["emission_cov_function"],
    emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
    method="fcekf",
)
fcekf_val_cb_kwargs = {
    "agent": ekf, "emission_fn": model_params["emission_mean_function"]
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    agent, emission_fn = kwargs["agent"], kwargs["emission_fn"]
    y_mean = emission_fn(bel_pred.mean, x)
    log_prob = tfd.Bernoulli(probs=y_mean).log_prob(y)
    
    return y_mean, log_prob

ekf_bel, ekf_outputs = ekf.scan(
    model_params["flat_params"], init_cov, X_test, y_test, callback=callback, 
    **fcekf_val_cb_kwargs
)

In [None]:
ekf_mean, ekf_log_prob = ekf_outputs
ekf_mean, ekf_log_prob = ekf_mean.ravel(), ekf_log_prob.ravel()

# Compute cumulative average
ekf_log_prob = jnp.cumsum(ekf_log_prob) / jnp.arange(1, len(ekf_log_prob) + 1)

In [None]:
plt.plot(ekf_log_prob);

## 2.2. Adaptive Kalman Filter

In [None]:
cb = partial(callbacks.cb_eval, evaluate_fn=partial(callbacks.evaluate_function,
                                                    loss_fn=callbacks.ll_binary))

def bbf_oclekf(log_init_cov, log_lr):
    init_cov = jnp.exp(log_init_cov).item()
    learning_rate = jnp.exp(log_lr).item()
    ekf = RebayesOCLEKF(
        dynamics_decay_delta=0.0,
        dynamics_covariance=init_cov,
        emission_mean_function=model_params["emission_mean_function"],
        emission_cov_function=model_params["emission_cov_function"],
        emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
        method="fcekf",
        decay_dynamics_weight=True,
        learning_rate=learning_rate,
    )
    fcekf_val_cb_kwargs = {
        "agent": ekf, "X_test": X_val, "y_test": y_val, "apply_fn": model_params["apply_fn"],
    }
    _, metric = ekf.scan(model_params["flat_params"], init_cov, X_train, y_train,
                         callback=cb, **fcekf_val_cb_kwargs)
    metric = jnp.array(list(metric.values())).mean()
    if jnp.isnan(metric):
        metric = -1e8
    
    return metric

In [None]:
bounds={
    "log_init_cov": (-15.0, 2.0),
    "log_lr": (-0.0, 0.0),
}
n_explore, n_exploit = 10, 15

oclekf_optimizer = BayesianOptimization(
    f=bbf_oclekf,
    pbounds=bounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

oclekf_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

In [None]:
hparams = oclekf_optimizer.max["params"]
init_cov = jnp.exp(hparams["log_init_cov"]).item()
learning_rate = jnp.exp(hparams["log_lr"]).item()

oclekf = RebayesOCLEKF(
    dynamics_decay_delta=0.0,
    dynamics_covariance=0.01,
    emission_mean_function=model_params["emission_mean_function"],
    emission_cov_function=model_params["emission_cov_function"],
    emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
    method="fcekf",
    decay_dynamics_weight=False,
    learning_rate=learning_rate,
)
fcekf_val_cb_kwargs = {
    "agent": ekf, "emission_fn": model_params["emission_mean_function"]
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    emission_fn = kwargs["emission_fn"]
    y_mean = emission_fn(bel_pred.mean, x)
    log_prob = tfd.Bernoulli(probs=y_mean).log_prob(y)
    gamma_sq = jnp.exp(-bel.dynamics_decay_delta)
    
    return y_mean, gamma_sq, log_prob

oclekf_bel, oclekf_outputs = oclekf.scan(
    model_params["flat_params"], init_cov, X_test, y_test, callback=callback, 
    **fcekf_val_cb_kwargs
)

In [None]:
oclekf_mean, oclekf_gamma_sq, oclekf_log_prob = oclekf_outputs
oclekf_mean, oclekf_gamma_sq, oclekf_log_prob = \
    oclekf_mean.ravel(), oclekf_gamma_sq.ravel(), oclekf_log_prob.ravel()

# Compute cumulative average
oclekf_log_prob = jnp.cumsum(oclekf_log_prob) / jnp.arange(1, len(oclekf_log_prob) + 1)

In [None]:
plt.plot(oclekf_log_prob);

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
plt.plot(oclekf_gamma_sq, color="green", linewidth=1.5);

# 3. Classification (Non-Stationary)

## 3.0. Dataset

In [None]:
def generate_nonstationary_spiral_dataset(
    key=0,
    num_per_class=200,
    num_classes=10,
):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    key, subkey = jr.split(key)
    X_train, X_val, X_test, y_train, y_val, y_test = [], [], [], [], [], []
    for _ in range(num_classes):
        key, subkey = jr.split(subkey)
        X_train_i, X_val_i, X_test_i, y_train_i, y_val_i, y_test_i = \
            generate_spiral_dataset(num_per_class, key=key)
        X_train.append(X_train_i)
        X_val.append(X_val_i)
        X_test.append(X_test_i)
        y_train.append(y_train_i)
        y_val.append(y_val_i)
        y_test.append(y_test_i)
    
    X_train, X_val, X_test = jnp.concatenate(X_train), jnp.concatenate(X_val), jnp.concatenate(X_test)
    y_train, y_val, y_test = jnp.concatenate(y_train), jnp.concatenate(y_val), jnp.concatenate(y_test)
    
    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
X_train, X_val, X_test, y_train, y_val, y_test = generate_nonstationary_spiral_dataset()

In [None]:
plt.plot(X_train[y_train==0, 0], X_train[y_train==0, 1], 'o', label='0');
plt.plot(X_train[y_train==1, 0], X_train[y_train==1, 1], 'o', label='1');

## 3.1. Kalman Filter (Q = 0.0)

In [None]:
nonstationary_cb = partial(callbacks.cb_osa,
                            evaluate_fn=callbacks.ll_binary,
                            label="log_likelihood")

In [None]:
def bbf_ekf(log_init_cov):
    init_cov = jnp.exp(log_init_cov).item()
    ekf = RebayesEKF(
        dynamics_weights_or_function=1.0,
        dynamics_covariance=0.0,
        emission_mean_function=model_params["emission_mean_function"],
        emission_cov_function=model_params["emission_cov_function"],
        emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
        method="fcekf"
    )
    fcekf_val_cb_kwargs = {
        "agent": ekf, "X_test": X_val, "y_test": y_val, "apply_fn": model_params["apply_fn"],
    }
    _, metric = ekf.scan(model_params["flat_params"], init_cov, X_train, y_train,
                         callback=nonstationary_cb, **fcekf_val_cb_kwargs)
    metric = jnp.array(list(metric.values())).mean()
    if jnp.isnan(metric):
        metric = -1e8
    
    return metric

In [None]:
bounds={
    "log_init_cov": (-10.0, 2.0),
}
n_explore, n_exploit = 10, 15

ekf_optimizer = BayesianOptimization(
    f=bbf_ekf,
    pbounds=bounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

ekf_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

In [None]:
hparams = ekf_optimizer.max["params"]
init_cov = jnp.exp(hparams["log_init_cov"]).item()

ekf = RebayesEKF(
    dynamics_weights_or_function=1.0,
    dynamics_covariance=0.,
    emission_mean_function=model_params["emission_mean_function"],
    emission_cov_function=model_params["emission_cov_function"],
    emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
    method="fcekf",
)
fcekf_val_cb_kwargs = {
    "agent": ekf, "emission_fn": model_params["emission_mean_function"]
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    agent, emission_fn = kwargs["agent"], kwargs["emission_fn"]
    y_mean = emission_fn(bel_pred.mean, x)
    log_prob = tfd.Bernoulli(probs=y_mean).log_prob(y)
    
    return y_mean, log_prob

ekf_bel, ekf_outputs = ekf.scan(
    model_params["flat_params"], init_cov, X_test, y_test, callback=callback, 
    **fcekf_val_cb_kwargs
)

In [None]:
ekf_mean, ekf_log_prob = ekf_outputs
ekf_mean, ekf_log_prob = ekf_mean.ravel(), ekf_log_prob.ravel()

# Compute cumulative average
ekf_log_prob = jnp.cumsum(ekf_log_prob) / jnp.arange(1, len(ekf_log_prob) + 1)

In [None]:
plt.plot(ekf_log_prob);

## 3.2. Kalman Filter (Q = 0.01)

In [None]:
def bbf_noisy_ekf(log_init_cov, log_dynamics_cov):
    init_cov = jnp.exp(log_init_cov).item()
    dynamics_cov = jnp.exp(log_dynamics_cov).item()
    ekf = RebayesEKF(
        dynamics_weights_or_function=1.0,
        dynamics_covariance=dynamics_cov,
        emission_mean_function=model_params["emission_mean_function"],
        emission_cov_function=model_params["emission_cov_function"],
        emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
        method="fcekf"
    )
    fcekf_val_cb_kwargs = {
        "agent": ekf, "X_test": X_val, "y_test": y_val, "apply_fn": model_params["apply_fn"],
    }
    _, metric = ekf.scan(model_params["flat_params"], init_cov, X_train, y_train,
                         callback=nonstationary_cb, **fcekf_val_cb_kwargs)
    metric = jnp.array(list(metric.values())).mean()
    if jnp.isnan(metric):
        metric = -1e8
    
    return metric

In [None]:
bounds={
    "log_init_cov": (-20.0, 2.0),
    "log_dynamics_cov": (-20.0, 2.0),
}
n_explore, n_exploit = 10, 15

noisy_ekf_optimizer = BayesianOptimization(
    f=bbf_noisy_ekf,
    pbounds=bounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

noisy_ekf_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

In [None]:
hparams = noisy_ekf_optimizer.max["params"]
init_cov = jnp.exp(hparams["log_init_cov"]).item()
dynamic_cov = jnp.exp(hparams["log_dynamics_cov"]).item()

noisy_ekf = RebayesEKF(
    dynamics_weights_or_function=1.0,
    dynamics_covariance=dynamic_cov,
    emission_mean_function=model_params["emission_mean_function"],
    emission_cov_function=model_params["emission_cov_function"],
    emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
    method="fcekf",
)
fcekf_val_cb_kwargs = {
    "agent": noisy_ekf, "emission_fn": model_params["emission_mean_function"]
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    agent, emission_fn = kwargs["agent"], kwargs["emission_fn"]
    y_mean = emission_fn(bel_pred.mean, x)
    log_prob = tfd.Bernoulli(probs=y_mean).log_prob(y)
    
    return y_mean, log_prob

noisy_ekf_bel, noisy_ekf_outputs = noisy_ekf.scan(
    model_params["flat_params"], init_cov, X_test, y_test, callback=callback, 
    **fcekf_val_cb_kwargs
)

In [None]:
noisy_ekf_mean, noisy_ekf_log_prob = noisy_ekf_outputs
noisy_ekf_mean, noisy_ekf_log_prob = noisy_ekf_mean.ravel(), noisy_ekf_log_prob.ravel()

# Compute cumulative average
noisy_ekf_log_prob = jnp.cumsum(noisy_ekf_log_prob) / jnp.arange(1, len(noisy_ekf_log_prob) + 1)

In [None]:
plt.plot(noisy_ekf_log_prob);

## 3.3 Adaptive Kalman Filter

In [None]:
def bbf_oclekf(log_init_cov, log_lr, log_dynamics_cov):
    init_cov = jnp.exp(log_init_cov).item()
    dynamics_cov = jnp.exp(log_dynamics_cov).item()
    learning_rate = jnp.exp(log_lr).item()
    ekf = RebayesOCLEKF(
        dynamics_decay_delta=0.0,
        dynamics_covariance=init_cov,
        emission_mean_function=model_params["emission_mean_function"],
        emission_cov_function=model_params["emission_cov_function"],
        emission_dist=lambda mean, cov: tfd.Bernoulli(probs=mean),
        method="fcekf",
        decay_dynamics_weight=True,
        learning_rate=learning_rate,
        gamma_ub=None,
    )
    fcekf_val_cb_kwargs = {
        "agent": ekf, "X_test": X_val, "y_test": y_val, "apply_fn": model_params["apply_fn"],
    }
    _, metric = ekf.scan(model_params["flat_params"], init_cov, X_train, y_train,
                         callback=nonstationary_cb, **fcekf_val_cb_kwargs)
    metric = jnp.array(list(metric.values())).mean()
    if jnp.isnan(metric):
        metric = -1e8
    
    return metric

In [None]:
import jax

bounds={
    "log_init_cov": (-10.0, 2.0),
    "log_lr": (0.0, 0.0),
    "log_dynamics_cov": (-20.0, 2.0),
}
n_explore, n_exploit = 10, 15

oclekf_optimizer = BayesianOptimization(
    f=bbf_oclekf,
    pbounds=bounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

oclekf_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

In [None]:
hparams = oclekf_optimizer.max["params"]
init_cov = jnp.exp(hparams["log_init_cov"]).item()
learning_rate = jnp.exp(hparams["log_lr"]).item()

oclekf = RebayesOCLEKF(
    dynamics_decay_delta=0.0,
    dynamics_covariance=init_cov,
    emission_mean_function=model_params["emission_mean_function"],
    emission_cov_function=model_params["emission_cov_function"],
    emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
    method="fcekf",
    decay_dynamics_weight=True,
    learning_rate=learning_rate,
    gamma_ub=0.90,
)
fcekf_val_cb_kwargs = {
    "agent": oclekf, "emission_fn": model_params["emission_mean_function"]
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    emission_fn = kwargs["emission_fn"]
    y_mean = emission_fn(bel_pred.mean, x)
    log_prob = tfd.Bernoulli(probs=y_mean).log_prob(y)
    gamma_sq = jnp.exp(-bel.dynamics_decay_delta)
    
    return y_mean, gamma_sq, log_prob


# with jax.disable_jit():
oclekf_bel, oclekf_outputs = oclekf.scan(
    model_params["flat_params"], init_cov, X_test, y_test, callback=callback, 
    **fcekf_val_cb_kwargs
)

In [None]:
oclekf_mean, oclekf_gamma_sq, oclekf_log_prob = oclekf_outputs
oclekf_mean, oclekf_gamma_sq, oclekf_log_prob = \
    oclekf_mean.ravel(), oclekf_gamma_sq.ravel(), oclekf_log_prob.ravel()

# Compute cumulative average
oclekf_log_prob = jnp.cumsum(oclekf_log_prob) / jnp.arange(1, len(oclekf_log_prob) + 1)

In [None]:
oclekf_gamma_sq.min()

In [None]:
plt.plot(oclekf_log_prob);

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
ax.plot(oclekf_gamma_sq, color="green", linewidth=1.5);
# Plot a vertical line every 120 time steps
for i in range(0, len(oclekf_gamma_sq), 120):
    ax.axvline(i, color="black", linewidth=0.5, linestyle="--")

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
ax.plot(oclekf_gamma_sq, color="green", linewidth=1.5);
# Plot a vertical line every 120 time steps
for i in range(0, len(oclekf_gamma_sq), 120):
    ax.axvline(i, color="black", linewidth=0.5, linestyle="--")

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
ax.plot(oclekf_gamma_sq, color="green", linewidth=1.5);
# Plot a vertical line every 120 time steps
for i in range(0, len(oclekf_gamma_sq), 120):
    ax.axvline(i, color="black", linewidth=0.5, linestyle="--")

## 3.4 Adaptive LOFI

In [None]:
from rebayes.low_rank_filter.lofi import RebayesOCLLoFiDiagonal

In [None]:
def bbf_ocllofi(log_init_cov, log_lr):
    init_cov = jnp.exp(log_init_cov).item()
    learning_rate = jnp.exp(log_lr).item()
    lofi = RebayesOCLLoFiDiagonal(
        dynamics_weights=1.0,
        dynamics_covariance=init_cov,
        emission_mean_function=model_params["emission_mean_function"],
        emission_cov_function=model_params["emission_cov_function"],
        emission_dist=lambda mean, cov: tfd.Bernoulli(probs=mean),
        decay_dynamics_weight=True,
        learning_rate=learning_rate,
        gamma_ub=None,
    )
    fcekf_val_cb_kwargs = {
        "agent": lofi, "X_test": X_val, "y_test": y_val, "apply_fn": model_params["apply_fn"],
    }
    _, metric = lofi.scan(model_params["flat_params"], init_cov, X_train, y_train,
                          callback=nonstationary_cb, **fcekf_val_cb_kwargs)
    metric = jnp.array(list(metric.values())).mean()
    if jnp.isnan(metric):
        metric = -1e8
    
    return metric

In [None]:
import jax

bounds={
    "log_init_cov": (-10.0, 2.0),
    "log_lr": (0.0, 0.0),
}
n_explore, n_exploit = 10, 15

ocllofi_optimizer = BayesianOptimization(
    f=bbf_ocllofi,
    pbounds=bounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

ocllofi_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

In [None]:
hparams = ocllofi_optimizer.max["params"]
init_cov = jnp.exp(hparams["log_init_cov"]).item()
learning_rate = jnp.exp(hparams["log_lr"]).item()

ocllofi = RebayesOCLLoFiDiagonal(
    dynamics_weights=1.0,
    dynamics_covariance=init_cov,
    emission_mean_function=model_params["emission_mean_function"],
    emission_cov_function=model_params["emission_cov_function"],
    emission_dist=lambda mean, cov: tfd.Bernoulli(logits=mean),
    decay_dynamics_weight=True,
    learning_rate=learning_rate,
    gamma_ub=None,
)
fcekf_val_cb_kwargs = {
    "agent": ocllofi, "emission_fn": model_params["emission_mean_function"]
}

def callback(bel, pred_obs, t, x, y, bel_pred, **kwargs):
    emission_fn = kwargs["emission_fn"]
    y_mean = emission_fn(bel_pred.mean, x)
    log_prob = tfd.Bernoulli(probs=y_mean).log_prob(y)
    gamma_sq = bel.gamma**2
    
    return y_mean, gamma_sq, log_prob


# with jax.disable_jit():
ocllofi_bel, ocllofi_outputs = ocllofi.scan(
    model_params["flat_params"], init_cov, X_test, y_test, callback=callback, 
    **fcekf_val_cb_kwargs
)

In [None]:
ocllofi_mean, ocllofi_gamma_sq, ocllofi_log_prob = ocllofi_outputs
ocllofi_mean, ocllofi_gamma_sq, ocllofi_log_prob = \
    ocllofi_mean.ravel(), ocllofi_gamma_sq.ravel(), ocllofi_log_prob.ravel()

# Compute cumulative average
ocllofi_log_prob = jnp.cumsum(ocllofi_log_prob) / jnp.arange(1, len(ocllofi_log_prob) + 1)

In [None]:
plt.plot(ocllofi_log_prob);

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
ax.plot(ocllofi_gamma_sq, color="green", linewidth=1.5);
# Plot a vertical line every 120 time steps
for i in range(0, len(ocllofi_gamma_sq), 120):
    ax.axvline(i, color="black", linewidth=0.5, linestyle="--")

# Regression (Non-Stationary)

In [None]:
import argparse
from functools import partial
import json
import os
from typing import Callable
from pathlib import Path
import pickle

import jax.random as jr
from jax.tree_util import tree_map

import demos.collas.datasets.dataloaders as dataloaders
from rebayes.low_rank_filter import lofi
import rebayes.utils.models as models
import rebayes.utils.callbacks as callbacks
import demos.collas.hparam_tune as hparam_tune
import demos.collas.train_utils as train_utils

import demos.collas.run_regression_experiments as reg_exp

In [115]:
problem_str = "permuted"
dataset = dataloaders.reg_datasets[problem_str]
if problem_str == "permuted":
    dataset = dataset()
dataset_load_fn, kwargs = dataset.values()

ntrain = None
base_dataset = dataloaders.load_target_digit_dataset(
    target_digit=2, dataset_type="mnist", n=ntrain,
)
dataset_load_fn = partial(dataset_load_fn, dataset=base_dataset)

dataset = dataset_load_fn()

In [None]:
obs_scale, nll_method, temp, aleatoric = 15.0, "nll", 1.0, 1.0

eval_metric = reg_exp._eval_metric(obs_scale, problem_str, nll_method, temp, aleatoric)

In [None]:
model_init_fn = models.initialize_regression_mlp

input_dim, output_dim = reg_exp._compute_io_dims("mnist")

model_init_fn = partial(model_init_fn, input_dim=input_dim, 
                            output_dim=output_dim, 
                            emission_cov=obs_scale**2)

In [None]:
def bbf_ocl_lofi(
    log_init_cov,
    log_learning_rate,
    # Specify before running
    train,
    test,
    callback,
    memory_size,
    decay_dynamics_weight,
    callback_at_end=True,
    n_seeds=5,
    **kwargs,
):
    """
    Black-box function for Bayesian optimization.
    """
    X_train, *_, y_train = train
    X_test, *_, y_test = test

    initial_covariance = jnp.exp(log_init_cov).item()
    learning_rate = jnp.exp(log_learning_rate).item()
    
    model_dict = model_init_fn(key=0)
    
    emission_dist = lambda mean, cov: tfd.Normal(loc=mean, scale=jnp.sqrt(cov))
    
    estimator = lofi.RebayesOCLLoFiDiagonal(
        dynamics_weights=1.0,
        dynamics_covariance=initial_covariance,
        emission_mean_function=model_dict["emission_mean_function"],
        emission_cov_function=model_dict["emission_cov_function"],
        emission_dist=emission_dist,
        decay_dynamics_weight=decay_dynamics_weight,
        memory_size=memory_size,
        inflation="hybrid",
        learning_rate=learning_rate,
    )
    
    test_cb_kwargs = {"agent": estimator, "X_test": X_test, "y_test": y_test,
                      "apply_fn": model_dict["apply_fn"], "key": jr.PRNGKey(0),
                      **kwargs}

    result = []
    for i in range(n_seeds):
        model_dict = model_init_fn(key=i)
        flat_params = model_dict["flat_params"]
        if callback_at_end:
            bel, _ = estimator.scan(flat_params, initial_covariance, 
                                    X_train, y_train, progress_bar=False)
            metric = jnp.array(list(callback(bel, **test_cb_kwargs).values()))
        else:
            _, metric = estimator.scan(flat_params, initial_covariance, 
                                       X_train, y_train, progress_bar=False, 
                                       callback=callback, **test_cb_kwargs)
            metric = jnp.array(list(metric.values())).mean()
        result.append(metric)
    result = jnp.array(result).mean()
    
    if jnp.isnan(result) or jnp.isinf(result):
        result = -1e8
        
    return result

In [116]:
bbf_ocl_lofi = partial(
    bbf_ocl_lofi,
    train=dataset["train"],
    test=dataset["val"],
    callback=eval_metric["val"],
    memory_size=10,
    decay_dynamics_weight=False,
    callback_at_end=False,
    n_seeds=2,
)

In [117]:
pbounds = {
    "log_init_cov": (-10.0, 2.0),
    "log_learning_rate": (-10.0, 2.0),
}

In [118]:
n_explore, n_exploit = 10, 15

ocllofi_optimizer = BayesianOptimization(
    f=bbf_ocl_lofi,
    pbounds=pbounds,
    random_state=0,
    verbose=2,
    allow_duplicate_points=True,
)

ocllofi_optimizer.maximize(init_points=n_explore, n_iter=n_exploit)

|   iter    |  target   | log_in... | log_le... |
-------------------------------------------------
