In [1]:
import jax

jax.config.update("jax_enable_x64", True)
from jax import numpy as jnp
import numpy as np
import scipy.stats
from matplotlib.figure import Figure
from tqdm.notebook import trange

%matplotlib inline

In [2]:
import sys
import os

lib_path = os.path.join(os.path.curdir, 'lib')
sys.path.insert(0, lib_path)

In [None]:
import importlib

import probit_network
import random_matrix

importlib.reload(random_matrix)
importlib.reload(probit_network)
from probit_network import ProbitLinear, ProbitLinearNetwork
from random_matrix import RandomGaussian, RandomOrthogonalProjection, ZeroMatrix

Roadmap
- orthogonal initialization
- inputs
- layer noise
- train

In [None]:
key = jax.random.PRNGKey(123)

# single layer test

In [None]:
f = ProbitLinear(
    3,
    1,
    key,
    A=RandomOrthogonalProjection(4.0),
    b=RandomOrthogonalProjection(0.0),
    C=RandomOrthogonalProjection(1.0),
    d=RandomOrthogonalProjection(1.0),
)
network = ProbitLinearNetwork(f)

In [None]:
x = 1 + jnp.zeros(f.in_size)
Σ = jnp.eye(x.shape[0]) * 1

In [None]:
mean, cov = f._mc_mean_cov(x, Σ, key, 1_000_000)
mean

In [None]:
print("ekf mean", f._propagate_mean_lin(x, Σ))
print("analytic mean", f._propagate_mean(x, Σ))
print("monte carlo mean", mean)

In [None]:
np.linalg.eigvalsh(f._propagate_cov(x, Σ))

In [None]:
print("ekf cov", f._propagate_cov_lin(x, Σ))
print("analytic cov", f._propagate_cov(x, Σ))
print("monte carlo covariance", cov)

# Small-variance eigenvalue rectification

In [None]:
key, key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(12), 5)
input_size = 2
output_size = 2
hidden_size = 5
network = ProbitLinearNetwork(
    ProbitLinear(
        input_size,
        output_size,
        key1,
        A=RandomOrthogonalProjection(),
        b=RandomOrthogonalProjection(),
    ),
)
μ = jnp.zeros(input_size)
Σ = jnp.eye(input_size)

In [None]:
mean, cov = network.propagate_mean_cov(μ, Σ * 1e-2)
mean, jnp.linalg.eigvalsh(cov)

# MLP UQ test

In [None]:
key, key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(12), 5)
input_size = 2
output_size = 1
hidden_size = 70
layers = [
    ProbitLinear(
        input_size,
        hidden_size,
        key1,
        A=RandomOrthogonalProjection(4.0),
        b=RandomOrthogonalProjection(1.0),
        C=RandomOrthogonalProjection(1.0),
        d=ZeroMatrix(),
    ),
    ProbitLinear(
        hidden_size,
        hidden_size,
        key2,
        A=RandomOrthogonalProjection(4.0),
        b=RandomOrthogonalProjection(1.0),
        C=RandomOrthogonalProjection(1.0),
        d=ZeroMatrix(),
    ),
    ProbitLinear(
        hidden_size,
        hidden_size,
        key3,
        A=RandomOrthogonalProjection(4.0),
        b=RandomOrthogonalProjection(1.0),
        C=RandomOrthogonalProjection(1.0),
        d=ZeroMatrix(),
    ),
    ProbitLinear(
        hidden_size,
        output_size,
        key4,
        A=RandomOrthogonalProjection(4.0),
        b=RandomOrthogonalProjection(1.0),
        C=RandomOrthogonalProjection(1.0),
        d=ZeroMatrix(),
    ),
]

In [None]:
μ = jnp.zeros(input_size)
Σ = jnp.eye(input_size)

In [None]:
network = ProbitLinearNetwork(*layers)

In [None]:
network.propagate_mean_cov(μ, Σ)

In [None]:
network.propagate_mean_cov_lin(μ, Σ * 1e-3)

In [None]:
network.propagate_mean_cov(μ, Σ * 1e-3)

In [None]:
def plot_mc(ax, μ, Σ):
    μ_unscented, Σ_unscented = network.propagate_mean_cov_unscented(μ, Σ)
    μ_lin, Σ_lin = network.propagate_mean_cov_lin(μ, Σ)
    μ_adf, Σ_adf = network.propagate_mean_cov(μ, Σ)

    input_samples = jax.random.multivariate_normal(
        key=key, mean=μ, cov=Σ, shape=100_000
    )
    output_samples = jax.vmap(network)(input_samples)
    y_mesh = np.linspace(np.min(output_samples), np.max(output_samples), 2000)

    ax.plot(
        y_mesh,
        scipy.stats.gaussian_kde(output_samples.reshape(-1))(y_mesh),
        label="empirical KDE",
    )
    ax.plot(
        y_mesh,
        scipy.stats.norm.pdf(
            y_mesh, loc=output_samples.mean(), scale=output_samples.std()
        ),
        label="pseudo-true Gaussian fit",
    )

    ax.plot(
        y_mesh,
        scipy.stats.norm.pdf(
            y_mesh, loc=μ_unscented.reshape(-1), scale=Σ_unscented.reshape(-1) ** 0.5
        ),
        label="unscented approximation",
    )

    ax.plot(
        y_mesh,
        scipy.stats.norm.pdf(
            y_mesh, loc=μ_lin.reshape(-1), scale=Σ_lin.reshape(-1) ** 0.5
        ),
        label="linear approximation",
        linestyle="--",
    )

    ax.plot(
        y_mesh,
        scipy.stats.norm.pdf(
            y_mesh, loc=μ_adf.reshape(-1), scale=Σ_adf.reshape(-1) ** 0.5
        ),
        label="my approximation",
        linestyle="--",
    )
    ax.legend()

In [None]:
fig = Figure(figsize=(8, 8), dpi=600, constrained_layout=1)
ax = fig.add_subplot(311)
ax.set_title("Covariance scale 0.01")
plot_mc(ax, μ, Σ * 1e-2)
ax = fig.add_subplot(312)
ax.set_title("Covariance scale 1")
plot_mc(ax, μ, Σ * 1)
ax = fig.add_subplot(313)
ax.set_title("Covariance scale 100")
plot_mc(ax, μ, Σ * 1e2)

fig.savefig("figures/deep-mlp.pdf")
fig

# Kalman filter

In [None]:
import equinox

In [None]:
@equinox.filter_jit
def schur_complement(A, B, C, x, y, method=1):
    """Returns a numerically stable(ish) attempt at
    x + B C^(-1) y,
    A - B C^(-1) B^T.
    """
    if method == 1:
        # C = U U^T
        U = jax.scipy.linalg.cholesky(C)
        # B_tilde = B U^-T
        B_tilde = jax.scipy.linalg.solve_triangular(U, B.T, trans=1, lower=False).T
        return (
            x + B_tilde @ jax.scipy.linalg.solve_triangular(U, y, lower=False),
            A - B_tilde.dot(B_tilde.T),
        )

## random MLP ensemble

In [None]:
import numpy as np


def ccf(*eigenvalues):
    """
    Generate the controllable canonical form (companion matrix) A
    for a system with the given eigenvalues.

    Parameters:
        *eigenvalues: Variable length argument list of eigenvalues (lambda1, lambda2, ..., lambdan)

    Returns:
        A: The controllable canonical form matrix (n x n)
    """
    # Ensure input is treated as a flat list
    eigenvalues = np.array(eigenvalues)

    # Compute the coefficients of the characteristic polynomial
    # The polynomial has the form: s^n + a_{n-1}s^{n-1} + ... + a_1 s + a_0
    poly_coeffs = np.poly(eigenvalues)  # Gives [1, a_{n-1}, ..., a_0]

    # Remove the leading 1 (coefficient of s^n)
    a_coeffs = -poly_coeffs[1:]  # Negative signs for canonical form

    n = len(a_coeffs)
    A = np.zeros((n, n))

    # Fill the last row with the negative coefficients
    A[-1, :] = a_coeffs[::-1]

    # Fill the subdiagonal with 1s
    for i in range(n - 1):
        A[i, i + 1] = 1

    return jnp.array(A)

In [None]:
key, key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(21), 5)

n_x = 4
n_y = 2

n_hidden_dynamics = 5
n_hidden_output = 500

F = ProbitLinearNetwork(
    ProbitLinear(
        in_size=n_x,
        out_size=n_x,
        key=key2,
        A=ccf(0.99, 0.99, 0.5, -0.3),
        b=RandomGaussian(0.0),
        C=ZeroMatrix(),
        d=ZeroMatrix(),
    ),
)
H = ProbitLinearNetwork(
    ProbitLinear.create_probit(
        in_size=n_x,
        out_size=n_hidden_output,
        key=key3,
        A=RandomOrthogonalProjection(1e1),
        b=RandomOrthogonalProjection(1e1),
    ),
    ProbitLinear.create_probit(
        in_size=n_hidden_output,
        out_size=n_y,
        key=key4,
        A=RandomOrthogonalProjection(1e1),
        b=ZeroMatrix(),
    ),
)
H_aug = H.augment_with_identity()

In [None]:
x = F(np.zeros(n_x))

In [None]:
sim_horizon = 200
Q = jnp.eye(n_x) * 1e-2
R = jnp.eye(n_y) * 1e-2

In [None]:
η_key, ϵ_key = jax.random.split(key, 2)

In [None]:
η = jax.random.multivariate_normal(
    mean=jnp.zeros(n_x), cov=Q, key=η_key, shape=sim_horizon
)

In [None]:
ϵ = jax.random.multivariate_normal(
    mean=jnp.zeros(n_y), cov=R, key=ϵ_key, shape=sim_horizon
)

In [None]:
x = np.zeros((sim_horizon, n_x))
y = np.zeros((sim_horizon, n_y))
y_noiseless = np.zeros((sim_horizon, n_y))
for i in range(1, sim_horizon):
    x[i, :] = F(x[i - 1, :]) + η[i - 1]
    y_noiseless[i, :] = H(x[i, :])
    y[i, :] = y_noiseless[i, :] + ϵ[i]

In [None]:
fig = Figure()
ax = fig.add_subplot(211)
ax.plot(x[:, 0], label="x_0")
ax.legend()
ax = fig.add_subplot(212)
ax.plot(y[:, 0], label="y_1")
ax.plot(y_noiseless[:, 0])
ax.legend()
fig

In [None]:
STATES = slice(None, n_x)
OUTPUTS = slice(n_x, None)
JOINT = slice(None, None)

joint_prediction = np.zeros((sim_horizon, n_x + n_y))
P_pred = np.zeros((sim_horizon, n_x + n_y, n_x + n_y))
P_post = np.zeros((sim_horizon, n_x, n_x))

joint_prediction[0, STATES] = 0
P_pred[0, STATES, STATES] = Q * 1
P_post[0, STATES, STATES] = Q * 1

In [None]:
def psd_error(A):
    return jnp.linalg.norm(A - rectify_eigenvalues(A))

In [None]:
for i in trange(1, sim_horizon):
    # predict x
    joint_prediction[i, STATES], P_pred[i, STATES, STATES] = F.propagate_mean_cov(
        joint_prediction[i - 1, STATES], P_post[i - 1, STATES, STATES]
    )
    P_pred[i, STATES, STATES] += Q
    # predict y
    joint_prediction[i, JOINT], P_pred[i, JOINT, JOINT] = H_aug.propagate_mean_cov(
        joint_prediction[i, STATES], P_pred[i, STATES, STATES]
    )
    P_pred[i, OUTPUTS, OUTPUTS] += R
    # correct x
    joint_prediction[i, STATES], P_post[i, STATES, STATES] = schur_complement(
        P_pred[i, STATES, STATES],
        P_pred[i, STATES, OUTPUTS],
        P_pred[i, OUTPUTS, OUTPUTS],
        joint_prediction[i, STATES],
        y[i, :] - joint_prediction[i, OUTPUTS],
    )
    if np.any(np.isnan(joint_prediction[i])):
        print(i)
        raise KeyboardInterrupt

In [None]:
x_pred = joint_prediction[:, STATES]
x_pred_std = jax.vmap(jnp.diag)(P_pred[:, STATES, STATES]) ** 0.5
y_pred = joint_prediction[:, OUTPUTS]
y_pred_std = jax.vmap(jnp.diag)(P_pred[:, OUTPUTS, OUTPUTS]) ** 0.5

In [None]:
print(((x - x_pred) ** 2).mean() ** 0.5, np.std(x))

In [None]:
time_slice = slice(1, None)

In [None]:
fig = Figure(figsize=(8, 4), dpi=100, constrained_layout=1)
ax = fig.gca()
ax.plot(x[time_slice, 1])
ax.fill_between(
    np.arange(sim_horizon)[time_slice],
    (x_pred[time_slice, 1] - x_pred_std[time_slice, 1]),
    (x_pred[time_slice, 1] + x_pred_std[time_slice, 1]),
    color="C1",
    alpha=0.5,
)
fig

In [None]:
fig = Figure(figsize=(8, 4), dpi=100, constrained_layout=1)
ax = fig.gca()
ax.plot(y_noiseless[:, 0])
# ax.plot(y_pred[:, 1])
ax.fill_between(
    np.arange(sim_horizon),
    y_pred[:, 0] - y_pred_std[:, 0],
    y_pred[:, 0] + y_pred_std[:, 0],
    color="C1",
    alpha=0.5,
)
fig

# vibe coded KF

In [None]:
kf = KalmanFilter(
    n_x=n_x, n_y=n_y, F=F, H_aug=H_aug, Q=Q, R=R, x0=5 * np.ones(n_x), P0=Q.copy()
)

joint_predictions = np.zeros((sim_horizon, n_x + n_y))
P_preds = np.zeros((sim_horizon, n_x + n_y, n_x + n_y))
P_posts = np.zeros((sim_horizon, n_x, n_x))

joint_predictions[0, :n_x] = 5  # initial x0
P_preds[0, :n_x, :n_x] = Q
P_posts[0] = Q

for i in trange(1, sim_horizon):
    kf.predict()
    kf.update(y[i])

    joint_predictions[i], P_preds[i] = kf.get_state()
    P_posts[i] = P_preds[i, :n_x, :n_x]