In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

In [2]:
class KalmanFilter(object):
    def __init__(self, A = None, C = None, Gamma = None, Sigma = None, P = None, u0 = None, V0 = None,x=None):

        if(A is None or C is None):
            raise ValueError("Set proper system dynamics.")

        self.x = x # T x N
        self.M = A.shape[0] # dimension of hidden states
        self.T = self.x.shape[0] # number of observations
        self.N = self.x.shape[1] # number of dimension of the observations

        self.A = A # A is the transition probability matrix, M x M 
        self.C = C # C is the emission probability matrix, N x M
        
        self.Gamma = np.eye(self.M) if Gamma is None else Gamma # Gamma is the covariance matrix of noise term added to the hidden state transition, M x M
        self.Sigma = np.eye(self.M) if Sigma is None else Sigma # Sigma is the covariance matrix of noise term added to the emission, N x N
        
        self.P = np.zeros((self.T, self.M, self.M))
        self.P[:,:,] = np.eye(self.M) if P is None else P # P is an intermediate variable during inference, N x M x M
        self.u = np.zeros((self.T, self.M)) # T x M x 1
        self.V = np.zeros((self.T, self.M, self.M)) # T x M x M
        self.K = np.zeros((self.T, self.M, self.N)) # T x M x N
        self.c = np.zeros((self.T)) # T x 1

        # for backward passing
        self.u_hat = np.zeros((self.T, self.M)) # T x M x 1
        self.V_hat = np.zeros((self.T, self.M, self.M)) # T x M x M
        self.J = np.zeros((self.T, self.M, self.M)) # T x M x M

        self.u0 = u0 # u0 is the initial estimate of the mean of z1, M x 1
        self.V0 = V0 # V0 is the initial estimate of the variance of z1, M x M
        
        S_temp = np.matmul(np.matmul(self.C, self.V0), self.C.T) + self.Sigma
        Q_temp = np.matmul(self.C, self.u0)
        I = np.eye(self.M)

        self.V[0] = np.matmul((I - np.matmul(np.matmul(np.matmul(self.V0, self.C.T), np.linalg.inv(S_temp)), self.C)), self.V0)
        self.P[0] = np.matmul(np.matmul(self.A, self.V[0]), self.A.T) + self.Gamma
        self.K[0] = np.matmul(np.matmul(self.P[0], self.C.T), np.linalg.inv(S_temp))
        
        self.u[0] = self.u0 + np.matmul(self.K[0], self.x[0] - Q_temp)
        # self.c[0] = multivariate_normal.pdf(self.x[0], Q_temp, S_temp)
    
    def forward(self,i):
        # during inference, u[n], V[n], c[n] are calculated
        
        # if i < self.T:
        I = np.eye(self.M)
        Q_temp = np.matmul(np.matmul(self.C, self.A), self.u[i-1])
        
        self.V[i] = np.matmul((I - np.matmul(self.K[i-1], self.C)), self.P[i-1])
        self.P[i] = np.matmul(np.matmul(self.A, self.V[i]), self.A.T) + self.Gamma
        S_temp = np.matmul(np.matmul(self.C, self.P[i]), self.C.T) + self.Sigma

        self.K[i] = np.matmul(np.matmul(self.P[i], self.C.T), np.linalg.inv(S_temp))

        self.u[i] = np.matmul(self.A, self.u[i-1]) + np.matmul(self.K[i-1], self.x[i] - Q_temp)
        # print(f'The covariance matrix is: {S}')
        # print(f'The Sigma is {self.Sigma}, The Gamma is {self.Gamma}')
        # print(f'The A is {self.A}, The C is {self.C}, and P[{i-1}] is {self.P[i-1]}, K[{i}] is {self.K[i]}],V[{i-1}] is {self.V[i-1]}]')

        # self.c[i] = multivariate_normal.pdf(self.x[i], J, S)

    def backward(self,i):
        self.J[i] = np.matmul(np.matmul(self.V[i], self.A.T), np.linalg.inv(self.P[i]))
        self.u_hat[i] = self.u[i] + np.matmul(self.J[i], self.u_hat[i+1] - np.matmul(self.A, self.u[i]))
        self.V_hat[i] = self.V[i] + np.matmul(np.matmul(self.J[i], self.V_hat[i+1] - self.P[i]), self.J[i].T)
    
    def learning(self,M,N):
        self.u0 = self.u_hat[0]
        self.V0 = self.V_hat[0] + np.outer(self.u_hat[0], self.u_hat[0].T) - np.outer(self.u_hat[0], self.u_hat[0].T)

        # E[z[n]] : M x 1
        # E[z[n]z[n-1].T] : M x M
        # E[z[n]z[n].T] : M x M

        sub_1 = np.zeros((M,M))
        sub_1_alt = np.zeros((M,M))
        sub_2 = np.zeros((M,M))
        sub_2_alt = np.zeros((M,M))
        sub_3 = np.zeros((M,M))
        sub_4 = np.zeros((M,M))
        sub_5 = np.zeros((N,N))
        sub_5_alt = np.zeros((N,M))
        sub_6 = np.zeros((N,N))
        sub_6_alt = np.zeros((M,M))
        sub_7 = np.zeros((N,N))
        sub_8 = np.zeros((N,N))

        for i in range(1,self.T,1):
            sub_1 += np.matmul(np.matmul(self.V_hat[i],self.J[i-1].T) + np.outer(self.u_hat[i],self.u_hat[i-1].T), self.A.T) # z[n]z[n-1]
            
            sub_1_alt += np.matmul(self.V_hat[i],self.J[i-1].T) + np.outer(self.u_hat[i],self.u_hat[i-1].T)
            sub_2 += np.matmul(np.matmul(self.A, self.V_hat[i-1] + np.outer(self.u_hat[i-1], self.u_hat[i-1].T)), self.A.T) 
            sub_2_alt += self.V_hat[i-1] + np.outer(self.u_hat[i-1], self.u_hat[i-1].T) # z[n-1]z[n-1]


            sub_3 += self.V_hat[i] + np.outer(self.u_hat[i], self.u_hat[i].T) # z[n]z[n]
            sub_4 += np.matmul(self.A,(np.matmul(self.V_hat[i],self.J[i-1].T) + np.outer(self.u_hat[i],self.u_hat[i-1].T)).T) #z[n-1]z[n]

        for i in range(self.T):
            sub_5 += np.matmul(np.outer(self.x[i], self.u_hat[i].T), self.C.T) # x[n] * E[z[n]].T
            sub_5_alt += np.outer(self.x[i], self.u_hat[i].T)
            sub_6_alt += self.V_hat[i] + np.outer(self.u_hat[i], self.u_hat[i].T)
            sub_6 += np.matmul(np.matmul(self.C,self.V_hat[i] + np.outer(self.u_hat[i], self.u_hat[i].T)),self.C.T) # z[n]z[n]
            sub_7 += np.outer(self.x[i], self.x[i].T) # x[n]x[n]
            sub_8 += np.outer(np.matmul(self.C,self.u_hat[i]), self.x[i].T) #E[z[n]] * x[n].T 

        self.A = np.matmul(sub_1_alt, np.linalg.inv(sub_2_alt))
        self.Gamma = 1/(self.T-1) * (sub_3 - sub_4 - sub_1 + sub_2)
        # self.Gamma = 1/(self.N-1) * (sub_3 - np.matmul(self.A, sub_4) )
        
        self.C = np.matmul(sub_5_alt, np.linalg.inv(sub_6_alt))
        self.Sigma = 1/self.T * (sub_7 - sub_8 - sub_5 + sub_6)
        # self.Sigma = 1/self.N * (sub_7 - np.matmul(self.C, sub_8) )




In [9]:
def generate_examples(A, C, Gamma,Sigma,u0,V0,M,N,T):
 
    z = np.zeros((T,M))
    x = np.zeros((T,N))
    z[0] = np.random.multivariate_normal(u0,V0)
    # z[0] = np.array([23.0,24.0,25.0])
    x[0] = np.random.multivariate_normal(np.matmul(C,z[0]),Sigma)
    for t in range(1,T,1):
        z[t] = np.random.multivariate_normal(np.matmul(A,z[t-1]),Gamma)
        x[t] = np.random.multivariate_normal(np.matmul(C,z[t]),Sigma)
    return z,x


In [4]:
def main():
	
	n_states = 3 # M
	n_obs = 2 # N
	n_time = 100 # T
	p_old = -10000
	tol = 0.0001
	max_iter = 100

	# z: T x M
	# x : T x N
	# A = np.array([[0.9, 0.1],[0.5,0.5]])
	# C = np.array([[1, 0],[0.2, 0.8]])
	# Gamma = np.array([[0.1, 0.1], [0.1, 0.1]])
	# Sigma = np.array([[0.5,0.5],[0.5,0.5]])

	A = np.array([[0.75, 0.433, -0.5],[-0.217, 0.875, 0.433],[0.625, -0.217, 0.75]])
	Gamma = np.array([[1.5, 0.1, 0.0], [0.1, 2.0, 0.3], [0.0, 0.3, 1.0]])
	C = np.array([[1.0,1.0,0.0],[0.0,1.0,1.0]])
	Sigma = np.array([[1.0,0.2], [0.2,2.0]])

	u0 = np.array([1,2])
	V0 = np.array([[0.1,0.3],[0.3,0.1]])

	# A_init = np.array([[0.5, 0.5],[0.5,0.5]])
	# C_init = np.array([[0.5, 0.5],[0.5, 0.5]])
	# Gamma_init = np.array([[0.5, 0.9], [0.9, 4.5]])
	# Sigma_init = np.array([[0.5, 0.9], [0.9, 2.5]])
	# u0_init = np.array([1,2])
	# V0_init = np.array([[0.2,0.5],[0.5,0.4]])

	A_init = np.array([[1.0, 1.1, 1.2],[1.3, 1.4, 1.5],[1.6, 1.7, 1.8]])
	C_init = np.array([[1.0,1.0,1.0], [1.0, 1.0,1.0]])
	Gamma_init = np.array([[1.0, 0.5, 0.5], [0.5,1.0, 0.5],[0.5, 0.5, 1.0]])
	Sigma_init = np.array([[1.0,0.5], [0.5,1.0]])
	u0_init = np.array([10.0,10.0,10.0])
	V0_init = np.array([[1.0, 0.5, 0.5], [0.5,1.0, 0.5],[0.5, 0.5, 1.0]])


	z,x = generate_examples(A,C,Gamma,Sigma,u0,V0,n_states,n_obs,n_time)
	kf = KalmanFilter(A = A_init, C = C_init, Gamma = Gamma_init, Sigma = Sigma_init, u0=u0_init, V0=V0_init,x=x)
	
	for ite in range(max_iter):
		print(f'The current iteration is: {ite}')

		for t in range(1,kf.T,1):
			kf.forward(t)
		kf.u_hat[-1] = kf.u[-1]
		kf.V_hat[-1] = kf.V[-1]

		for t in range(kf.T-2,-1,-1):
			kf.backward(t)
		kf.learning(z.shape[1],x.shape[1])
		# p = np.sum(np.log(kf.c))
		# print(f'The likelihood is {p}')
		# if p>p_old and p - p_old < tol:
		# 	break
		# p_old = p
		S_temp = np.matmul(np.matmul(kf.C, kf.V0), kf.C.T) + kf.Sigma
		Q_temp = np.matmul(kf.C, kf.u0)
		I = np.eye(kf.M)
		kf.V[0] = np.matmul((I - np.matmul(np.matmul(np.matmul(kf.V0, kf.C.T), np.linalg.inv(S_temp)), kf.C)), kf.V0)
		kf.P[0] = np.matmul(np.matmul(kf.A, kf.V[0]), kf.A.T) + kf.Gamma
		kf.K[0] = np.matmul(np.matmul(kf.P[0], kf.C.T), np.linalg.inv(S_temp))
        
		kf.u[0] = kf.u0 + np.matmul(kf.K[0], kf.x[0] - Q_temp)
	print(kf.A,kf.C,kf.Gamma,kf.Sigma,kf.u0,kf.V0)
	return kf,z,x

kf,z,x = main()

The current iteration is: 0
The current iteration is: 1
The current iteration is: 2
The current iteration is: 3
The current iteration is: 4
The current iteration is: 5
The current iteration is: 6
The current iteration is: 7
The current iteration is: 8
The current iteration is: 9
The current iteration is: 10
The current iteration is: 11
The current iteration is: 12
The current iteration is: 13
The current iteration is: 14
The current iteration is: 15
The current iteration is: 16
The current iteration is: 17
The current iteration is: 18
The current iteration is: 19
The current iteration is: 20
The current iteration is: 21
The current iteration is: 22
The current iteration is: 23
The current iteration is: 24
The current iteration is: 25
The current iteration is: 26
The current iteration is: 27
The current iteration is: 28
The current iteration is: 29
The current iteration is: 30
The current iteration is: 31
The current iteration is: 32
The current iteration is: 33
The current iteration is

In [5]:
def is_symmetric_positive_semidefinite(matrix):
    if not np.allclose(matrix, matrix.T):
        return False  # Not symmetric
    eigenvalues = np.linalg.eigvals(matrix)
    return np.all(eigenvalues >= 0)

B = np.array([[35.62538664, 35.93218881],
 [35.93218881, 36.40284164]])
print(is_symmetric_positive_semidefinite(B)) 

True


In [10]:
z_sim,x_sim = generate_examples(kf.A,kf.C,kf.Gamma,kf.Sigma,kf.u0,kf.V0,kf.M,kf.N,kf.T)


In [11]:
print(x_sim)

[[53.33376555 37.72496621]
 [53.40651879 42.13647076]
 [34.78749383 56.67516233]
 [41.28236028 48.33494177]
 [49.39160633 38.89428216]
 [42.75951873 58.20069382]
 [50.69071956 56.50526704]
 [43.83716816 63.74638576]
 [58.89789592 46.28644553]
 [41.10930395 67.44664232]
 [62.0105482  50.55186965]
 [48.46824439 61.18416159]
 [41.52877475 71.1523303 ]
 [51.58214506 74.2226579 ]
 [61.00947297 64.04765823]
 [62.99360901 54.92897604]
 [57.69004052 54.04350087]
 [61.50915989 62.55558352]
 [55.52740723 66.08157237]
 [51.29951426 83.97303629]
 [67.49028317 52.33724024]
 [57.56702214 75.37337452]
 [78.42214865 51.81252899]
 [62.38159046 70.21496891]
 [70.14284323 51.97237941]
 [71.11152481 58.32306589]
 [48.25322348 72.04527403]
 [65.89080913 55.77179616]
 [37.91117171 83.95479088]
 [56.71857345 72.84146212]
 [69.43472668 60.9926026 ]
 [70.97522116 62.14984435]
 [69.16941862 56.21155924]
 [60.23620888 67.44736727]
 [82.06298984 47.15737381]
 [73.07114735 53.17175463]
 [59.74580186 71.27660158]
 

In [8]:
print(x)

[[46.09650138 47.25578821]
 [41.71613588 53.04199868]
 [41.68100982 57.0804984 ]
 [45.75529331 55.97437267]
 [51.27023074 49.82444082]
 [59.95926285 44.61991018]
 [58.54021563 40.22939766]
 [51.08910555 42.87023399]
 [44.4016467  50.97916512]
 [38.70401042 57.50184408]
 [38.76558054 51.77617302]
 [43.13239053 50.76773642]
 [53.73966411 42.54451579]
 [53.88217249 39.54975382]
 [53.57227973 44.51157718]
 [47.94230759 45.50863175]
 [43.41563196 52.54134326]
 [42.46970122 63.57532283]
 [45.82550241 55.08820464]
 [48.47720651 46.31084107]
 [56.66859218 41.80275531]
 [55.94980601 42.85884298]
 [48.57508837 43.31331638]
 [51.70845512 52.71253906]
 [46.67040725 56.87316539]
 [50.3262906  63.04010256]
 [56.48550367 58.9759535 ]
 [60.27302281 53.56644985]
 [66.63694896 53.42976379]
 [63.78739839 49.68368045]
 [57.68140163 55.10752071]
 [54.02297531 60.34041284]
 [53.36567807 66.66757092]
 [52.48641909 64.13874148]
 [55.71990177 63.81519489]
 [62.1021282  55.77480743]
 [64.79673103 50.24061906]
 