In [1]:
import numpy as np
from scipy.stats import multivariate_normal
from matplotlib import pyplot as plt
from copy import copy
from datetime import datetime, timedelta

from stonesoup.functions import gm_reduce_single

from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater
from stonesoup.predictor.imm import IMMPredictor
from stonesoup.updater.imm import IMMUpdater
from stonesoup.models.transition.linear import ConstantVelocity, \
    CombinedLinearGaussianTransitionModel, LinearGaussianTimeInvariantTransitionModel, RandomWalk

from stonesoup.models.measurement.linear import LinearGaussian

from stonesoup.types.state import GaussianState
from stonesoup.types.hypothesis import SingleHypothesis
from stonesoup.types.detection import Detection
from stonesoup.types.state import StateVector, CovarianceMatrix, \
    GaussianMixtureState, WeightedGaussianState
from stonesoup.types.track import Track

from stonesoup.simulator.simple import SingleTargetGroundTruthSimulator

from matplotlib.patches import Ellipse


In [2]:
from emgpb2.states import Gaussian as EMGPB2Gaussian
from emgpb2.states import GMM as EMGPB2GMM
from emgpb2.states import GMMSequence as EMGPB2GMMSequence
from emgpb2.models import RandomWalk as EMGPB2RandomWalk
from emgpb2.models import ConstantVelocity as EMGPB2ConstantVelocity
from emgpb2.EM import SKFEstimator

Simulate groundtruth and detections

In [3]:
# dimension of state vector and observation
state_dim = 2
obs_dim = 2

# define transition models
gt_transition_model_1 = CombinedLinearGaussianTransitionModel(
                        (RandomWalk(2.0 ** 2),
                         RandomWalk(2.0 ** 2)))
gt_transition_model_2 = CombinedLinearGaussianTransitionModel(
                        (RandomWalk(10.0 ** 2),
                         RandomWalk(10.0 ** 2)))

# define measurement models
mesurement_noise_1 = np.diag([0.5 ** 2, 0.5 ** 2])
measurement_model_1 = LinearGaussian(ndim_state=2, mapping=[0, 1], noise_covar=mesurement_noise_1)
mesurement_noise_2 = np.diag([0.75 ** 2, 0.75 ** 2])
measurement_model_2 = LinearGaussian(ndim_state=2, mapping=[0, 1], noise_covar=mesurement_noise_2)

# iteration numbers of two models
iter_model_1 = 220
iter_model_2 = 180

# generate groundtruth
gt_track = []
gt_time = []
timestamp_init = datetime.now()

# model 1
state_init_1 = GaussianState(StateVector(np.zeros((state_dim, 1))),
                           CovarianceMatrix(np.zeros((state_dim, state_dim))),
                           timestamp=timestamp_init)
gt_generator_1 = SingleTargetGroundTruthSimulator(gt_transition_model_1, state_init_1, number_steps=iter_model_1)
for time, gnd_paths in gt_generator_1.groundtruth_paths_gen():
    gnd_path = gnd_paths.pop()
    gt_track.append(gnd_path.state)
    gt_time.append(time)

# model 2
state_init_2 = GaussianState(gt_track[-1].state_vector, 
                             CovarianceMatrix(np.zeros((state_dim, state_dim))), 
                             timestamp=gt_time[-1])
gt_generator_2 = SingleTargetGroundTruthSimulator(gt_transition_model_2, state_init_2, number_steps=iter_model_2)
for time, gnd_paths in gt_generator_2.groundtruth_paths_gen():
    gnd_path = gnd_paths.pop()
    gt_track.append(gnd_path.state)
    gt_time.append(time)

# generate detections (adding noise)
# model 1
detection_track = []
for i in range(iter_model_1):
    m_ = Detection(measurement_model_1.function(gt_track[i].state_vector, measurement_model_1.rvs(1)), gt_time[i])
    detection_track.append(m_)
# model 2
for i in range(iter_model_2):
    ii = i + iter_model_1
    m_ = Detection(measurement_model_2.function(gt_track[ii].state_vector, measurement_model_2.rvs(1)), gt_time[ii])
    detection_track.append(m_)

Filter model. Testing by correct parameters.

In [4]:
# model transition probability
model_transition_matrix = np.array([[0.5, 0.5],
                                    [0.5, 0.5]])

# define two kalman predictors. re-use the groundtruth transition model.
predictor_1 = KalmanPredictor(gt_transition_model_1)
predictor_2 = KalmanPredictor(gt_transition_model_2)

# define IMM predictor
imm_predictor = IMMPredictor([predictor_1, predictor_2],
                             model_transition_matrix)

# define two kalman updaters. re-use the groundtruth measurement model.
updater_1 = KalmanUpdater(measurement_model_1)
updater_2 = KalmanUpdater(measurement_model_2)

# define IMM updater
imm_updater = IMMUpdater([updater_1, updater_2], model_transition_matrix)

# initialise state for kalman filter 1
ft_state_init_1 = WeightedGaussianState(StateVector(np.zeros((state_dim, 1))),
                                      CovarianceMatrix(np.eye(state_dim) * 1.0),
                                      timestamp=timestamp_init,
                                      weight=0.5)

# initialise state for kalman filter 2
ft_state_init_2 = WeightedGaussianState(StateVector(np.zeros((state_dim, 1))),
                                      CovarianceMatrix(np.eye(state_dim) * 1.0),
                                      timestamp=timestamp_init,
                                      weight=0.5)

# define Gaussian Mixture model for initial state
prior = GaussianMixtureState([ft_state_init_1, ft_state_init_2])
track = Track([copy(prior)])

# tracking
track_error = []
for i in range(1, len(detection_track)-1):
    # measurement
    measurement = detection_track[i]
    # State prediction
    prediction = imm_predictor.predict(track.state, timestamp=gt_time[i])
    # Measurement prediction
    meas_prediction = imm_updater.predict_measurement(prediction)
    # Association
    hyp = SingleHypothesis(prediction, measurement)
    # State update
    prior = imm_updater.update(hyp)
    track.append(prior)
    track_error.append(measurement_model_1.function(track[i].state_vector, np.zeros((obs_dim, 1)))
                       - measurement_model_1.function(gt_track[i].state_vector, np.zeros((obs_dim, 1))))

track_error = np.asarray(track_error)
track_error = np.squeeze(track_error)
rmse_tmp = []
for track_error_ in track_error:
    rmse_tmp.append(track_error_ @ track_error_.T)
rmse = np.sqrt(np.mean(np.array(rmse_tmp)))
print("RMSE: " + str(rmse))

RMSE: 0.8752274520492569


Gussing some IMM paramters.

In [5]:
init_P = [5.0, 5.0]
q = [1.0, 20.0]
r = [3.0, 3.0]

Filter model. Testing with correct parameters.

In [6]:
# define model transition probability
ft_model_transition_matrix = np.array([[0.5, 0.5],
                                       [0.5, 0.5]])
# define transition model 1 using guessed parameter
ft_transition_model_1 = CombinedLinearGaussianTransitionModel(
                        (RandomWalk(q[0] ** 2),
                         RandomWalk(q[0] ** 2)))
# define transition model 2 using guessed parameter
ft_transition_model_2 = CombinedLinearGaussianTransitionModel(
                        (RandomWalk(q[1] ** 2),
                         RandomWalk(q[1] ** 2)))

# define measurement model. use same measurement model for two Kalman filters.
ft_measurement_noise = np.diag([r[0] ** 2, r[0] ** 2])
ft_measurement_model = LinearGaussian(ndim_state=2, mapping=[0, 1],
                                   noise_covar=ft_measurement_noise)

# define two kalman updaters.
ft_predictor_1 = KalmanPredictor(ft_transition_model_1)
ft_predictor_2 = KalmanPredictor(ft_transition_model_2)

# define IMM predictor
ft_imm_predictor = IMMPredictor([ft_predictor_1, ft_predictor_2], ft_model_transition_matrix)

# define IMM updater. use same measurement model for two Kalman filters.
ft_updater = KalmanUpdater(ft_measurement_model)
ft_imm_updater = IMMUpdater([copy(ft_updater), copy(ft_updater)], ft_model_transition_matrix)

# define Gaussian Mixture model for initial state. We used the same initial state as previous filter.
prior = GaussianMixtureState([ft_state_init_1, ft_state_init_2])
track = Track([copy(prior)])
track_error = []

for i in range(1, len(detection_track)-1):
    # measurement
    measurement = detection_track[i]
    # State prediction
    prediction = ft_imm_predictor.predict(track.state, timestamp=gt_time[i])
    # Measurement prediction
    meas_prediction = ft_imm_updater.predict_measurement(prediction)
    # Association
    hyp = SingleHypothesis(prediction, measurement)
    # State update
    prior = ft_imm_updater.update(hyp)
    track.append(prior)
    track_error.append(ft_measurement_model.function(track[i].state_vector, np.zeros((2,1))) - ft_measurement_model.function(gt_track[i].state_vector, np.zeros((2,1))))

track_error = np.asarray(track_error)
track_error = np.squeeze(track_error)
rmse_tmp = []
for track_error_ in track_error:
    rmse_tmp.append(track_error_ @ track_error_.T)
rmse = np.sqrt(np.mean(np.array(rmse_tmp)))
print("RMSE: " + str(rmse))

RMSE: 2.21982124809413


Run EMGPB2 to estimate IMM parameters.

In [None]:
num_of_models = len(q)
gaussian_models = []
for i in range(num_of_models):
    gaussian_models.append(EMGPB2Gaussian(np.zeros([state_dim, 1]), (init_P[i] ** 2) * np.eye(state_dim)))
initial_gmm_state = EMGPB2GMM(gaussian_models)
# measurement sequence
detection_track_EMGPB2 = [ele.state_vector for ele in detection_track]
gmmsequence = EMGPB2GMMSequence(np.float64(detection_track_EMGPB2), initial_gmm_state)
dataset = [gmmsequence]
# Initial models. use guessed parameters.
randomwalk_models = []
for i in range(num_of_models):
    randomwalk_models.append(EMGPB2RandomWalk(q=q[i], r=r[i], state_dim=state_dim))

# Switching matrix probability
Z = np.ones((2, 2)) / 2

models_all, Z_all, dataset, LLs = SKFEstimator.EM(dataset, randomwalk_models, Z,
                                                 max_iters=300, threshold=1e-8, learn_H=True, learn_R=True,
                                                 learn_A=True, learn_Q=True, learn_init_state=False, learn_Z=True,
                                                 diagonal_Q=False, wishart_prior=False)


new_models = new_models_all[-1]

# new_models[0].Q = (new_models[0].Q + new_models[0].Q.T)/2
# new_models[1].Q = (new_models[1].Q + new_models[1].Q.T)/2
# new_models[0].R = (new_models[0].R + new_models[0].R.T)/2
# new_models[1].R = (new_models[1].R + new_models[1].R.T)/2

-2658.3708176157725
model -- 0 new_A: 
[[ 1.00069162 -0.00294135]
 [-0.00154716  1.00046209]]
model -- 1 new_A: 
[[ 0.97038623 -0.03502778]
 [-0.00425479  0.95888915]]
model -- 0 new_Q: 
[[ 6.20709899 -0.3337229 ]
 [-0.3337229   5.81807587]]
model -- 1 new_Q: 
[[151.72954891 -21.51413648]
 [-21.51413648 136.61121432]]
model -- 0 new_H: 
[[ 0.99616283 -0.00507871]
 [ 0.00121408  0.99731945]]
model -- 1 new_H: 
[[ 1.00149908  0.00418436]
 [-0.00181867  0.99510197]]
model -- 0 new_R: 
[[6.79500807 0.03032292]
 [0.03471326 6.7339821 ]]
model -- 1 new_R: 
[[13.09021009 -0.48300987]
 [-0.47261787 12.0669294 ]]
new_Z: 
[[0.69988727 0.29760647]
 [0.69988727 0.29760647]]
-----------------------------------------------
-2572.6775410035075
model -- 0 new_A: 
[[ 0.99977684 -0.00679462]
 [-0.00190629  0.99991336]]
model -- 1 new_A: 
[[ 0.97087083 -0.03412484]
 [-0.00181265  0.96166865]]
model -- 0 new_Q: 
[[10.46792553 -0.67455448]
 [-0.66893415  9.80798214]]
model -- 1 new_Q: 
[[138.60234692 -18.8

-2586.33126120612
model -- 0 new_A: 
[[ 0.99807457 -0.0121047 ]
 [-0.00294546  0.99816983]]
model -- 1 new_A: 
[[ 0.96947848 -0.03146824]
 [-0.00170906  0.95790993]]
model -- 0 new_Q: 
[[14.12722371 -0.88710764]
 [-0.87932837 13.58629488]]
model -- 1 new_Q: 
[[149.68238618 -19.13496237]
 [-19.12639545 133.58407465]]
model -- 0 new_H: 
[[ 0.99490731 -0.0024887 ]
 [-0.00497137  0.98061315]]
model -- 1 new_H: 
[[ 1.00044409  0.0012307 ]
 [-0.00255428  0.98416972]]
model -- 0 new_R: 
[[ 3.18803315 -0.17700051]
 [-0.1718799   3.32220879]]
model -- 1 new_R: 
[[ 5.69352454 -0.37359543]
 [-0.36800369  6.52671223]]
new_Z: 
[[0.73562997 0.26186376]
 [0.73562997 0.26186376]]
-----------------------------------------------
-2582.969294555136
model -- 0 new_A: 
[[ 0.99825305 -0.01244812]
 [-0.00280009  0.99858171]]
model -- 1 new_A: 
[[ 0.96971535 -0.03079431]
 [-0.00201386  0.95791396]]
model -- 0 new_Q: 
[[13.73845103 -0.86574765]
 [-0.85791646 13.19911081]]
model -- 1 new_Q: 
[[149.06788741 -19.

-2551.6251791904574
model -- 0 new_A: 
[[ 0.99931641 -0.01834559]
 [-0.00443654  0.99938787]]
model -- 1 new_A: 
[[ 9.73242132e-01 -2.36877855e-02]
 [-5.65467666e-04  9.62948070e-01]]
model -- 0 new_Q: 
[[10.03611949 -0.36873145]
 [-0.35997005  9.81439307]]
model -- 1 new_Q: 
[[139.99533172 -19.09831817]
 [-19.08774162 130.45951802]]
model -- 0 new_H: 
[[ 0.99476192 -0.00279299]
 [-0.00802188  0.97047564]]
model -- 1 new_H: 
[[ 0.99957507  0.00368728]
 [-0.00563446  0.97376817]]
model -- 0 new_R: 
[[1.98793671 0.02276133]
 [0.02821109 2.13918361]]
model -- 1 new_R: 
[[3.53116763 0.14422131]
 [0.15080889 4.16378856]]
new_Z: 
[[0.68303302 0.31446071]
 [0.68303302 0.31446071]]
-----------------------------------------------
-2549.3790478620335
model -- 0 new_A: 
[[ 0.99944092 -0.01880376]
 [-0.00480553  0.99888973]]
model -- 1 new_A: 
[[ 9.73456332e-01 -2.33098559e-02]
 [-2.18208149e-04  9.63732787e-01]]
model -- 0 new_Q: 
[[ 9.77842988 -0.30524401]
 [-0.29638842  9.58925236]]
model -- 1 

-2527.2484847048318
model -- 0 new_A: 
[[ 1.00122871 -0.02487567]
 [-0.00924186  0.98550982]]
model -- 1 new_A: 
[[ 0.97503226 -0.02005878]
 [ 0.0031913   0.97394027]]
model -- 0 new_Q: 
[[7.53037515 0.364881  ]
 [0.37471025 7.41147673]]
model -- 1 new_Q: 
[[131.22271687 -19.45052991]
 [-19.43748215 126.61563608]]
model -- 0 new_H: 
[[ 0.99489029 -0.00184874]
 [-0.00991184  0.96313828]]
model -- 1 new_H: 
[[ 0.99962168  0.0055702 ]
 [-0.00641274  0.96980131]]
model -- 0 new_R: 
[[1.43024956 0.17691717]
 [0.18282746 1.51543701]]
model -- 1 new_R: 
[[2.57996358 0.43553995]
 [0.4433726  2.7734172 ]]
new_Z: 
[[0.64272992 0.35476381]
 [0.64272992 0.35476381]]
-----------------------------------------------
-2525.4813715334367
model -- 0 new_A: 
[[ 1.00137645 -0.02558297]
 [-0.00952124  0.98355713]]
model -- 1 new_A: 
[[ 0.97514499 -0.01979011]
 [ 0.00335881  0.97490302]]
model -- 0 new_Q: 
[[7.38739296 0.40568979]
 [0.41560286 7.23501101]]
model -- 1 new_Q: 
[[130.53357708 -19.45730055]
 [-

-2511.389760657567
model -- 0 new_A: 
[[ 1.00269524 -0.03112078]
 [-0.0110457   0.96562716]]
model -- 1 new_A: 
[[ 0.97585172 -0.01823684]
 [ 0.0039763   0.98166163]]
model -- 0 new_Q: 
[[6.38067006 0.68766059]
 [0.69853441 5.89503679]]
model -- 1 new_Q: 
[[124.6716772  -19.39088691]
 [-19.37531745 123.47803044]]
model -- 0 new_H: 
[[ 9.94975770e-01 -6.59762175e-04]
 [-1.02182994e-02  9.60805676e-01]]
model -- 1 new_H: 
[[ 1.00053688  0.0066915 ]
 [-0.0064517   0.96747207]]
model -- 0 new_R: 
[[1.15385883 0.22050988]
 [0.22688025 1.10051317]]
model -- 1 new_R: 
[[2.05446603 0.48916925]
 [0.49825421 1.80294039]]
new_Z: 
[[0.61611415 0.38137959]
 [0.61611415 0.38137959]]
-----------------------------------------------
-2510.6298588793165
model -- 0 new_A: 
[[ 1.00277902 -0.03131193]
 [-0.01108382  0.96472795]]
model -- 1 new_A: 
[[ 0.97588054 -0.01822904]
 [ 0.00395407  0.98188972]]
model -- 0 new_Q: 
[[6.3274238  0.70213975]
 [0.71310978 5.82930924]]
model -- 1 new_Q: 
[[124.32666038 -1

-2505.2354789273727
model -- 0 new_A: 
[[ 1.00344099 -0.0321834 ]
 [-0.01099581  0.95893616]]
model -- 1 new_A: 
[[ 0.9760892  -0.01843363]
 [ 0.003418    0.98282194]]
model -- 0 new_Q: 
[[5.94270094 0.77782128]
 [0.7898331  5.3621939 ]]
model -- 1 new_Q: 
[[121.71389478 -18.96311316]
 [-18.94563377 121.92394103]]
model -- 0 new_H: 
[[ 9.95654923e-01  7.62034078e-05]
 [-9.62241671e-03  9.60415849e-01]]
model -- 1 new_H: 
[[ 1.00108275  0.00716723]
 [-0.00662618  0.96587526]]
model -- 0 new_R: 
[[1.00629722 0.22248267]
 [0.22932471 0.84999066]]
model -- 1 new_R: 
[[1.72533802 0.46888293]
 [0.47887532 1.26649607]]
new_Z: 
[[0.60291246 0.39458127]
 [0.60291246 0.39458127]]
-----------------------------------------------
-2504.9169376641103
model -- 0 new_A: 
[[ 1.00348839 -0.03222561]
 [-0.01096173  0.95863961]]
model -- 1 new_A: 
[[ 0.97610109 -0.0184536 ]
 [ 0.00335847  0.98283145]]
model -- 0 new_Q: 
[[5.91909558 0.77936811]
 [0.79146878 5.33434342]]
model -- 1 new_Q: 
[[121.54680048 -

-2502.357807491753
model -- 0 new_A: 
[[ 1.00396453 -0.03275392]
 [-0.01049795  0.95662673]]
model -- 1 new_A: 
[[ 0.97617623 -0.01859273]
 [ 0.00271353  0.98264061]]
model -- 0 new_Q: 
[[5.71767613 0.76605906]
 [0.77904336 5.11127348]]
model -- 1 new_Q: 
[[120.12511308 -18.47159664]
 [-18.45280373 121.05554519]]
model -- 0 new_H: 
[[ 9.96450754e-01  7.50815383e-04]
 [-9.07318865e-03  9.60551660e-01]]
model -- 1 new_H: 
[[ 1.00135094  0.0073508 ]
 [-0.00678521  0.96480404]]
model -- 0 new_R: 
[[0.90392429 0.21167356]
 [0.21891149 0.69179174]]
model -- 1 new_R: 
[[1.4873755  0.41823327]
 [0.42881786 0.96059046]]
new_Z: 
[[0.5950825  0.40241123]
 [0.5950825  0.40241123]]
-----------------------------------------------
-2502.1856487155637
model -- 0 new_A: 
[[ 1.00400595 -0.03281183]
 [-0.01045184  0.95651911]]
model -- 1 new_A: 
[[ 0.97617879 -0.01859914]
 [ 0.00265779  0.98260874]]
model -- 0 new_Q: 
[[5.7029366  0.76298916]
 [0.7760448  5.0963493 ]]
model -- 1 new_Q: 
[[120.02352945 -1

-2500.6633585868185
model -- 0 new_A: 
[[ 1.00446286 -0.03356892]
 [-0.00993342  0.95571503]]
model -- 1 new_A: 
[[ 0.97617399 -0.01862266]
 [ 0.00208458  0.98221992]]
model -- 0 new_Q: 
[[5.56094256 0.71889742]
 [0.73264003 4.96422409]]
model -- 1 new_Q: 
[[119.0831433  -18.00644814]
 [-17.98678748 120.40370796]]
model -- 0 new_H: 
[[ 0.99714978  0.00137579]
 [-0.00868623  0.96070126]]
model -- 1 new_H: 
[[ 1.00146932  0.00739602]
 [-0.00691873  0.96405821]]
model -- 0 new_R: 
[[0.82086482 0.19466884]
 [0.20220625 0.58343307]]
model -- 1 new_R: 
[[1.29774433 0.36112084]
 [0.3720687  0.76949696]]
new_Z: 
[[0.58949514 0.40799859]
 [0.58949514 0.40799859]]
-----------------------------------------------
-2500.550205225682
model -- 0 new_A: 
[[ 1.00450479 -0.03364828]
 [-0.009886    0.9556645 ]]
model -- 1 new_A: 
[[ 0.97617122 -0.01862117]
 [ 0.002036    0.98218377]]
model -- 0 new_Q: 
[[5.54934129 0.71425642]
 [0.728053   4.95435792]]
model -- 1 new_Q: 
[[119.00982308 -17.9694104 ]
 [-1

-2499.4824608332506
model -- 0 new_A: 
[[ 1.00496533 -0.03462618]
 [-0.00937255  0.9552318 ]]
model -- 1 new_A: 
[[ 0.97612691 -0.01857201]
 [ 0.001539    0.98180138]]
model -- 0 new_Q: 
[[5.43031505 0.66018037]
 [0.67448814 4.86087497]]
model -- 1 new_Q: 
[[118.29023773 -17.58124018]
 [-17.56102917 119.82468372]]
model -- 0 new_H: 
[[ 0.99771027  0.00192596]
 [-0.0084237   0.96080993]]
model -- 1 new_H: 
[[ 1.00149392  0.00736594]
 [-0.00704391  0.9634945 ]]
model -- 0 new_R: 
[[0.74901647 0.17543084]
 [0.18318045 0.50453416]]
model -- 1 new_R: 
[[1.13995031 0.30705911]
 [0.31821056 0.63943266]]
new_Z: 
[[0.58497627 0.41251746]
 [0.58497627 0.41251746]]
-----------------------------------------------
-2499.3984298481737
model -- 0 new_A: 
[[ 1.00500661 -0.03472373]
 [-0.0093271   0.95520047]]
model -- 1 new_A: 
[[ 0.97612216 -0.01856482]
 [ 0.00149717  0.98176849]]
model -- 0 new_Q: 
[[5.42013357 0.65512703]
 [0.66947437 4.85349016]]
model -- 1 new_Q: 
[[118.23130119 -17.54777878]
 [-

-2498.5808950827377
model -- 0 new_A: 
[[ 1.00544139 -0.03586642]
 [-0.00885077  0.95490956]]
model -- 1 new_A: 
[[ 0.97607025 -0.01846161]
 [ 0.00107471  0.98143167]]
model -- 0 new_Q: 
[[5.31411196 0.60023792]
 [0.6149557  4.7816511 ]]
model -- 1 new_Q: 
[[117.63784325 -17.20131258]
 [-17.18078019 119.28382923]]
model -- 0 new_H: 
[[ 0.99813507  0.00239776]
 [-0.00825248  0.96088559]]
model -- 1 new_H: 
[[ 1.00146467  0.00729525]
 [-0.00716233  0.96304208]]
model -- 0 new_R: 
[[0.68546957 0.15618448]
 [0.16407618 0.44440122]]
model -- 1 new_R: 
[[1.00662239 0.25925511]
 [0.27049717 0.54535422]]
new_Z: 
[[0.58110737 0.41638636]
 [0.58110737 0.41638636]]
-----------------------------------------------
-2498.515184043111
model -- 0 new_A: 
[[ 1.0054785  -0.03597478]
 [-0.00881012  0.9548872 ]]
model -- 1 new_A: 
[[ 0.97606589 -0.01845043]
 [ 0.00103978  0.98140337]]
model -- 0 new_Q: 
[[5.30502134 0.59540998]
 [0.61015605 4.77589225]]
model -- 1 new_Q: 
[[117.58844991 -17.17185205]
 [-1

-2497.8730618328045
model -- 0 new_A: 
[[ 1.0058501  -0.03717781]
 [-0.00839952  0.95467931]]
model -- 1 new_A: 
[[ 9.76025843e-01 -1.83161631e-02]
 [ 6.93953804e-04  9.81117168e-01]]
model -- 0 new_Q: 
[[5.21158018 0.54556194]
 [0.56056874 4.71999684]]
model -- 1 new_Q: 
[[117.09065446 -16.87103054]
 [-16.8503475  118.78703189]]
model -- 0 new_H: 
[[ 0.99844614  0.00279691]
 [-0.00814667  0.96093398]]
model -- 1 new_H: 
[[ 1.00140993  0.00720782]
 [-0.00727096  0.96266773]]
model -- 0 new_R: 
[[0.629137   0.13817842]
 [0.14615804 0.39706864]]
model -- 1 new_R: 
[[0.89392913 0.21860427]
 [0.22985435 0.47451212]]
new_Z: 
[[0.57776282 0.41973091]
 [0.57776282 0.41973091]]
-----------------------------------------------
-2497.8215223533616
model -- 0 new_A: 
[[ 1.00588023 -0.03728615]
 [-0.00836574  0.95466363]]
model -- 1 new_A: 
[[ 9.76023018e-01 -1.83033902e-02]
 [ 6.65957172e-04  9.81093356e-01]]
model -- 0 new_Q: 
[[5.20371757 0.54137721]
 [0.55640353 4.71554746]]
model -- 1 new_Q: 


-2497.321558881849
model -- 0 new_A: 
[[ 1.00616976 -0.03843028]
 [-0.00803503  0.95452376]]
model -- 1 new_A: 
[[ 9.76000877e-01 -1.81651140e-02]
 [ 3.93558066e-04  9.80853728e-01]]
model -- 0 new_Q: 
[[5.12465634 0.49975523]
 [0.51495703 4.67282892]]
model -- 1 new_Q: 
[[116.63762599 -16.59102828]
 [-16.57032469 118.34882172]]
model -- 0 new_H: 
[[ 0.99867276  0.00313159]
 [-0.00808488  0.96095754]]
model -- 1 new_H: 
[[ 1.00134651  0.00711917]
 [-0.00736805  0.9623538 ]]
model -- 0 new_R: 
[[0.57949485 0.12203992]
 [0.13006627 0.35894043]]
model -- 1 new_R: 
[[0.79919405 0.18493159]
 [0.19612824 0.4196531 ]]
new_Z: 
[[0.57490836 0.42258537]
 [0.57490836 0.42258537]]
-----------------------------------------------
-2497.2817946695677
model -- 0 new_A: 
[[ 1.00619237 -0.03852846]
 [-0.00800858  0.95451389]]
model -- 1 new_A: 
[[ 9.75999613e-01 -1.81531093e-02]
 [ 3.71843617e-04  9.80833852e-01]]
model -- 0 new_Q: 
[[5.11815511 0.49638017]
 [0.51159469 4.66946819]]
model -- 1 new_Q: 
[

-2496.899044249313
model -- 0 new_A: 
[[ 1.00640421 -0.03952665]
 [-0.00775414  0.95443276]]
model -- 1 new_A: 
[[ 9.75992267e-01 -1.80317643e-02]
 [ 1.62322249e-04  9.80633927e-01]]
model -- 0 new_Q: 
[[5.054004   0.46361309]
 [0.47893744 4.63746468]]
model -- 1 new_Q: 
[[116.2697858  -16.35740993]
 [-16.33678307 117.97632173]]
model -- 0 new_H: 
[[ 0.99884194  0.00341052]
 [-0.00805077  0.96095834]]
model -- 1 new_H: 
[[ 1.00128244  0.00703773]
 [-0.00745376  0.96208876]]
model -- 0 new_R: 
[[0.53603431 0.10796064]
 [0.11600311 0.32766252]]
model -- 1 new_R: 
[[0.71992023 0.15747562]
 [0.16857434 0.37622101]]
new_Z: 
[[0.5725092  0.42498453]
 [0.5725092  0.42498453]]
-----------------------------------------------
-2496.8688135800317
model -- 0 new_A: 
[[ 1.0064204  -0.03960956]
 [-0.00773408  0.95442761]]
model -- 1 new_A: 
[[ 9.75992102e-01 -1.80218257e-02]
 [ 1.45710593e-04  9.80617345e-01]]
model -- 0 new_Q: 
[[5.04881589 0.46101028]
 [0.47634215 4.63496731]]
model -- 1 new_Q: 
[

-2496.5786257918453
model -- 0 new_A: 
[[ 1.00657052 -0.04043735]
 [-0.00754213  0.95438952]]
model -- 1 new_A: 
[[ 9.75994207e-01 -1.79246805e-02]
 [-1.47015086e-05  9.80450457e-01]]
model -- 0 new_Q: 
[[4.99807112 0.43598788]
 [0.45138027 4.61127678]]
model -- 1 new_Q: 
[[115.97401599 -16.16360544]
 [-16.14312361 117.66613088]]
model -- 0 new_H: 
[[ 0.99897323  0.00364307]
 [-0.00803348  0.96094028]]
model -- 1 new_H: 
[[ 1.00122113  0.00696699]
 [-0.00752922  0.96186362]]
model -- 0 new_R: 
[[0.49811519 0.09585512]
 [0.10389179 0.30159198]]
model -- 1 new_R: 
[[0.65361498 0.13522209]
 [0.14619333 0.34115369]]
new_Z: 
[[0.57050717 0.42698656]
 [0.57050717 0.42698656]]
-----------------------------------------------
-2496.5557216591787
model -- 0 new_A: 
[[ 1.00658191 -0.04050528]
 [-0.00752702  0.95438749]]
model -- 1 new_A: 
[[ 9.75994686e-01 -1.79169097e-02]
 [-2.74684926e-05  9.80436598e-01]]
model -- 0 new_Q: 
[[4.99399083 0.43401216]
 [0.44940826 4.6094308 ]]
model -- 1 new_Q: 


-2496.335365827842
model -- 0 new_A: 
[[ 1.00668757 -0.04118071]
 [-0.00738187  0.95437667]]
model -- 1 new_A: 
[[ 9.76002025e-01 -1.78417196e-02]
 [-1.51616169e-04  9.80296709e-01]]
model -- 0 new_Q: 
[[4.95411268 0.4150259 ]
 [0.43044695 4.59187943]]
model -- 1 new_Q: 
[[115.73510401 -16.00242936]
 [-15.98213582 117.40867511]]
model -- 0 new_H: 
[[ 0.99907892  0.00383877]
 [-0.00802641  0.96090872]]
model -- 1 new_H: 
[[ 1.00116403  0.00690741]
 [-0.00759569  0.96167081]]
model -- 0 new_R: 
[[0.46501876 0.08550167]
 [0.09351775 0.27954025]]
model -- 1 new_R: 
[[0.59793869 0.11715122]
 [0.12797808 0.31231653]]
new_Z: 
[[0.56883169 0.42866204]
 [0.56883169 0.42866204]]
-----------------------------------------------
-2496.317906689054
model -- 0 new_A: 
[[ 1.00669561 -0.0412361 ]
 [-0.00737037  0.95437655]]
model -- 1 new_A: 
[[ 9.76002838e-01 -1.78357351e-02]
 [-1.61584592e-04  9.80285048e-01]]
model -- 0 new_Q: 
[[4.95090124 0.41352373]
 [0.4289458  4.59050667]]
model -- 1 new_Q: 
[[

-2496.1491329882365
model -- 0 new_A: 
[[ 1.00677052 -0.04178881]
 [-0.00725899  0.95438173]]
model -- 1 new_A: 
[[ 9.76012755e-01 -1.77776607e-02]
 [-2.59437188e-04  9.80166853e-01]]
model -- 0 new_Q: 
[[4.91939984 0.39901869]
 [0.41444122 4.57739825]]
model -- 1 new_Q: 
[[115.53955361 -15.86740732]
 [-15.84732569 117.19353742]]
model -- 0 new_H: 
[[ 0.99916647  0.00400577]
 [-0.00802565  0.96086866]]
model -- 1 new_H: 
[[ 1.0011117   0.00685812]
 [-0.00765438  0.96150416]]
model -- 0 new_R: 
[[0.4360384  0.07663973]
 [0.08462596 0.26063347]]
model -- 1 new_R: 
[[0.55085861 0.10237276]
 [0.11304805 0.288195  ]]
new_Z: 
[[0.56741701 0.43007673]
 [0.56741701 0.43007673]]
-----------------------------------------------
-2496.1356850555876
model -- 0 new_A: 
[[ 1.00677626 -0.04183438]
 [-0.00725009  0.95438268]]
model -- 1 new_A: 
[[ 9.76013721e-01 -1.77730103e-02]
 [-2.67376312e-04  9.80156954e-01]]
model -- 0 new_Q: 
[[4.9168505  0.39786386]
 [0.41328559 4.57636772]]
model -- 1 new_Q: 


-2496.0049415433696
model -- 0 new_A: 
[[ 1.00683005 -0.04229228]
 [-0.00716296  0.95439657]]
model -- 1 new_A: 
[[ 9.76024710e-01 -1.77274765e-02]
 [-3.46079825e-04  9.80056136e-01]]
model -- 0 new_Q: 
[[4.89170094 0.38663486]
 [0.40204094 4.56647725]]
model -- 1 new_Q: 
[[115.37689392 -15.75323494]
 [-15.73337486 117.01181293]]
model -- 0 new_H: 
[[ 0.99924058  0.00415052]
 [-0.00802884  0.96082392]]
model -- 1 new_H: 
[[ 1.00106418  0.00681775]
 [-0.00770639  0.96135879]]
model -- 0 new_R: 
[[0.41053582 0.06902085]
 [0.07697201 0.24422155]]
model -- 1 new_R: 
[[0.51070138 0.09016685]
 [0.10069007 0.2677015 ]]
new_Z: 
[[0.56620927 0.43128446]
 [0.56620927 0.43128446]]
-----------------------------------------------
-2495.9944592574966
model -- 0 new_A: 
[[ 1.0068342  -0.04233033]
 [-0.00715592  0.95439809]]
model -- 1 new_A: 
[[ 9.76025729e-01 -1.77237906e-02]
 [-3.52531489e-04  9.80047647e-01]]
model -- 0 new_Q: 
[[4.88965254 0.38573376]
 [0.40113788 4.5656958 ]]
model -- 1 new_Q: 


-2495.8919761460565
model -- 0 new_A: 
[[ 1.00687335 -0.04271536]
 [-0.00708637  0.95441642]]
model -- 1 new_A: 
[[ 9.76036960e-01 -1.76873327e-02]
 [-4.17083025e-04  9.79960762e-01]]
model -- 0 new_Q: 
[[4.86932708 0.37690408]
 [0.39228247 4.55816002]]
model -- 1 new_Q: 
[[115.23940592 -15.65572628]
 [-15.63608818 116.85650296]]
model -- 0 new_H: 
[[ 0.99930437  0.00427785]
 [-0.00803451  0.96077717]]
model -- 1 new_H: 
[[ 1.00102121  0.00678494]
 [-0.00775267  0.96123087]]
model -- 0 new_R: 
[[0.38796173 0.06242874]
 [0.07034238 0.22981553]]
model -- 1 new_R: 
[[0.4761307  0.07997344]
 [0.09034829 0.25004535]]
new_Z: 
[[0.56516656 0.43232717]
 [0.56516656 0.43232717]]
-----------------------------------------------
-2495.883713598475
model -- 0 new_A: 
[[ 1.00687639 -0.0427476 ]
 [-0.0070807   0.95441819]]
model -- 1 new_A: 
[[ 9.76037975e-01 -1.76843510e-02]
 [-4.22423637e-04  9.79953412e-01]]
model -- 0 new_Q: 
[[4.86766254 0.37619034]
 [0.39156607 4.55756166]]
model -- 1 new_Q: 
[

-2495.8024970861707
model -- 0 new_A: 
[[ 1.00690524 -0.04307644]
 [-0.00702414  0.95443834]]
model -- 1 new_A: 
[[ 9.76049003e-01 -1.76544931e-02]
 [-4.76309332e-04  9.79877806e-01]]
model -- 0 new_Q: 
[[4.85103602 0.36914789]
 [0.38449199 4.55176638]]
model -- 1 new_Q: 
[[115.12147936 -15.57163877]
 [-15.55221746 116.72226693]]
model -- 0 new_H: 
[[ 0.99936004  0.00439134]
 [-0.0080417   0.96073011]]
model -- 1 new_H: 
[[ 1.00098242  0.00675844]
 [-0.00779403  0.96111742]]
model -- 0 new_R: 
[[0.36785534 0.05668323]
 [0.0645587  0.21704339]]
model -- 1 new_R: 
[[0.4460936  0.07136409]
 [0.08159666 0.23464381]]
new_Z: 
[[0.56425676 0.43323697]
 [0.56425676 0.43323697]]
-----------------------------------------------
-2495.7959094128823
model -- 0 new_A: 
[[ 1.0069075  -0.04310419]
 [-0.00701947  0.95444022]]
model -- 1 new_A: 
[[ 9.76049989e-01 -1.76520195e-02]
 [-4.80806942e-04  9.79871373e-01]]
model -- 0 new_Q: 
[[4.8496636  0.36857384]
 [0.38391489 4.55130382]]
model -- 1 new_Q: 


-2495.730831161418
model -- 0 new_A: 
[[ 1.00692906 -0.04338917]
 [-0.00697262  0.9544611 ]]
model -- 1 new_A: 
[[ 9.76060606e-01 -1.76270081e-02]
 [-5.26533835e-04  9.79804875e-01]]
model -- 0 new_Q: 
[[4.83586953 0.36286891]
 [0.37817522 4.5468061 ]]
model -- 1 new_Q: 
[[115.01887971 -15.49842407]
 [-15.47921112 116.60490344]]
model -- 0 new_H: 
[[ 0.99940919  0.00449364]
 [-0.00804977  0.96068383]]
model -- 1 new_H: 
[[ 1.0009474   0.00673721]
 [-0.00783117  0.96101611]]
model -- 0 new_R: 
[[0.34983555 0.05163738]
 [0.05947519 0.20561864]]
model -- 1 new_R: 
[[0.4197643  0.06401233]
 [0.07410991 0.22106101]]
new_Z: 
[[0.56345467 0.43403907]
 [0.56345467 0.43403907]]
-----------------------------------------------
-2495.7255267064706
model -- 0 new_A: 
[[ 1.00693075 -0.04341338]
 [-0.00696872  0.95446301]]
model -- 1 new_A: 
[[ 9.76061548e-01 -1.76249149e-02]
 [-5.30379152e-04  9.79799191e-01]]
model -- 0 new_Q: 
[[4.83472394 0.36240056]
 [0.37770364 4.54644583]]
model -- 1 new_Q: 
[

-2495.672907854605
model -- 0 new_A: 
[[ 1.00694702 -0.04366348]
 [-0.00692929  0.95448388]]
model -- 1 new_A: 
[[ 9.76071650e-01 -1.76035586e-02]
 [-5.69731946e-04  9.79740168e-01]]
model -- 0 new_Q: 
[[4.82314659 0.35771658]
 [0.37298367 4.54293378]]
model -- 1 new_Q: 
[[114.92854691 -15.43411787]
 [-15.41510324 116.50127141]]
model -- 0 new_H: 
[[ 0.99945301  0.00458674]
 [-0.00805831  0.96063898]]
model -- 1 new_H: 
[[ 1.00091573  0.00672038]
 [-0.00786466  0.96092509]]
model -- 0 new_R: 
[[0.33358898 0.04717265]
 [0.05497398 0.19531791]]
model -- 1 new_R: 
[[0.39649385 0.05766903]
 [0.06763933 0.20896606]]
new_Z: 
[[0.56274113 0.4347526 ]
 [0.56274113 0.4347526 ]]
-----------------------------------------------
-2495.6686011568563
model -- 0 new_A: 
[[ 1.0069483  -0.04368485]
 [-0.00692599  0.95448576]]
model -- 1 new_A: 
[[ 9.76072543e-01 -1.76017552e-02]
 [-5.73062931e-04  9.79735100e-01]]
model -- 0 new_Q: 
[[4.82217984 0.35732951]
 [0.37259332 4.54265182]]
model -- 1 new_Q: 
[

-2495.625739989075
model -- 0 new_A: 
[[ 1.00696071 -0.04390682]
 [-0.00689233  0.95450625]]
model -- 1 new_A: 
[[ 9.76082084e-01 -1.75832052e-02]
 [-6.07349958e-04  9.79682275e-01]]
model -- 0 new_Q: 
[[4.81236654 0.35343649]
 [0.36866427 4.53990143]]
model -- 1 new_Q: 
[[114.84819993 -15.37718602]
 [-15.35835906 116.40896557]]
model -- 0 new_H: 
[[ 0.99949245  0.00467219]
 [-0.00806701  0.96059594]]
model -- 1 new_H: 
[[ 1.00088703  0.00670724]
 [-0.00789497  0.96084288]]
model -- 0 new_R: 
[[0.31885769 0.04319351]
 [0.05095994 0.18596494]]
model -- 1 new_R: 
[[0.37576907 0.05214274]
 [0.06199347 0.19810404]]
new_Z: 
[[0.56210134 0.4353924 ]
 [0.56210134 0.4353924 ]]
-----------------------------------------------
-2495.622220583498
model -- 0 new_A: 
[[ 1.0069617  -0.04392588]
 [-0.00688949  0.95450808]]
model -- 1 new_A: 
[[ 9.76082925e-01 -1.75816266e-02]
 [-6.10268961e-04  9.79677722e-01]]
model -- 0 new_Q: 
[[4.81154323 0.35311318]
 [0.36833771 4.5396806 ]]
model -- 1 new_Q: 
[[

-2495.5870863163645
model -- 0 new_A: 
[[ 1.00697122 -0.04412475]
 [-0.00686039  0.95452794]]
model -- 1 new_A: 
[[ 9.76091907e-01 -1.75652847e-02]
 [-6.40467134e-04  9.79630092e-01]]
model -- 0 new_Q: 
[[4.80314388 0.34984441]
 [0.3650336  4.53752428]]
model -- 1 new_Q: 
[[114.77609447 -15.32641668]
 [-15.30776672 116.32609458]]
model -- 0 new_H: 
[[ 0.99952823  0.00475119]
 [-0.0080757   0.96055488]]
model -- 1 new_H: 
[[ 1.00086098  0.00669718]
 [-0.00792252  0.96076827]]
model -- 0 new_R: 
[[0.3054288  0.0396229 ]
 [0.04735617 0.17741898]]
model -- 1 new_R: 
[[0.3571812  0.04728515]
 [0.05702371 0.18827575]]
new_Z: 
[[0.56152354 0.4359702 ]
 [0.56152354 0.4359702 ]]
-----------------------------------------------
-2495.584192604559
model -- 0 new_A: 
[[ 1.00697197 -0.04414191]
 [-0.00685792  0.95452972]]
model -- 1 new_A: 
[[ 9.76092699e-01 -1.75638848e-02]
 [-6.43050801e-04  9.79625972e-01]]
model -- 0 new_Q: 
[[4.8024357  0.34957133]
 [0.36475736 4.53735101]]
model -- 1 new_Q: 
[

-2495.555238642726
model -- 0 new_A: 
[[ 1.00697928 -0.04432162]
 [-0.00683251  0.95454883]]
model -- 1 new_A: 
[[ 9.76101150e-01 -1.75493076e-02]
 [-6.69899681e-04  9.79582737e-01]]
model -- 0 new_Q: 
[[4.79518106 0.34679747]
 [0.36194929 4.53566076]]
model -- 1 new_Q: 
[[114.71085829 -15.2808376 ]
 [-15.26235432 116.25113883]]
model -- 0 new_H: 
[[ 0.99956092  0.00482469]
 [-0.00808423  0.9605159 ]]
model -- 1 new_H: 
[[ 1.00083726  0.00668974]
 [-0.00794764  0.96070029]]
model -- 0 new_R: 
[[0.29312612 0.03639841]
 [0.04410033 0.16956649]]
model -- 1 new_R: 
[[0.34040219 0.04298032]
 [0.05261369 0.17932353]]
new_Z: 
[[0.56099827 0.43649546]
 [0.56099827 0.43649546]]
-----------------------------------------------
-2495.5528488030177
model -- 0 new_A: 
[[ 1.00697986 -0.04433719]
 [-0.00683034  0.95455054]]
model -- 1 new_A: 
[[ 9.76101896e-01 -1.75480518e-02]
 [-6.72206944e-04  9.79578985e-01]]
model -- 0 new_Q: 
[[4.79456688 0.34656465]
 [0.36171342 4.53552518]]
model -- 1 new_Q: 
[

Filter model. Testing with learnt parameters from EMGPB2.

In [None]:
em_transition_model_1 = LinearGaussianTimeInvariantTransitionModel(transition_matrix=new_models[0].A, covariance_matrix=new_models[0].Q)
em_transition_model_2 = LinearGaussianTimeInvariantTransitionModel(transition_matrix=new_models[1].A, covariance_matrix=new_models[1].Q)
em_measurement_noise_1 = new_models[0].R
em_measurement_noise_2 = new_models[1].R

em_measurement_model_1 = LinearGaussian(ndim_state=2, mapping=[0, 1], 
                                        noise_covar=em_measurement_noise_1)
em_measurement_model_2 = LinearGaussian(ndim_state=2, mapping=[0, 1], 
                                        noise_covar=em_measurement_noise_2)

em_model_transition_matrix = np.array([[0.5, 0.5],
                                       [0.5, 0.5]])
em_predictor_1 = KalmanPredictor(em_transition_model_1)
em_predictor_2 = KalmanPredictor(em_transition_model_2)
em_imm_predictor = IMMPredictor([em_predictor_1, em_predictor_2], em_model_transition_matrix)

em_updater_1 = KalmanUpdater(em_measurement_model_1)
em_updater_2 = KalmanUpdater(em_measurement_model_2)
em_imm_updater = IMMUpdater([em_updater_1, em_updater_2], em_model_transition_matrix)

em_state_init = WeightedGaussianState(StateVector([[0], [0]]),
                                      CovarianceMatrix(np.diag([5.0, 5.0])),
                                      timestamp=timestamp_init,
                                      weight=0.5)

prior1 = copy(em_state_init)
prior = GaussianMixtureState([prior1, prior1])
track = Track([copy(prior)])
track_error = []

for i in range(1, len(detection_track)):
    measurement = detection_track[i]
    
    # State prediction
    prediction = em_imm_predictor.predict(track.state, timestamp=gt_time[i])
    # Measurement prediction
    meas_prediction = em_imm_updater.predict_measurement(prediction)
    # Association
    hyp = SingleHypothesis(prediction, measurement)
    # State update
    prior = em_imm_updater.update(hyp)
    track.append(prior)
    track_error.append(em_measurement_model_1.function(track[i].state_vector, np.zeros((2,1)))
                       - em_measurement_model_1.function(gt_track[i].state_vector, np.zeros((2,1))))

track_error = np.asarray(track_error)
track_error = np.squeeze(track_error)
rmse_tmp = []
for track_error_ in track_error:
    rmse_tmp.append(track_error_ @ track_error_.T)
rmse = np.sqrt(np.mean(np.array(rmse_tmp)))
print("RMSE: " + str(rmse))