In [1]:
import numpy as np
from scipy.stats import multivariate_normal
from matplotlib import pyplot as plt
from copy import copy
from datetime import datetime, timedelta

from stonesoup.functions import gm_reduce_single

from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater
from stonesoup.predictor.imm import IMMPredictor
from stonesoup.updater.imm import IMMUpdater
from stonesoup.models.transition.linear import ConstantVelocity, \
    CombinedLinearGaussianTransitionModel, LinearGaussianTimeInvariantTransitionModel

from stonesoup.models.measurement.linear import LinearGaussian

from stonesoup.types.state import GaussianState
from stonesoup.types.hypothesis import SingleHypothesis
from stonesoup.types.detection import Detection
from stonesoup.types.state import StateVector, CovarianceMatrix, \
    GaussianMixtureState, WeightedGaussianState
from stonesoup.types.track import Track

from stonesoup.simulator.simple import SingleTargetGroundTruthSimulator

from matplotlib.patches import Ellipse


In [2]:
from bayou.datastructures import Gaussian as EMGPB2Gaussian
from bayou.datastructures import GMM as EMGPB2GMM
from bayou.datastructures import GMMSequence as EMGPB2GMMSequence
from bayou.models import LinearModel as EMGPB2LinearModel
from bayou.models import ConstantVelocity as EMGPB2ConstantVelocity
from bayou.expmax.skf import SKF

Plot function

In [3]:
def plot_cov_ellipse(cov, pos, nstd=2, ax=None, **kwargs):
    """
    Plots an `nstd` sigma error ellipse based on the specified covariance
    matrix (`cov`). Additional keyword arguments are passed on to the
    ellipse patch artist.
    Parameters
    ----------
        cov : The 2x2 covariance matrix to base the ellipse on
        pos : The location of the center of the ellipse. Expects a 2-element
            sequence of [x0, y0].
        nstd : The radius of the ellipse in numbers of standard deviations.
            Defaults to 2 standard deviations.
        ax : The axis that the ellipse will be plotted on. Defaults to the
            current axis.
        Additional keyword arguments are pass on to the ellipse patch.
    Returns
    -------
        A matplotlib ellipse artist
    """

    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:, order]

    if ax is None:
        ax = plt.gca()

    vals, vecs = eigsorted(cov)
    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))

    # Width and height are "full" widths, not radius
    width, height = 2 * nstd * np.sqrt(vals)
    ellip = Ellipse(xy=pos, width=width, height=height, angle=theta,
                    **kwargs)

    ax.add_artist(ellip)
    return ellip


Groundtruth and detections

In [4]:
gt_transition_model_1 = CombinedLinearGaussianTransitionModel(
                        (ConstantVelocity(1.0),
                         ConstantVelocity(1.0)))

gt_track = []
gt_time = []
iter_model_1 = 120
iter_model_2 = 80

timestamp_init = datetime.now()
state_init_1 = GaussianState(StateVector([[0], [0], [0], [0]]),
                           CovarianceMatrix(np.diag([0.0, 0.0, 0.0, 0.0])),
                           timestamp=timestamp_init)
gt_generator_1 = SingleTargetGroundTruthSimulator(gt_transition_model_1, state_init_1, number_steps=iter_model_1)

for time, gnd_paths in gt_generator_1.groundtruth_paths_gen():
    gnd_path = gnd_paths.pop()
    gt_track.append(gnd_path.state)
    gt_time.append(time)

gt_transition_model_2 = CombinedLinearGaussianTransitionModel(
                    (ConstantVelocity(8.0),
                    ConstantVelocity(8.0)))
state_init_2 = GaussianState(gt_track[-1].state_vector,
                           CovarianceMatrix(np.diag([0.0, 0.0, 0.0, 0.0])),
                           timestamp=gt_time[-1])
gt_generator_2 = SingleTargetGroundTruthSimulator(gt_transition_model_2, state_init_2, number_steps=iter_model_2)
for time, gnd_paths in gt_generator_2.groundtruth_paths_gen():
    gnd_path = gnd_paths.pop()
    gt_track.append(gnd_path.state)
    gt_time.append(time)

detection_track = []
mesurement_noise_1 = np.diag([0.75 ** 2, 0.75 ** 2])
measurement_model_1 = LinearGaussian(ndim_state=4, mapping=[0, 2], noise_covar=mesurement_noise_1)
for i in range(iter_model_1):
    m_ = Detection(measurement_model_1.function(gt_track[i].state_vector, measurement_model_1.rvs(1)), gt_time[i])
    detection_track.append(m_)

mesurement_noise_2 = np.diag([0.5 ** 2, 0.5 ** 2])
measurement_model_2 = LinearGaussian(ndim_state=4, mapping=[0, 2], noise_covar=mesurement_noise_2)
for i in range(iter_model_2):
    ii = i + iter_model_1
    m_ = Detection(measurement_model_2.function(gt_track[ii].state_vector, measurement_model_2.rvs(1)), gt_time[ii])
    detection_track.append(m_)

# print(gt_track)
# print(detection_track)
# print(gt_time)

Filter model. Testing without EMGPB2 estimated parameters.

In [12]:
ft_transition_model_1 = CombinedLinearGaussianTransitionModel(
                        (ConstantVelocity(1.0),
                         ConstantVelocity(1.0)))
ft_transition_model_2 = CombinedLinearGaussianTransitionModel(
                        (ConstantVelocity(15.0),
                         ConstantVelocity(15.0)))
ft_measurement_noise = np.diag([3.0 ** 2, 3.0 ** 2])

ft_measurement_model = LinearGaussian(ndim_state=4, mapping=[0, 2],
                                   noise_covar=ft_measurement_noise)

ft_model_transition_matrix = np.array([[0.5, 0.5],
                                       [0.5, 0.5]])
ft_predictor_1 = KalmanPredictor(ft_transition_model_1)
ft_predictor_2 = KalmanPredictor(ft_transition_model_2)
ft_imm_predictor = IMMPredictor([ft_predictor_1, ft_predictor_2], ft_model_transition_matrix)

ft_updater = KalmanUpdater(ft_measurement_model)
ft_imm_updater = IMMUpdater([ft_updater, ft_updater], ft_model_transition_matrix)

ft_state_init = WeightedGaussianState(StateVector([[0], [0], [0], [0]]),
                                      CovarianceMatrix(np.diag([5.0, 5.0, 5.0, 5.0])),
                                      timestamp=timestamp_init,
                                      weight=0.5)

prior1 = copy(ft_state_init)
prior = GaussianMixtureState([prior1, prior1])
track = Track([copy(prior)])
track_error = []

for i in range(1, len(detection_track)-1):
    measurement = detection_track[i]
    # State prediction
    prediction = ft_imm_predictor.predict(track.state, timestamp=gt_time[i])
    # Measurement prediction
    meas_prediction = ft_imm_updater.predict_measurement(prediction)
    # State update
    hyp = SingleHypothesis(prediction, measurement)
    prior = ft_imm_updater.update(hyp)
    track.append(prior)
    track_error.append(ft_measurement_model.function(track[i].state_vector, np.zeros((2,1))) - ft_measurement_model.function(gt_track[i].state_vector, np.zeros((2,1))))
    '''
    fig, (ax1, ax2) = plt.subplots(2,1)
    # Generate plots
    ax1.cla()
    ax2.cla()
    # PLot true trajectory
    data = np.array([state.state_vector for state in gt_track])
    ax1.plot(data[:, 0], data[:, 2], 'b-')
    # PLot estimated trajectory
    data = np.array([state.state_vector for state in track.states])
    ax1.plot(data[:, 0], data[:, 2], 'r-')
    # Plot innovation covariance
    plot_cov_ellipse(meas_prediction.covar,
                     meas_prediction.mean, edgecolor='b',
                     facecolor='none', ax=ax1)
    # Plot estimated covariance
    plot_cov_ellipse(track.state.covar[[0, 2], :][:, [0, 2]],
                     track.state.mean[[0, 2], :], edgecolor='r',
                     facecolor='none', ax=ax1)
    # Visualise model weights
    ax2.bar([1,2], prior.weights.ravel())
    plt.pause(0.0001)
    '''

track_error = np.asarray(track_error)
track_error = np.squeeze(track_error)
rmse_tmp = []
for track_error_ in track_error:
    rmse_tmp.append(track_error_ @ track_error_.T)
rmse = np.sqrt(np.mean(np.array(rmse_tmp)))
print("RMSE: " + str(rmse))

RMSE: 1.166523877015827


Filter model. Testing with correct parameters.

In [6]:
model_transition_matrix = np.array([[0.5, 0.5],
                                    [0.5, 0.5]])

predictor_1 = KalmanPredictor(gt_transition_model_1)
predictor_2 = KalmanPredictor(gt_transition_model_2)
imm_predictor = IMMPredictor([predictor_1, predictor_2],
                             model_transition_matrix)

updater_1 = KalmanUpdater(measurement_model_1)
updater_2 = KalmanUpdater(measurement_model_2)
imm_updater = IMMUpdater([updater_1, updater_2], model_transition_matrix)

prior1 = copy(ft_state_init)
prior = GaussianMixtureState([prior1, prior1])
track = Track([copy(prior)])
track_error = []

for i in range(1, len(detection_track)-1):
    measurement = detection_track[i]
    # State prediction
    prediction = imm_predictor.predict(track.state, timestamp=gt_time[i])
    # Measurement prediction
    meas_prediction = imm_updater.predict_measurement(prediction)
    # State update
    hyp = SingleHypothesis(prediction, measurement)
    prior = imm_updater.update(hyp)
    track.append(prior)
    track_error.append(ft_measurement_model.function(track[i].state_vector, np.zeros((2,1))) - ft_measurement_model.function(gt_track[i].state_vector, np.zeros((2,1))))
    # print(ft_measurement_model.function(track[i].state_vector, np.zeros((2,1))))
    # print(ft_measurement_model.function(gt_track[i].state_vector, np.zeros((2,1))))
    # print('----------------------')
    
    '''
    fig, (ax1, ax2) = plt.subplots(2,1)
    # Generate plots
    ax1.cla()
    ax2.cla()
    # PLot true trajectory
    data = np.array([state.state_vector for state in gt_track])
    ax1.plot(data[:, 0], data[:, 2], 'b-')
    # PLot estimated trajectory
    data = np.array([state.state_vector for state in track.states])
    ax1.plot(data[:, 0], data[:, 2], 'r-')
    # Plot innovation covariance
    plot_cov_ellipse(meas_prediction.covar,
                     meas_prediction.mean, edgecolor='b',
                     facecolor='none', ax=ax1)
    # Plot estimated covariance
    plot_cov_ellipse(track.state.covar[[0, 2], :][:, [0, 2]],
                     track.state.mean[[0, 2], :], edgecolor='r',
                     facecolor='none', ax=ax1)
    # Visualise model weights
    ax2.bar([1,2], prior.weights.ravel())
    plt.pause(0.0001)
    '''

track_error = np.asarray(track_error)
track_error = np.squeeze(track_error)
rmse_tmp = []
for track_error_ in track_error:
    rmse_tmp.append(track_error_ @ track_error_.T)
rmse = np.sqrt(np.mean(np.array(rmse_tmp)))
print("RMSE: " + str(rmse))

RMSE: 0.8727326829575636


EMGPB2

In [7]:
F = np.asarray([
        [1, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 1],
        [0, 0, 0, 1]
    ])
H = np.asanyarray([
    [1, 0, 0, 0],
    [0, 0, 1, 0]
])
def get_Q(Q_sig, dt=1):
    Q = (Q_sig ** 2) * np.asarray([
        [(1/3)*np.power(dt, 3), (1/2)*np.power(dt, 2), 0, 0],
        [(1/2) * np.power(dt, 2), dt, 0, 0],
        [0, 0, (1/3)*np.power(dt, 3), (1/2)*np.power(dt, 2)],
        [0, 0, (1/2) * np.power(dt, 2), dt]
    ])
    return Q

In [13]:
g1 = EMGPB2Gaussian(np.zeros([4, 1]), 10.0*np.eye(4))
g2 = EMGPB2Gaussian(np.zeros([4, 1]), 10.0*np.eye(4))
initial_gmm_state = EMGPB2GMM([g1, g2])

# measurements = 5 * np.random.randn(200, 2, 1) + 1

detection_track_EMGPB2 = [ele.state_vector for ele in detection_track]
gmmsequence = EMGPB2GMMSequence(np.float64(detection_track_EMGPB2), initial_gmm_state)

m1 = EMGPB2LinearModel(F, get_Q(1.0), H, (3.0 ** 2)*np.eye(2))
m2 = EMGPB2LinearModel(F, get_Q(15.0), H, (3.0 ** 2)*np.eye(2))
initial_models = [m1, m2]

Z = np.ones([2, 2]) / 2
dataset = [gmmsequence]

new_models_all, Z, dataset, LL = SKF.EM(dataset, initial_models, Z,
                                    max_iters=50, threshold=1e-5, learn_H=True, learn_R=True,
                                    learn_A=True, learn_Q=True, learn_init_state=False, learn_Z=True,
                                    diagonal_Q=False, wishart_prior=False)

new_models = new_models_all[-1]

new_models[0].Q = (new_models[0].Q + new_models[0].Q.T)/2
new_models[1].Q = (new_models[1].Q + new_models[1].Q.T)/2

-1204.6341076225829
model -- 0 new_A: 
[[ 1.00147473e+00  8.99876690e-01  9.30248296e-04  2.84671717e-02]
 [ 1.29940930e-03  8.98620613e-01  1.21817563e-03  1.99497909e-02]
 [-1.24492063e-03  4.49332404e-02  1.00011101e+00  9.48370192e-01]
 [-1.28191975e-03  5.17582150e-02 -1.98593219e-04  9.48933430e-01]]
model -- 1 new_A: 
[[ 1.00200846e+00  8.57463413e-01  1.27355212e-03  2.40684347e-02]
 [ 2.66274387e-03  8.06031510e-01  2.06763194e-03  3.85288146e-02]
 [-1.29666352e-03  3.19782055e-02  1.00033231e+00  9.34146298e-01]
 [-1.11579911e-03  2.17781285e-02  2.70016748e-04  9.41145118e-01]]
model -- 0 new_Q: 
[[5.98331983 5.14717399 0.55601069 0.67425339]
 [5.14717399 8.77018426 0.562453   0.96403675]
 [0.55601069 0.562453   6.07028075 5.23525757]
 [0.67425339 0.96403675 5.23525757 8.73658924]]
model -- 1 new_Q: 
[[11.19285947 10.3975651   3.53277583  3.41268748]
 [10.3975651  16.95952873  3.63965138  4.76653249]
 [ 3.53277583  3.63965138  9.64389553  8.95534492]
 [ 3.41268748  4.7665324

-972.8576504303447
model -- 0 new_A: 
[[ 1.00166538e+00  8.81624560e-01  1.19551405e-03  2.99200404e-02]
 [ 1.54461943e-04  9.55364278e-01  1.17816980e-03 -7.28247785e-03]
 [-1.24912669e-03  3.98533595e-02  1.00025458e+00  9.44855651e-01]
 [-2.28241862e-04  1.71359297e-02 -4.94119201e-04  9.94031401e-01]]
model -- 1 new_A: 
[[ 1.00049431e+00  9.51482223e-01  3.11047851e-04 -1.62133029e-02]
 [-2.82938662e-04  9.90242384e-01  6.36089763e-04 -2.18420597e-02]
 [-2.09968556e-03  8.91521416e-02  9.99538341e-01  9.13100881e-01]
 [-4.69258082e-04  3.62969528e-02 -9.18796199e-04  9.88696470e-01]]
model -- 0 new_Q: 
[[3.51098179 3.05306505 1.15906038 0.82863069]
 [3.0529718  5.72735825 0.92132734 1.32034994]
 [1.15169623 0.91734861 2.93841528 2.43192297]
 [0.82462149 1.31259672 2.43209309 4.50363058]]
model -- 1 new_Q: 
[[6.62295842 3.80788861 3.49933104 1.21151166]
 [3.80767445 5.96383795 1.45750414 1.42251951]
 [3.491792   1.45349417 4.73747529 2.74642448]
 [1.20723332 1.41464141 2.74647745 4.

-914.5969887870691
model -- 0 new_A: 
[[ 1.00167471e+00  8.79903002e-01  1.19234440e-03  2.81777880e-02]
 [-1.53274338e-04  9.78705718e-01  9.20970745e-04 -1.01386135e-02]
 [-1.25793968e-03  3.89744658e-02  1.00025356e+00  9.42253123e-01]
 [-3.44300714e-04  2.97025016e-02 -7.01087134e-04  9.95038175e-01]]
model -- 1 new_A: 
[[ 9.96923570e-01  9.98165701e-01  2.07593293e-04 -3.36983786e-01]
 [-1.35505812e-03  1.04385968e+00  4.45969398e-04 -6.20570192e-02]
 [-4.73040081e-03  1.17989462e-01  9.99291945e-01  6.68490720e-01]
 [-8.84691801e-04  5.74506551e-02 -1.21966163e-03  9.77159494e-01]]
model -- 0 new_Q: 
[[2.72340633 2.25616338 0.89569885 0.30192245]
 [2.2564855  4.28955307 0.50433888 0.50947312]
 [0.88606716 0.49883092 2.27883072 1.94965499]
 [0.29628273 0.49931321 1.94939813 3.75860142]]
model -- 1 new_Q: 
[[41.78830319  4.93241951 31.51420082  1.44184076]
 [ 4.93115647  5.36275257  2.31669698  0.76027013]
 [31.50481957  2.31196971 26.41665755  2.8686495 ]
 [ 1.43520372  0.75015488

-881.2762030889451
model -- 0 new_A: 
[[ 1.00160688e+00  8.82963655e-01  1.17895460e-03  2.56884742e-02]
 [-2.20176850e-04  9.84286187e-01  8.43411610e-04 -1.19160523e-02]
 [-1.34403317e-03  4.15667210e-02  1.00027426e+00  9.38526513e-01]
 [-4.58522892e-04  3.49547428e-02 -7.13860196e-04  9.91132170e-01]]
model -- 1 new_A: 
[[ 9.97648441e-01  6.96142956e-01  1.76861017e-03 -8.31858443e-01]
 [-3.84402902e-03  1.15135709e+00  1.13065806e-03 -1.00717102e-01]
 [-3.49336552e-03 -1.10583579e-01  9.99094966e-01  3.35684737e-01]
 [-2.19335560e-03  1.70424505e-01 -2.82247622e-03  1.01248447e+00]]
model -- 0 new_Q: 
[[2.03495333 2.06278904 0.37439458 0.07176061]
 [2.06295694 4.05475369 0.38621861 0.27520403]
 [0.37116484 0.38361806 1.84384029 1.93039205]
 [0.06912591 0.26877479 1.93029526 3.8075323 ]]
model -- 1 new_Q: 
[[ 3.80145929  1.50488066  1.21980873 -0.91618939]
 [ 1.5017958   7.5537215  -0.77070294  1.45360577]
 [ 1.21848002 -0.77181848  2.07103258  0.76225318]
 [-0.91925083  1.44747908

-871.0066405759949
model -- 0 new_A: 
[[ 1.00166333e+00  8.73449274e-01  1.27940953e-03  2.83155274e-02]
 [-1.62944328e-04  9.75050607e-01  9.35944930e-04 -7.86051247e-03]
 [-1.43949717e-03  4.14616623e-02  1.00043714e+00  9.31278932e-01]
 [-3.57561072e-04  2.63450931e-02 -4.81946438e-04  9.87771835e-01]]
model -- 1 new_A: 
[[ 9.98599717e-01  6.02837695e-01  3.66779785e-03 -8.85050085e-01]
 [-8.81004414e-04  9.91201870e-01  2.08944817e-03 -1.13242071e-01]
 [-3.51512405e-03 -8.28839793e-02  9.97760614e-01  3.83386970e-01]
 [-1.52917424e-03  2.07779965e-01 -6.06646085e-03  1.09728898e+00]]
model -- 0 new_Q: 
[[1.93453945 1.89476702 0.35683352 0.1288523 ]
 [1.89459078 3.70781405 0.39536706 0.4362985 ]
 [0.35977335 0.39643904 1.77267188 1.86084337]
 [0.13009798 0.43701076 1.86109357 3.54235696]]
model -- 1 new_Q: 
[[ 3.95909373  4.41871963 -0.53441688 -1.41335216]
 [ 4.4186681  18.94118614 -2.98625595  3.32951933]
 [-0.53159636 -2.98501049  2.41173743  1.76282001]
 [-1.41138408  3.32899849

-864.5790121835246
model -- 0 new_A: 
[[ 1.00181777e+00  8.53266209e-01  1.53740074e-03  3.18740800e-02]
 [-1.05336538e-04  9.62530853e-01  1.11668513e-03 -5.33047330e-03]
 [-1.68165969e-03  5.05675429e-02  1.00050570e+00  9.21420743e-01]
 [-3.84124100e-04  3.21903051e-02 -5.22844312e-04  9.85007277e-01]]
model -- 1 new_A: 
[[ 9.98021078e-01  6.13963407e-01  4.46627858e-03 -9.40647518e-01]
 [-7.34202191e-04  1.00533067e+00  1.78012998e-03 -1.18521339e-01]
 [-3.09122973e-03 -1.01123675e-01  9.97499059e-01  4.04824278e-01]
 [-1.38833350e-04  1.57336309e-01 -7.07631954e-03  1.14940491e+00]]
model -- 0 new_Q: 
[[1.81650711 1.64697612 0.35222258 0.21413532]
 [1.64621272 3.26150864 0.34276617 0.64501415]
 [0.35938558 0.3465445  1.74736959 1.80699247]
 [0.21844848 0.6515538  1.80772907 3.31744408]]
model -- 1 new_Q: 
[[ 4.04661733  4.88802423 -1.11824714 -1.86159174]
 [ 4.89015141 20.70856152 -3.29757939  3.57949354]
 [-1.1121887  -3.29405425  2.47684716  2.18120922]
 [-1.85476543  3.58131409

-862.3300767991748
model -- 0 new_A: 
[[ 1.00187377e+00  8.40750172e-01  1.74475035e-03  3.20094714e-02]
 [-1.18026417e-04  9.59712685e-01  1.18394611e-03 -5.82175337e-03]
 [-1.83532419e-03  5.16018657e-02  1.00063716e+00  9.14962884e-01]
 [-3.64119105e-04  3.19710080e-02 -5.23403042e-04  9.85002077e-01]]
model -- 1 new_A: 
[[ 9.97924074e-01  6.08260907e-01  4.74240162e-03 -9.59798731e-01]
 [-7.03938305e-04  1.02149174e+00  1.43219091e-03 -1.09961385e-01]
 [-3.01494962e-03 -1.06355700e-01  9.97489805e-01  4.09100451e-01]
 [ 2.87971551e-04  1.39436475e-01 -7.26845239e-03  1.16185500e+00]]
model -- 0 new_Q: 
[[1.73047636 1.51568273 0.33826914 0.22815913]
 [1.51429155 3.12776002 0.29319858 0.72410524]
 [0.34863649 0.29904566 1.77082103 1.82882744]
 [0.234906   0.73527621 1.83005894 3.26984204]]
model -- 1 new_Q: 
[[ 3.88388667  4.6665134  -1.17659725 -2.0255336 ]
 [ 4.67042016 22.52541931 -3.3830366   4.03542153]
 [-1.16805987 -3.37756356  2.48393993  2.27113271]
 [-2.01499657  4.03904503

-861.6271892811037
model -- 0 new_A: 
[[ 1.00191528e+00  8.29859872e-01  1.93743151e-03  3.14194295e-02]
 [-1.24610000e-04  9.58884239e-01  1.20887738e-03 -6.36595661e-03]
 [-1.95227479e-03  5.01513329e-02  1.00078253e+00  9.09734950e-01]
 [-3.51549798e-04  3.12138594e-02 -5.15020295e-04  9.85097491e-01]]
model -- 1 new_A: 
[[ 9.98032279e-01  5.94302983e-01  4.92414173e-03 -9.68278533e-01]
 [-6.85999320e-04  1.02997976e+00  1.26723482e-03 -1.01967258e-01]
 [-3.03950822e-03 -1.06194713e-01  9.97497595e-01  4.09847115e-01]
 [ 3.90929026e-04  1.35146846e-01 -7.33630351e-03  1.16533946e+00]]
model -- 0 new_Q: 
[[1.68713472 1.4678692  0.32173115 0.20502887]
 [1.46585482 3.16049237 0.25280275 0.74757001]
 [0.33476897 0.26046869 1.81130414 1.87388289]
 [0.21389263 0.76296617 1.87557851 3.27944249]]
model -- 1 new_Q: 
[[ 3.77587963  4.4380915  -1.16544929 -2.10707403]
 [ 4.44383953 23.97084944 -3.438105    4.39777931]
 [-1.15512231 -3.43106576  2.48243384  2.28567868]
 [-2.09361363  4.40352728

-861.3143904855226
model -- 0 new_A: 
[[ 1.00195913e+00  8.18823736e-01  2.13318624e-03  3.07345700e-02]
 [-1.27945960e-04  9.58347667e-01  1.22511466e-03 -6.83178049e-03]
 [-2.05173237e-03  4.79804050e-02  1.00092692e+00  9.05054012e-01]
 [-3.46883649e-04  3.08258066e-02 -5.08761030e-04  9.85037068e-01]]
model -- 1 new_A: 
[[ 9.98169205e-01  5.81563845e-01  5.04339760e-03 -9.72834012e-01]
 [-6.82037942e-04  1.03597397e+00  1.17369609e-03 -9.59426348e-02]
 [-3.08272500e-03 -1.05264428e-01  9.97515560e-01  4.09533677e-01]
 [ 4.26064438e-04  1.32791698e-01 -7.35760983e-03  1.16644792e+00]]
model -- 0 new_Q: 
[[1.65062931 1.44588398 0.31063851 0.1797305 ]
 [1.44326891 3.23974391 0.22779891 0.76257979]
 [0.32592858 0.2371687  1.84961965 1.91885337]
 [0.19046867 0.78211434 1.92095224 3.30094567]]
model -- 1 new_Q: 
[[ 3.71027206  4.24761648 -1.14915875 -2.1767051 ]
 [ 4.25531727 25.21041512 -3.48580601  4.69764658]
 [-1.13750774 -3.47743454  2.48185804  2.29098902]
 [-2.16076953  4.70578887

Filter model. Testing with learnt parameters from EMGPB2.

In [14]:
em_transition_model_1 = LinearGaussianTimeInvariantTransitionModel(transition_matrix=new_models[0].A, covariance_matrix=new_models[0].Q)
em_transition_model_2 = LinearGaussianTimeInvariantTransitionModel(transition_matrix=new_models[1].A, covariance_matrix=new_models[1].Q)
em_measurement_noise_1 = new_models[0].R
em_measurement_noise_2 = new_models[1].R

em_measurement_model_1 = LinearGaussian(ndim_state=4, mapping=[0, 2],
                                   noise_covar=em_measurement_noise_1)
em_measurement_model_2 = LinearGaussian(ndim_state=4, mapping=[0, 2],
                                   noise_covar=em_measurement_noise_2)

em_model_transition_matrix = np.array([[0.5, 0.5],
                                       [0.5, 0.5]])
em_predictor_1 = KalmanPredictor(em_transition_model_1)
em_predictor_2 = KalmanPredictor(em_transition_model_2)
em_imm_predictor = IMMPredictor([em_predictor_1, em_predictor_2], em_model_transition_matrix)

em_updater_1 = KalmanUpdater(em_measurement_model_1)
em_updater_2 = KalmanUpdater(em_measurement_model_2)
em_imm_updater = IMMUpdater([em_updater_1, em_updater_2], em_model_transition_matrix)

em_state_init = WeightedGaussianState(StateVector([[0], [0], [0], [0]]),
                                      CovarianceMatrix(np.diag([5.0, 5.0, 5.0, 5.0])),
                                      timestamp=timestamp_init,
                                      weight=0.5)

prior1 = copy(em_state_init)
prior = GaussianMixtureState([prior1, prior1])
track = Track([copy(prior)])
track_error = []

for i in range(1, len(detection_track)-1):
    measurement = detection_track[i]
    
    # State prediction
    prediction = em_imm_predictor.predict(track.state, timestamp=gt_time[i])
    # Measurement prediction
    meas_prediction = em_imm_updater.predict_measurement(prediction)
    # State update
    hyp = SingleHypothesis(prediction, measurement)
    prior = em_imm_updater.update(hyp)
    track.append(prior)
    track_error.append(em_measurement_model_1.function(track[i].state_vector, np.zeros((2,1))) - em_measurement_model_1.function(gt_track[i].state_vector, np.zeros((2,1))))
    
    '''
    fig, (ax1, ax2) = plt.subplots(2,1)
    # Generate plots
    ax1.cla()
    ax2.cla()
    # PLot true trajectory
    data = np.array([state.state_vector for state in gt_track])
    ax1.plot(data[:, 0], data[:, 2], 'b-')
    # PLot estimated trajectory
    data = np.array([state.state_vector for state in track.states])
    ax1.plot(data[:, 0], data[:, 2], 'r-')
    # Plot innovation covariance
    plot_cov_ellipse(meas_prediction.covar,
                     meas_prediction.mean, edgecolor='b',
                     facecolor='none', ax=ax1)
    # Plot estimated covariance
    plot_cov_ellipse(track.state.covar[[0, 2], :][:, [0, 2]],
                     track.state.mean[[0, 2], :], edgecolor='r',
                     facecolor='none', ax=ax1)
    # Visualise model weights
    ax2.bar([1,2], prior.weights.ravel())
    plt.pause(0.0001)
    '''

track_error = np.asarray(track_error)
track_error = np.squeeze(track_error)
rmse_tmp = []
for track_error_ in track_error:
    rmse_tmp.append(track_error_ @ track_error_.T)
rmse = np.sqrt(np.mean(np.array(rmse_tmp)))
print("RMSE: " + str(rmse))

RMSE: 0.8940777803304147


In [10]:
get_Q(0.5)

array([[0.08333333, 0.125     , 0.        , 0.        ],
       [0.125     , 0.25      , 0.        , 0.        ],
       [0.        , 0.        , 0.08333333, 0.125     ],
       [0.        , 0.        , 0.125     , 0.25      ]])