We first initialize the system and set parameters.

In [135]:
import numpy as np
from lds_funcs import *
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Dimensions
N = 2  # state: [temp, humidity]
M = 2  # observed: [temp, humidity]
T = 500 # timesteps

# State transition (slow drift)
F = np.array([
    [1.0, 0.0],
    [0.0, 0.95]
])

# Observation model (we directly observe both)
H = np.eye(2)

# Process noise (weather randomness)
Q = np.array([
    [0.1, 0.0],
    [0.0, 0.05]
])

# Measurement noise (sensor noise)
R = np.array([
    [0.5, 0.0],
    [0.0, 0.3]
])


Construct the actual model by letting the sytem evolve and by adding noise.

In [136]:
x_true = np.zeros((T, N))
y_obs = np.zeros((T, M))

# Initial true state
x_true[0] = np.array([20.0, 0.6]) # 20Â°C, 60% humidity

for t in range(1, T):
    x_true[t] = F @ x_true[t-1] + np.random.multivariate_normal(np.zeros(N), Q)

for t in range(T):
    y_obs[t] = H @ x_true[t] + np.random.multivariate_normal(np.zeros(M), R)


Construct initial estimation of parameters.

In [137]:
# State transition (slow drift)
F_est_iter = np.array([
    [0.75, 0.10],
    [0.05, 0.75]
])

# Observation model (we directly observe both)
H_est_iter = np.eye(2) + np.random.multivariate_normal(np.zeros(2), np.array([
    [0.25, 0.15], [0.15, 0.25]
]))

# Process noise (weather randomness)
Q_est_iter = np.array([
    [0.45, 0.25],
    [0.25, 0.45]
])

# Measurement noise (sensor noise)
R_est_iter = np.array([
    [0.10, 0.1],
    [0.1, 0.10]
])


Run iterative algorithm

In [138]:
n_iters = 10

# keep track of stats
Q_error = np.empty(n_iters)
H_error = np.empty(n_iters)
F_error = np.empty(n_iters)
R_error = np.empty(n_iters)

P_iter = np.eye(N) * 5.0 # very uncertain initially
x_iter = np.array([18.0, 0.5]) # initial guess for state

# start

for iter in range(n_iters):

    x_filt, P_filt, x_pred, P_pred = kalman_filter(
        y=y_obs,
        F=F_est_iter,
        Q=Q_est_iter,
        H=H_est_iter,
        R=R_est_iter,
        x0=x_iter,
        P0=P_iter
    )

    """
    x_filt, x_pred: shape (T, N)
    P_filt, P_pred: shape (T, N, N)
    """

    x_smoothed, P_smoothed, C = rts_smoother(
        x_filt, P_filt, x_pred, P_pred, F_est_iter
    )

    """
    x_smoothed: shape (T, N)
    P_smoothed: shape (T, N, N)
    """

    Ex, Exx, Exx_tm1 = compute_expectations(x_smoothed, P_smoothed, C)

    """
    Ex: shape (T, N)
    Exx: shape (T, N, N)
    Exx_tm1: shape (T - 1, N, N)
    """

    F_est_iter, Q_est_iter, H_est_iter, R_est_iter = estimate_lds_params(Exx_tm1, Exx, Ex, y_obs)

    Q_est_iter = 0.5 * (Q_est_iter + Q_est_iter.T)
    Q_est_iter += 1e-3 * np.eye(N)

    R_est_iter = 0.5 * (R_est_iter + R_est_iter.T)
    R_est_iter += 1e-3 * np.eye(N)

    Q_error[iter] = np.linalg.norm(Q - Q_est_iter)
    H_error[iter] = np.linalg.norm(H - H_est_iter)
    R_error[iter] = np.linalg.norm(R - R_est_iter)
    F_error[iter] = np.linalg.norm(F - F_est_iter)

    x_iter = x_smoothed[0] # shape (N)
    P_iter = P_smoothed[0] # shape (N, N)

Plot estimated data vs. real data

In [139]:
fig = make_subplots(rows=1, cols=3)

fig.add_trace(go.Scatter(x=list(range(n_iters)), y=Q_error, name="Q errors"), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(n_iters)), y=F_error, name="F errors"), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(n_iters)), y=R_error, name="R errors"), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(n_iters)), y=H_error, name="H errors"), row=1, col=1)

fig.add_trace(go.Scatter(x=list(range(T)), y=x_true[:, 0], name="True temp"), row=1, col=2)
fig.add_trace(go.Scatter(x=list(range(T)), y=x_smoothed[:, 0], name="Est. temp"), row=1, col=2)

fig.add_trace(go.Scatter(x=list(range(T)), y=x_true[:, 1], name="True humidity"), row=1, col=3)
fig.add_trace(go.Scatter(x=list(range(T)), y=x_smoothed[:, 1], name="Est. humidity"), row=1, col=3)

fig.update_xaxes(title_text="Iterations", row=1, col=1)
fig.update_yaxes(title_text="Error rate of params", row=1, col=1)
fig.update_yaxes(title_text="Humidity", row=1, col=3)
fig.update_yaxes(title_text="Temperature", row=1, col=2)
fig.update_xaxes(title_text="Time", row=1, col=3)
fig.update_xaxes(title_text="Time", row=1, col=2)


fig.show()

In [140]:
print("F est.:")
print(F_est_iter)
print("Q est.:")
print(Q_est_iter)
print("H est.:")
print(H_est_iter)
print("R est.:")
print(R_est_iter)

print()
print('F real:')
print(F)
print('Q real:')
print(Q)
print('H real:')
print(H)
print('R real:')
print(R)

F est.:
[[ 1.09006357  0.18000848]
 [-0.15616253  0.68954191]]
Q est.:
[[ 0.39397232 -0.33924908]
 [-0.33924908  0.42370614]]
H est.:
[[1.64588682 0.26942497]
 [0.64597186 1.26959863]]
R est.:
[[0.14010285 0.12798793]
 [0.12798793 0.13634027]]

F real:
[[1.   0.  ]
 [0.   0.95]]
Q real:
[[0.1  0.  ]
 [0.   0.05]]
H real:
[[1. 0.]
 [0. 1.]]
R real:
[[0.5 0. ]
 [0.  0.3]]


In [141]:
limit = 50

print('True temp values (limited):')
print(x_true[:limit, 0])
print('Est. temp values (limited):')
print(x_smoothed[:limit, 0])
print()
print('True humidity values (limited):')
print(x_true[:limit, 1])
print('Est. humidity values (limited):')
print(x_smoothed[:limit, 1])

True temp values (limited):
[20.         19.98995335 19.74086923 19.73991415 19.29193093 18.97265599
 18.72726698 18.7029728  18.37795635 18.63946717 18.53256309 18.27124786
 18.26772176 18.46018341 18.46282831 18.3077367  18.54846711 19.06782548
 19.26129242 19.58734815 19.00599345 19.35086319 19.28405567 19.22903741
 19.74675576 19.4270071  20.30867292 20.52763755 20.21546559 20.54237768
 20.57173962 20.56242164 20.31045846 20.22647888 19.74970304 20.22548982
 19.95341031 20.14178323 20.37874859 20.38659032 19.96047986 19.84201446
 20.3655203  19.79566032 19.69530156 19.04730921 18.94134881 19.4154432
 19.55461859 19.70876819]
Est. temp values (limited):
[12.55148591 12.87170589 13.58232902 13.02124344 12.46832857 12.83439088
 12.35760683 12.91011985 12.48701977 12.66307839 12.65997969 12.56766781
 12.3539164  11.79971285 12.54726093 11.77354935 12.31903155 13.22796671
 12.71771226 14.06856888 12.3558775  12.76333263 12.29537435 12.78947225
 12.67160574 12.05803713 14.09419735 12.745