# 0. Imports and Dataset

In [44]:
try:
    import flax
    import optax
    import ml_collections
    import chex
    import tensorflow
    import tensorflow_datasets
    import rebayes
    import dynamax
    import jax_tqdm
    import bayes_opt
except ModuleNotFoundError:
    %pip install flax
    %pip install optax
    %pip install ml_collections
    %pip install chex
    %pip install tensorflow
    %pip install tensorflow_datasets
    %pip install git+https://github.com/probml/rebayes.git
    %pip install dynamax
    %pip install jax-tqdm
    %pip install bayesian-optimization 
    import flax 
    import optax
    import ml_collections
    import chex
    import tensorflow
    import tensorflow_datasets
    import rebayes
    import dynamax
    import jax_tqdm
    import bayes_opt

In [46]:
import gc
from functools import partial
import time
from typing import Callable, Sequence
import collections.abc
import warnings
import copy
from collections import deque

import ml_collections
import tensorflow_datasets as tfds
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalDiag as MVND
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

from flax import linen as nn
from flax.training import train_state
import jax
from jax import jit, vmap, lax, jacfwd
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
import optax
from sklearn.preprocessing import OneHotEncoder
import numpy as np
from bayes_opt import BayesianOptimization

jax.numpy.set_printoptions(suppress = True, precision=4)
from rebayes.orfit import GeneralizedORFitParams, RebayesORFit
from rebayes.ekf import RebayesEKF
from rebayes.diagonal_inference import DEKFParams


In [3]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    train_tvsplit_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[:80%]', batch_size=-1))
    val_tvsplit_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[80%:]', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    for ds in [train_ds, train_tvsplit_ds, val_tvsplit_ds, test_ds]:
        ds['image'] = jnp.float32(ds['image']) / 255.
    return train_ds, test_ds, train_tvsplit_ds, val_tvsplit_ds

In [4]:
def get_config(
    optimizer='sgd',
    learning_rate=0.01, 
    momentum=2e-1,
    init_var=3e-2,
    num_iter=1,
    batch_size=1, 
    num_epochs=1,
    sample_freq=500,
    posterior_predictive_method='mc',
    seed=0
    ):
    """Get the default hyperparameter configuration."""
    config = ml_collections.ConfigDict()
    config.optimizer = optimizer
    config.learning_rate = learning_rate
    config.momentum = momentum
    config.init_var = init_var
    config.num_iter = num_iter
    config.batch_size = batch_size
    config.num_epochs = num_epochs
    config.sample_freq = sample_freq
    config.posterior_predictive_method = posterior_predictive_method
    config.seed = seed
    return config

In [6]:
train_ds, test_ds, train_tvsplit_ds, val_tvsplit_ds = get_datasets()

In [7]:
X_train, y_train = jnp.array(train_ds['image']), jnp.array(train_ds['label'])
X_test, y_test = jnp.array(test_ds['image']), jnp.array(test_ds['label'])

In [8]:
# Reshape data
X_train = X_train.reshape(-1, 1, 28, 28, 1)
y_train_ohe = jax.nn.one_hot(y_train, 10)

# 1. Models

In [9]:
class CNN(nn.Module):
    """A simple CNN model."""
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        x = x.ravel()
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

In [10]:
cnn = CNN()
key = jr.PRNGKey(0)
params = cnn.init(key, jnp.ones([1, 28, 28, 1]))['params']
flat_params, unflatten_fn = ravel_pytree(params)
print(f'Params size = {flat_params.shape}')
state_dim = flat_params.shape[0]

Params size = (421642,)


In [11]:
apply_fn = lambda w, x: cnn.apply({'params': unflatten_fn(w)}, x).ravel()
emission_mean_function=lambda w, x: jax.nn.softmax(apply_fn(w, x))
def emission_cov_function(w, x):
    ps = emission_mean_function(w, x)
    return jnp.diag(ps) - jnp.outer(ps, ps) + 1e-3 * jnp.eye(len(ps)) # Add diagonal to avoid singularity

In [69]:
def evaluate_neg_log_likelihood(flat_params, unflatten_fn, apply_fn, test_set):
    """ Evaluate negative log likelihood for given parameters and test set
    """
    @jit
    def evaluate_nll(label, image):
        image = image.reshape((1, 28, 28, 1))
        logits = apply_fn(flat_params, image)
        return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=label)
    nlls = vmap(evaluate_nll, (0, 0))(test_set['label'], test_set['image'])
    return nlls.mean()

# 2. Generalized ORFit

In [20]:
from jax.lax import scan
from jax import jacrev, vmap, jit
from tqdm import trange
from jax_tqdm import scan_tqdm


# Helper functions
_jacrev_2d = lambda f, x: jnp.atleast_2d(jacrev(f)(x))
_stable_division = lambda a, b: jnp.where(b.any(), a / b, jnp.zeros(shape=a.shape))
_normalize = lambda v: jnp.where(v.any(), v / jnp.linalg.norm(v), jnp.zeros(shape=v.shape))
_projection_matrix = lambda a: _stable_division(a.reshape(-1, 1) @ a.reshape(1, -1), a.T @ a)
_form_projection_matrix = lambda A: jnp.eye(A.shape[0]) - vmap(_projection_matrix, 1)(A).sum(axis=0)
_project = lambda a, x: _stable_division(a * (a.T @ x), (a.T @ a))
_project_to_columns = lambda A, x: \
    jnp.where(A.any(), vmap(_project, (1, None))(A, x).sum(axis=0), jnp.zeros(shape=x.shape))


def _generalized_orfit_condition_on(m, U, Sigma, eta, y_cond_mean, y_cond_cov, x, y, sv_threshold):
    """Condition step of the ORFit algorithm.
    Args:
        m (D_hid,): Prior mean.
        U (D_hid, D_mem,): Prior basis.
        Sigma (D_mem,): Prior singular values.
        eta (float): Prior precision.
        y_cond_mean (Callable): Conditional emission mean function.
        y_cond_cov (Callable): Conditional emission covariance function.
        x (D_in,): Control input.
        y (D_obs,): Emission.
        sv_threshold (float): Threshold for singular values.
    Returns:
        m_cond (D_hid,): Posterior mean.
        U_cond (D_hid, D_mem,): Posterior basis.
        Sigma_cond (D_mem,): Posterior singular values.
    """
    m_Y = lambda w: y_cond_mean(w, x)
    Cov_Y = lambda w: y_cond_cov(w, x)
    
    yhat = jnp.atleast_1d(m_Y(m))
    R = jnp.atleast_2d(Cov_Y(m))
    # print(f'R: {R}')
    L = jnp.linalg.cholesky(R)
    # print(f'L: {L}')
    A = jnp.linalg.lstsq(L, jnp.eye(L.shape[0]))[0].T
    H = _jacrev_2d(m_Y, m)
    # print(f'A: {A}')
    W_tilde = jnp.hstack([Sigma * U, (H.T @ A).squeeze()])
    S = eta*jnp.eye(W_tilde.shape[1]) + W_tilde.T @ W_tilde
    # print(f'S: {S}')
    K = (H.T @ A) @ A.T - W_tilde @ (jnp.linalg.pinv(S) @ (W_tilde.T @ ((H.T @ A) @ A.T)))
    # print(f'K: {K}')
    # print('\n')

    m_cond = m + K/eta @ (y - yhat)
    U_tilde = (H.T - U @ (U.T @ H.T)) @ A

    def _update_basis(carry, i):
        U, Sigma = carry
        v = U_tilde[:, i]
        u = _normalize(v)
        U_cond = jnp.where(
            Sigma.min() < u @ v, 
            jnp.where(sv_threshold < u @ v, U.at[:, Sigma.argmin()].set(u), U),
            U
        )
        Sigma_cond = jnp.where(
            Sigma.min() < u @ v,
            jnp.where(sv_threshold < u @ v, Sigma.at[Sigma.argmin()].set(u.T @ v), Sigma),
            Sigma,
        )
        return (U_cond, Sigma_cond), (U_cond, Sigma_cond)

    (U_cond, Sigma_cond), _ = scan(_update_basis, (U, Sigma), jnp.arange(U_tilde.shape[1]))

    return m_cond, U_cond, Sigma_cond

In [21]:
@chex.dataclass
class ORFitBel:
    mean: chex.Array
    basis: chex.Array
    sigma: chex.Array

class RebayesORFit:
    def __init__(
        self,
        orfit_params,
    ):
        self.eta = orfit_params.initial_precision
        self.update_fn = _generalized_orfit_condition_on
        self.gamma = orfit_params.dynamics_decay
        self.q = orfit_params.dynamics_noise
        self.f = orfit_params.emission_mean_function
        self.r = orfit_params.emission_cov_function
        self.mu0 = orfit_params.initial_mean
        self.m = orfit_params.memory_size
        self.sv_threshold = orfit_params.sv_threshold
        self.U0 = jnp.zeros((len(self.mu0), self.m))
        self.Sigma0 = jnp.zeros((self.m,))

    def initialize(self):
        return ORFitBel(mean=self.mu0, basis=self.U0, sigma=self.Sigma0)

    # @partial(jit, static_argnums=(0,))
    def update(self, bel, u, y):
        m, U, Sigma = bel.mean, bel.basis, bel.sigma # prior predictive for hidden state
        m_cond, U_cond, Sigma_cond = self.update_fn(m, U, Sigma, self.eta, self.f, self.r, u, y, self.sv_threshold)
        return ORFitBel(mean=m_cond, basis=U_cond, sigma=Sigma_cond)

    def scan(self, X, Y, callback=None):
        num_timesteps = X.shape[0]
        
        @scan_tqdm(num_timesteps)
        def step(bel, t):
            bel = self.update(bel, X[t], Y[t])
            out = None
            if callback is not None:
                out = callback(bel, t, X[t], Y[t])
            return bel, out

        carry = self.initialize()
        bel, outputs = scan(step, carry, jnp.arange(num_timesteps))
        return bel, outputs

In [31]:
orfit_params = GeneralizedORFitParams(
    initial_mean=flat_params,
    initial_precision=1.0,
    dynamics_decay=0.0,
    dynamics_noise=0.0,
    emission_mean_function=emission_mean_function,
    emission_cov_function=emission_cov_function,
    memory_size=20
)
orfit_estimator = RebayesORFit(orfit_params)

In [32]:
orfit_bel, _ = orfit_estimator.scan(X_train[:1000], y_train_ohe[:1000])

  0%|          | 0/1000 [00:00<?, ?it/s]

In [33]:
orfit_mean = orfit_bel.mean

In [36]:
emission_mean_function(orfit_mean, X_test[1].reshape(1,28,28,1))

DeviceArray([0.0351, 0.1237, 0.0676, 0.327 , 0.0599, 0.1037, 0.0531,
             0.0849, 0.0774, 0.0674], dtype=float32)

In [39]:
emission_mean_function(orfit_mean, X_test[3].reshape(1,28,28,1))

DeviceArray([0.0351, 0.1237, 0.0676, 0.327 , 0.0599, 0.1037, 0.0531,
             0.0849, 0.0774, 0.0674], dtype=float32)

In [37]:
y_test[1]

DeviceArray(0, dtype=int32)

# 3. VDEKF

## 3.1 Hyperparameter Tuning

In [71]:
def opt_fn(init_var):
    vdekf_params = DEKFParams(
        initial_mean=flat_params,
        initial_cov_diag=jnp.ones(flat_params.shape[0]) * init_var,
        dynamics_cov_diag=jnp.ones((flat_params.size)) * 1e-4,
        emission_mean_function=emission_mean_function,
        emission_cov_function=emission_cov_function,
    )
    vdekf_estimator = RebayesEKF(vdekf_params, 'vdekf')
    vdekf_bel, _ = vdekf_estimator.scan(X_train[:2000], y_train_ohe[:2000])
    vdekf_mean = vdekf_bel.mean
    log_likelihood = -evaluate_neg_log_likelihood(vdekf_mean, unflatten_fn, apply_fn, test_ds)

    return log_likelihood

In [72]:
vdekf_optimizer = BayesianOptimization(
    f=opt_fn,
    pbounds={'init_var': (1e-4, 1.0)},
)

In [None]:
vdekf_optimizer.maximize(
    init_points=200,
    n_iter=200,
)

|   iter    |  target   | init_var  |
-------------------------------------


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m1        [0m | [0m-2.311   [0m | [0m0.6531   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m2        [0m | [95m-2.311   [0m | [95m0.9175   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m3        [0m | [95m-0.7158  [0m | [95m0.1239   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m4        [0m | [95m-0.5113  [0m | [95m0.438    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m5        [0m | [0m-0.7144  [0m | [0m0.0773   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m6        [0m | [0m-0.5273  [0m | [0m0.7585   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m7        [0m | [0m-2.311   [0m | [0m0.9737   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m8        [0m | [95m-0.4068  [0m | [95m0.3931   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m9        [0m | [0m-2.311   [0m | [0m0.8023   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m10       [0m | [0m-0.5853  [0m | [0m0.01634  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m11       [0m | [0m-2.311   [0m | [0m0.7145   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m12       [0m | [0m-0.4082  [0m | [0m0.6132   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m13       [0m | [95m-0.3474  [0m | [95m0.4579   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m14       [0m | [95m-0.2872  [0m | [95m0.5726   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m15       [0m | [0m-0.7113  [0m | [0m0.1147   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m16       [0m | [0m-0.6389  [0m | [0m0.02802  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m17       [0m | [0m-0.6137  [0m | [0m0.08376  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m18       [0m | [0m-2.311   [0m | [0m0.5689   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m19       [0m | [0m-0.7202  [0m | [0m0.6341   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m20       [0m | [0m-0.5028  [0m | [0m0.2544   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m21       [0m | [0m-0.5368  [0m | [0m0.3846   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m22       [0m | [0m-0.5297  [0m | [0m0.2252   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m23       [0m | [0m-0.596   [0m | [0m0.4931   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m24       [0m | [0m-0.5703  [0m | [0m0.4794   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m25       [0m | [0m-2.311   [0m | [0m0.9165   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m26       [0m | [0m-0.4088  [0m | [0m0.3994   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m27       [0m | [0m-0.4626  [0m | [0m0.2885   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m28       [0m | [0m-2.311   [0m | [0m0.7686   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m29       [0m | [0m-0.5586  [0m | [0m0.475    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m30       [0m | [0m-0.4382  [0m | [0m0.3332   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m31       [0m | [0m-0.3717  [0m | [0m0.2665   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m32       [0m | [0m-2.311   [0m | [0m0.765    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m33       [0m | [0m-0.5353  [0m | [0m0.6852   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m34       [0m | [0m-0.6271  [0m | [0m0.05372  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m35       [0m | [0m-0.5525  [0m | [0m0.6501   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m36       [0m | [0m-0.7075  [0m | [0m0.01283  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m37       [0m | [0m-2.311   [0m | [0m0.684    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m38       [0m | [0m-0.4064  [0m | [0m0.3257   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m39       [0m | [0m-2.311   [0m | [0m0.8926   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m40       [0m | [0m-2.311   [0m | [0m0.9348   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m41       [0m | [0m-0.3132  [0m | [0m0.4312   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m42       [0m | [0m-2.311   [0m | [0m0.9899   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m43       [0m | [0m-0.5444  [0m | [0m0.1859   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m44       [0m | [0m-0.609   [0m | [0m0.03973  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m45       [0m | [0m-2.311   [0m | [0m0.9643   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m46       [0m | [0m-0.5365  [0m | [0m0.07755  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m47       [0m | [0m-0.7079  [0m | [0m0.448    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m48       [0m | [0m-2.311   [0m | [0m0.7866   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m49       [0m | [0m-2.311   [0m | [0m0.9427   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m50       [0m | [0m-0.5634  [0m | [0m0.08625  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m51       [0m | [0m-0.4699  [0m | [0m0.452    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m52       [0m | [0m-0.4638  [0m | [0m0.4724   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m53       [0m | [0m-0.5247  [0m | [0m0.5423   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m54       [0m | [0m-2.311   [0m | [0m0.9164   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m55       [0m | [0m-0.4436  [0m | [0m0.3439   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m56       [0m | [0m-0.4912  [0m | [0m0.2979   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m57       [0m | [0m-2.311   [0m | [0m0.7484   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m58       [0m | [0m-0.6671  [0m | [0m0.03465  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m59       [0m | [0m-0.3755  [0m | [0m0.4763   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m60       [0m | [0m-0.6779  [0m | [0m0.01694  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m61       [0m | [0m-1.485   [0m | [0m0.3959   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m62       [0m | [0m-2.311   [0m | [0m0.6923   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m63       [0m | [0m-0.8058  [0m | [0m0.207    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m64       [0m | [0m-2.311   [0m | [0m0.9619   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m65       [0m | [0m-2.311   [0m | [0m0.9633   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m66       [0m | [0m-0.7642  [0m | [0m0.03841  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m67       [0m | [0m-2.311   [0m | [0m0.5836   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [95m68       [0m | [95m-0.2585  [0m | [95m0.5918   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m69       [0m | [0m-0.6937  [0m | [0m0.03189  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m70       [0m | [0m-0.6917  [0m | [0m0.04094  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m71       [0m | [0m-0.5018  [0m | [0m0.3298   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m72       [0m | [0m-2.311   [0m | [0m0.8601   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m73       [0m | [0m-106.4   [0m | [0m0.6515   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m74       [0m | [0m-0.6577  [0m | [0m0.1171   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m75       [0m | [0m-0.5726  [0m | [0m0.07242  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m76       [0m | [0m-0.4466  [0m | [0m0.418    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m77       [0m | [0m-0.4335  [0m | [0m0.4362   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m78       [0m | [0m-0.5925  [0m | [0m0.17     [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m79       [0m | [0m-2.31    [0m | [0m0.8419   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m80       [0m | [0m-2.311   [0m | [0m0.7501   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m81       [0m | [0m-0.9112  [0m | [0m0.2461   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m82       [0m | [0m-0.3246  [0m | [0m0.3933   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m83       [0m | [0m-0.4385  [0m | [0m0.362    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m84       [0m | [0m-2.311   [0m | [0m0.8356   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m85       [0m | [0m-0.4161  [0m | [0m0.2052   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m86       [0m | [0m-0.4269  [0m | [0m0.4394   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m87       [0m | [0m-2.311   [0m | [0m0.8118   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m88       [0m | [0m-0.6618  [0m | [0m0.1471   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m89       [0m | [0m-0.5499  [0m | [0m0.2327   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m90       [0m | [0m-0.7318  [0m | [0m0.002315 [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m91       [0m | [0m-0.4548  [0m | [0m0.1539   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m92       [0m | [0m-0.4299  [0m | [0m0.3434   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m93       [0m | [0m-0.7021  [0m | [0m0.5887   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m94       [0m | [0m-2.311   [0m | [0m0.8656   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m95       [0m | [0m-0.5281  [0m | [0m0.2744   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m96       [0m | [0m-0.3456  [0m | [0m0.3403   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m97       [0m | [0m-0.7608  [0m | [0m0.342    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m98       [0m | [0m-2.311   [0m | [0m0.802    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m99       [0m | [0m-0.2886  [0m | [0m0.6414   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m100      [0m | [0m-0.4168  [0m | [0m0.4463   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m101      [0m | [0m-0.5572  [0m | [0m0.2302   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m102      [0m | [0m-0.3957  [0m | [0m0.3288   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m103      [0m | [0m-0.4146  [0m | [0m0.4499   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m104      [0m | [0m-2.311   [0m | [0m0.7848   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m105      [0m | [0m-2.311   [0m | [0m0.7086   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m106      [0m | [0m-0.5022  [0m | [0m0.08839  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m107      [0m | [0m-0.4551  [0m | [0m0.4117   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m108      [0m | [0m-2.311   [0m | [0m0.7429   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m109      [0m | [0m-2.311   [0m | [0m0.8911   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m110      [0m | [0m-2.311   [0m | [0m0.6663   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m111      [0m | [0m-0.4317  [0m | [0m0.7465   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m112      [0m | [0m-0.4934  [0m | [0m0.1054   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m113      [0m | [0m-0.4914  [0m | [0m0.1393   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m114      [0m | [0m-0.9026  [0m | [0m0.4351   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m115      [0m | [0m-2.311   [0m | [0m0.7557   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m116      [0m | [0m-2.311   [0m | [0m0.6004   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m117      [0m | [0m-0.6004  [0m | [0m0.01998  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m118      [0m | [0m-2.311   [0m | [0m0.5882   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m119      [0m | [0m-2.311   [0m | [0m0.9985   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m120      [0m | [0m-0.38    [0m | [0m0.1407   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m121      [0m | [0m-0.6397  [0m | [0m0.2317   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m122      [0m | [0m-0.3177  [0m | [0m0.4704   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m123      [0m | [0m-0.4689  [0m | [0m0.1877   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m124      [0m | [0m-2.311   [0m | [0m0.8748   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m125      [0m | [0m-0.3916  [0m | [0m0.5868   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m126      [0m | [0m-0.5443  [0m | [0m0.09291  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m127      [0m | [0m-0.5148  [0m | [0m0.1324   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m128      [0m | [0m-0.3959  [0m | [0m0.3594   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m129      [0m | [0m-0.6655  [0m | [0m0.07736  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m130      [0m | [0m-0.5391  [0m | [0m0.1686   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m131      [0m | [0m-0.6904  [0m | [0m0.2964   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m132      [0m | [0m-0.3973  [0m | [0m0.4553   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m133      [0m | [0m-0.4667  [0m | [0m0.2784   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m134      [0m | [0m-2.311   [0m | [0m0.7725   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m135      [0m | [0m-0.4457  [0m | [0m0.5087   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m136      [0m | [0m-2.311   [0m | [0m0.9568   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m137      [0m | [0m-0.3682  [0m | [0m0.5816   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m138      [0m | [0m-2.311   [0m | [0m0.9194   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m139      [0m | [0m-0.502   [0m | [0m0.4157   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m140      [0m | [0m-0.5227  [0m | [0m0.1161   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m141      [0m | [0m-0.3915  [0m | [0m0.4743   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m142      [0m | [0m-0.4174  [0m | [0m0.3393   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m143      [0m | [0m-2.311   [0m | [0m0.8325   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m144      [0m | [0m-0.4346  [0m | [0m0.4113   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m145      [0m | [0m-0.2987  [0m | [0m0.3523   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m146      [0m | [0m-0.4126  [0m | [0m0.558    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m147      [0m | [0m-0.7751  [0m | [0m0.1265   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m148      [0m | [0m-2.311   [0m | [0m0.5648   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m149      [0m | [0m-0.3911  [0m | [0m0.2369   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m150      [0m | [0m-2.311   [0m | [0m0.6038   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m151      [0m | [0m-0.5097  [0m | [0m0.1478   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m152      [0m | [0m-0.3606  [0m | [0m0.6398   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m153      [0m | [0m-2.311   [0m | [0m0.6695   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m154      [0m | [0m-2.311   [0m | [0m0.8151   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m155      [0m | [0m-2.311   [0m | [0m0.9668   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m156      [0m | [0m-0.5579  [0m | [0m0.07907  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m157      [0m | [0m-0.5346  [0m | [0m0.442    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m158      [0m | [0m-0.386   [0m | [0m0.506    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m159      [0m | [0m-2.311   [0m | [0m0.9012   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m160      [0m | [0m-0.4635  [0m | [0m0.1817   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m161      [0m | [0m-2.311   [0m | [0m0.8581   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m162      [0m | [0m-0.31    [0m | [0m0.4558   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m163      [0m | [0m-0.643   [0m | [0m0.5628   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m164      [0m | [0m-2.311   [0m | [0m0.9929   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m165      [0m | [0m-2.311   [0m | [0m0.6043   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m166      [0m | [0m-2.311   [0m | [0m0.9546   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m167      [0m | [0m-0.5819  [0m | [0m0.5907   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m168      [0m | [0m-2.311   [0m | [0m0.9171   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m169      [0m | [0m-0.3869  [0m | [0m0.1964   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m170      [0m | [0m-2.311   [0m | [0m0.7795   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m171      [0m | [0m-0.3106  [0m | [0m0.5834   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m172      [0m | [0m-2.311   [0m | [0m0.8674   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m173      [0m | [0m-0.5784  [0m | [0m0.3866   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m174      [0m | [0m-0.7072  [0m | [0m0.336    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m175      [0m | [0m-0.4274  [0m | [0m0.2623   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m176      [0m | [0m-2.311   [0m | [0m0.8237   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m177      [0m | [0m-2.311   [0m | [0m0.6703   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m178      [0m | [0m-2.311   [0m | [0m0.8657   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m179      [0m | [0m-0.6924  [0m | [0m0.1283   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m180      [0m | [0m-0.6002  [0m | [0m0.4356   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m181      [0m | [0m-2.311   [0m | [0m0.7644   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m182      [0m | [0m-0.4746  [0m | [0m0.5003   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m183      [0m | [0m-2.311   [0m | [0m0.8488   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m184      [0m | [0m-0.4068  [0m | [0m0.6104   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m185      [0m | [0m-2.311   [0m | [0m0.7691   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m186      [0m | [0m-0.5249  [0m | [0m0.1941   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m187      [0m | [0m-2.311   [0m | [0m0.8584   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m188      [0m | [0m-2.311   [0m | [0m0.8768   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m189      [0m | [0m-0.5641  [0m | [0m0.4288   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m190      [0m | [0m-0.6347  [0m | [0m0.13     [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m191      [0m | [0m-0.4214  [0m | [0m0.523    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m192      [0m | [0m-2.311   [0m | [0m0.8362   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m193      [0m | [0m-2.311   [0m | [0m0.9299   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m194      [0m | [0m-0.5633  [0m | [0m0.1553   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m195      [0m | [0m-0.59    [0m | [0m0.3325   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m196      [0m | [0m-2.311   [0m | [0m0.7914   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m197      [0m | [0m-2.311   [0m | [0m0.6332   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m198      [0m | [0m-2.311   [0m | [0m0.8179   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m199      [0m | [0m-0.4373  [0m | [0m0.5657   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m200      [0m | [0m-2.311   [0m | [0m0.9187   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m201      [0m | [0m-0.4077  [0m | [0m0.6492   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m202      [0m | [0m-0.5329  [0m | [0m0.6542   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m203      [0m | [0m-0.4731  [0m | [0m0.2643   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m204      [0m | [0m-0.4246  [0m | [0m0.2764   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m205      [0m | [0m-0.5603  [0m | [0m0.4442   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m206      [0m | [0m-0.6349  [0m | [0m0.235    [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m207      [0m | [0m-0.5962  [0m | [0m0.1836   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m208      [0m | [0m-0.7091  [0m | [0m0.02993  [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m209      [0m | [0m-0.2825  [0m | [0m0.5744   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m210      [0m | [0m-0.701   [0m | [0m0.3505   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m211      [0m | [0m-0.3807  [0m | [0m0.4329   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m212      [0m | [0m-0.2884  [0m | [0m0.4138   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m213      [0m | [0m-2.311   [0m | [0m0.5762   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m214      [0m | [0m-0.5014  [0m | [0m0.3541   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m215      [0m | [0m-2.311   [0m | [0m0.5598   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m216      [0m | [0m-0.3195  [0m | [0m0.4597   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m217      [0m | [0m-0.3226  [0m | [0m0.4615   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m218      [0m | [0m-0.3343  [0m | [0m0.4686   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m219      [0m | [0m-2.311   [0m | [0m0.6432   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

| [0m220      [0m | [0m-0.4956  [0m | [0m0.5562   [0m |


  0%|          | 0/2000 [00:00<?, ?it/s]

# 4. SGD

## 4.1 Hyperparameter Tuning

In [None]:
@jit
def compute_loss_and_updates_sgd(state, batch_images, batch_labels):
    def loss_fn(params):
        logits = state.apply_fn(params, batch_images)
        loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch_labels))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    return grads

In [None]:
lr, momentum = 0.1, 0.1
tx = optax.sgd(lr, momentum)
opt_state = train_state.TrainState.create(apply_fn=apply_fn, params=flat_params, tx=tx)

In [None]:
def train_model(state, X_train, y_train):
    @jit
    def apply_updates(state, updates):
        return state.apply_gradients(grads=updates)

    for i in trange(len(X_train)):
        curr_image = X_train[i].reshape(1,28,28,1)
        curr_label = y_train[i]

        updates = compute_loss_and_updates_sgd(state, curr_image, curr_label)
        state = apply_updates(state, updates)
    
    return state

In [None]:
def sgd_opt_fn(lr, momentum):
    tx = optax.sgd(lr, momentum)
    opt_state = train_state.TrainState.create(apply_fn=apply_fn, params=flat_params, tx=tx)
    result_state = train_model(opt_state, X_train, y_train)
    mean = result_state.params
    log_likelihood = -evaluate_neg_log_likelihood(mean, unflatten_fn, apply_fn, test_ds)

    return log_likelihood

In [None]:
sgd_optimizer = BayesianOptimization(
    f=sgd_opt_fn,
    pbounds={'lr': (1e-4, 1.0), 'momentum': (1e-4, 1.0)},
)

In [None]:
sgd_optimizer.maximize(
    init_points=100,
    n_iter=100,
)