In [11]:
import numpy as np
from slds_funcs import *
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Dimensions
N = 2 # latent state: [temp, humidity]
M = 2 # observed: [temp, humidity]
T = 500
K = 2 # number of regimes

# regime transition matrix
# A[i, j] = P(s_t = j | s_{t-1} = i)
A = np.array([
    [0.95, 0.05], # from regime 0: mostly stays
    [0.10, 0.90] # from regime 1: sometimes switches
])

# initial regime prob
pi = np.array([0.70, 0.30]) # arbitrarily chosen, more probable to start in steady weather

# regime specific dynamics

# Regime 0: calm weather
F_0 = np.array([
    [1.0, 0.0],
    [0.0, 0.98]
])
Q_0 = np.array([
    [0.05, 0.0],
    [0.0, 0.02]
])

# Regime 1: volatile weather
F_1 = np.array([
    [1.0, -0.10],
    [0.0, 0.90]
])
Q_1 = np.array([
    [0.15, 0.02],
    [0.02, 0.10]
])

F = np.stack([F_0, F_1]) # (K, N, N)
Q = np.stack([Q_0, Q_1]) # (K, N, N)

# observation model (shared in this case)
H_0 = np.eye(2)
H_1 = np.eye(2)
H = np.stack([H_0, H_1]) # (K, M, N)

R_0 = np.array([
    [0.5, 0.0],
    [0.0, 0.3]
])
R_1 = np.array([
    [0.8, 0.1],
    [0.1, 0.5]
])
R = np.stack([R_0, R_1]) # (K, M, M)


In [12]:
regimes = np.empty(T, dtype=int)
regimes[0] = np.random.choice(2, p=pi)

x_true = np.zeros((T, N))
y_obs = np.zeros((T, M))

# Initial true state
x_true[0] = np.array([20.0, 0.6]) # 20Â°C, 60% humidity

for t in range(1, T):
    regimes[t] = np.random.choice(2, p=A[regimes[t - 1]])
    x_true[t] = F[regimes[t]] @ x_true[t-1] + np.random.multivariate_normal(np.zeros(N), Q[regimes[t]])
    y_obs[t] = H[regimes[t]] @ x_true[t] + np.random.multivariate_normal(np.zeros(M), R[regimes[t]])


In [13]:
import numpy as np

# --- Regime 0: Calm Weather ---
# State transition (slightly perturbed from identity, but stable)
F_est_iter_0 = np.array([
    [0.90, 0.05],   # Slightly less persistent than true F_0
    [0.02, 0.95]    # Slightly more persistent humidity
])

# Observation model (perturbed identity)
H_est_iter_0 = np.eye(2) + np.random.multivariate_normal(
    np.zeros(2),
    np.array([
        [0.10, 0.05],  # Smaller perturbation than before
        [0.05, 0.10]
    ])
)

# Process noise (overestimated, but positive definite)
Q_est_iter_0 = np.array([
    [0.60, 0.05],   # Overestimated variance for temperature
    [0.05, 0.50]    # Overestimated variance for humidity
])

# Measurement noise (underestimated)
R_est_iter_0 = np.array([
    [0.30, 0.05],   # Slightly underestimated sensor noise
    [0.05, 0.20]
])

# --- Regime 1: Volatile Weather ---
# State transition (more volatile than true F_1)
F_est_iter_1 = np.array([
    [0.85, -0.15],  # More volatile temperature, stronger coupling
    [0.10, 0.80]    # Less persistent humidity
])

# Observation model (perturbed identity)
H_est_iter_1 = np.eye(2) + np.random.multivariate_normal(
    np.zeros(2),
    np.array([
        [0.15, 0.10],  # Slightly larger perturbation
        [0.10, 0.15]
    ])
)

# Process noise (overestimated)
Q_est_iter_1 = np.array([
    [0.30, 0.10],   # Overestimated variance for temperature
    [0.10, 0.25]    # Overestimated variance for humidity
])

# Measurement noise (overestimated)
R_est_iter_1 = np.array([
    [0.30, 0.15],   # Overestimated sensor noise
    [0.15, 0.20]
])

# Stack for SLDS
F_est_iter = np.stack([F_est_iter_0, F_est_iter_1])  # (K, N, N)
Q_est_iter = np.stack([Q_est_iter_0, Q_est_iter_1])  # (K, N, N)
H_est_iter = np.stack([H_est_iter_0, H_est_iter_1])  # (K, M, N)
R_est_iter = np.stack([R_est_iter_0, R_est_iter_1])  # (K, M, M)

In [22]:
x_iter = np.array([18.0, 0.5]) # initial guess
P_iter = np.eye(N) * 5.0 # very uncertain initially

logliks = np.empty((K, T))

A_est = np.empty((K, K))
pi_est = np.empty(K)

Exx_total = np.empty((K, T, N, N))
Ex_total = np.empty((K, T, N))
Exx_tm1_total = np.empty((K, T - 1, N, N))

n_iters = 10

# keep track of stats
Q_error = np.empty((n_iters, K))
H_error = np.empty((n_iters, K))
F_error = np.empty((n_iters, K))
R_error = np.empty((n_iters, K))
A_error = np.empty(n_iters)

logliks_process = np.empty((n_iters, K, T))

for iter in range(n_iters):

    for k in range(K):

        x_filt, P_filt, x_pred, P_pred, loglik = kalman_filter_regime(
            y_obs, F_est_iter[k], Q_est_iter[k], H_est_iter[k], R_est_iter[k], x_iter, P_iter
        )

        """
        x_filt, x_pred: shape (T, N)
        P_filt, P_pred: shape (T, N, N)
        loglik: shape (T)
        """

        logliks[k] = loglik

        logliks_process[iter, k] = loglik

        eigvals_F = np.linalg.eigvals(F_est_iter[k])
        if np.any(np.abs(eigvals_F) > 1):
            F_est_iter[k] = F_est_iter[k] * 0.9  # Scale down F to stabilize

        x_smoothed, P_smoothed, C = rts_smoother(
            x_filt, P_filt, x_pred, P_pred, F_est_iter[k]
        )

        """
        x_smoothed: shape (T, N)
        P_smoothed: shape (T, N, N)
        """

        x_iter, P_iter = x_smoothed[0], P_smoothed[0]

        Ex, Exx, Exx_tm1 = compute_expectations(x_smoothed, P_smoothed, C)

        """
        Ex: shape (T, N)
        Exx: shape (T, N, N)
        Exx_tm1: shape (T - 1, N, N)
        """

        Ex_total[k] = Ex
        Exx_total[k] = Exx
        Exx_tm1_total[k] = Exx_tm1

    gamma, xi = forward_backward_regimes(logliks, A, pi)

    """
    gamma: shape (T, K)
    xi: shape (T, K, K)
    """

    for k in range(K):
        F_est_iter[k], Q_est_iter[k], H_est_iter[k], R_est_iter[k] = estimate_slds_params(
            Exx_tm1_total[k], Exx_total[k], Ex_total[k], y_obs, gamma, k
        )

    # update transition matrix A
    for i in range(K):
        den = np.sum(gamma[:-1, i])
        for j in range(K):
            num = np.sum(xi[:-1, i, j])
            A_est[i, j] = num / den
            
        A_est[i, :] /= np.sum(A_est[i, :])

    pi_est = gamma[0] / np.sum(gamma[0])

    A_error[iter] = np.linalg.norm(A - A_est)
    for k in range(K):
        Q_error[:, k] = np.linalg.norm(Q[k] - Q_est_iter)
        R_error[:, k] = np.linalg.norm(R[k] - R_est_iter)
        H_error[:, k] = np.linalg.norm(H[k] - H_est_iter)
        F_error[:, k] = np.linalg.norm(F[k] - F_est_iter)

# take mean of errors
# Q_error_mean = np.mean(Q_error, axis=-1)
# R_error_mean = np.mean(R_error, axis=-1)
# H_error_mean = np.mean(H_error, axis=-1)
# F_error_mean = np.mean(F_error, axis=-1)

mean_log_lik = np.mean(logliks_process, axis=-1) # (n_iters, K)

in RTS smoother, gotten P_pred:
[[[4.21949356 0.30984909]
  [0.30984909 6.25597852]]

 [[0.68570401 0.37730906]
  [0.37730906 1.87854299]]

 [[0.50358387 0.35961954]
  [0.35961954 1.75609436]]

 ...

 [[0.44437665 0.35796947]
  [0.35796947 1.7453675 ]]

 [[0.44437665 0.35796947]
  [0.35796947 1.7453675 ]]

 [[0.44437665 0.35796947]
  [0.35796947 1.7453675 ]]]
in RTS smoother, gotten P_pred:
[[[29.78859753 -4.32758444]
  [-4.32758444  0.7683502 ]]

 [[18.63840006 -1.56083149]
  [-1.56083149  0.14106172]]

 [[18.55636129 -1.58714357]
  [-1.58714357  0.13897538]]

 ...

 [[18.33647021 -1.59043871]
  [-1.59043871  0.1389316 ]]

 [[18.33647021 -1.59043871]
  [-1.59043871  0.1389316 ]]

 [[18.33647021 -1.59043871]
  [-1.59043871  0.1389316 ]]]
in RTS smoother, gotten P_pred:
[[[ 6.28756495 -0.26677384]
  [-0.26677384  1.31264012]]

 [[ 0.71240235  0.3227331 ]
  [ 0.3227331   1.71774259]]

 [[ 0.51479289  0.35506637]
  [ 0.35506637  1.76376284]]

 ...

 [[ 0.44740971  0.36375122]
  [ 0.363751

In [15]:
error_fig = make_subplots(rows=1, cols=2)

error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=Q_error[:, 0], name="Q errors regime 1"), row=1, col=1)
error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=F_error[:, 0], name="F errors regime 1"), row=1, col=1)
error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=H_error[:, 0], name="H errors regime 1"), row=1, col=1)
error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=R_error[:, 0], name="R errors regime 1"), row=1, col=1)

error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=Q_error[:, 1], name="Q errors regime 2"), row=1, col=2)
error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=F_error[:, 1], name="F errors regime 2"), row=1, col=2)
error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=H_error[:, 1], name="H errors regime 2"), row=1, col=2)
error_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=R_error[:, 1], name="R errors regime 2"), row=1, col=2)

error_fig.show()

In [16]:
ll_fig = make_subplots(rows=1, cols=2)

ll_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=mean_log_lik[:, 0], name="LL regime 1"), row=1, col=1)
ll_fig.add_trace(go.Scatter(x=list(range(n_iters)), y=mean_log_lik[:, 1], name="LL regime 2"), row=1, col=2)

ll_fig.show()

In [17]:
for k in range(K):
    print('regime', k)
    print('Q est:')
    print(Q_est_iter[k])
    print('Q real:'); print(Q[k])
    print()
    print('H est:')
    print(H_est_iter[k])
    print('H real:'); print(H[k])
    print()
    print('R est:')
    print(R_est_iter[k])
    print('R real:'); print(R[k])
    print()
    print('F est:')
    print(F_est_iter[k])
    print('F real:'); print(F[k])

regime 0
Q est:
[[0.28485264 0.16603264]
 [0.16603264 0.8124887 ]]
Q real:
[[0.05 0.  ]
 [0.   0.02]]

H est:
[[-0.09415082 -0.78985503]
 [-0.93371919  0.18058356]]
H real:
[[1. 0.]
 [0. 1.]]

R est:
[[0.4353946  0.01507029]
 [0.01507029 0.25882411]]
R real:
[[0.5 0. ]
 [0.  0.3]]

F est:
[[ 0.79848995  0.03890833]
 [-0.02879631  1.00521863]]
F real:
[[1.   0.  ]
 [0.   0.98]]
regime 1
Q est:
[[ 1.39035072 -0.16331927]
 [-0.16331927  0.28413259]]
Q real:
[[0.15 0.02]
 [0.02 0.1 ]]

H est:
[[ 0.85270882 -3.47726444]
 [ 0.0442993   0.22337233]]
H real:
[[1. 0.]
 [0. 1.]]

R est:
[[6.11201878 0.69049289]
 [0.69049289 0.43703947]]
R real:
[[0.8 0.1]
 [0.1 0.5]]

F est:
[[ 0.81751333 -1.64951065]
 [-0.09064839  0.21383993]]
F real:
[[ 1.  -0.1]
 [ 0.   0.9]]


In [18]:
param_fig = make_subplots(rows=1, cols=2)

param_fig.add_trace(go.Scatter(x=list(range(T)), y=x_smoothed[:, 0], name="Estimated temp"), row=1, col=1)
param_fig.add_trace(go.Scatter(x=list(range(T)), y=x_true[:, 0], name="True temp"), row=1, col=1)

param_fig.add_trace(go.Scatter(x=list(range(T)), y=x_smoothed[:, 1], name="Estimated humidity"), row=1, col=2)
param_fig.add_trace(go.Scatter(x=list(range(T)), y=x_true[:, 1], name="True humidity"), row=1, col=2)

param_fig.show()

In [23]:
print('true regime transition matrix:')
print(A)
print('estimated:')
print(A_est)

true regime transition matrix:
[[0.95 0.05]
 [0.1  0.9 ]]
estimated:
[[0.62422645 0.37577355]
 [0.03062    0.96938   ]]
