Skip to content

qq751220449/Kalman_Filter_Numpy

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 

Repository files navigation

Implementation of Kalman filter in 30 lines using Numpy.

原始链接:https://github.com/zziz/kalman-filter

修改其中参数不合理的地方,整体并未做太大改变,具体参考代码。

Running: python kalman-filter.py

import numpy as np


class KalmanFilter(object):

    def __init__(self, F=None, B=None, H=None, Q=None, R=None, P=None, x0=None):

        if (F is None or H is None):
            raise ValueError("Set proper system dynamics.")

        self.n = F.shape[1]
        self.m = H.shape[0]

        self.F = F              # 状态转移矩阵 大小 n*n  n表示选取的状态变量
        self.H = H              # 测量数据的大小 大小为 m*n  m是测量变量的大小
        self.B = 0 if B is None else B                          # 输入矩阵 n*k
        self.Q = np.eye(self.n) if Q is None else Q             # 系统噪声协方差矩阵 n*n
        self.R = np.eye(self.m) if R is None else R             # 测量噪声协方差矩阵 大小跟测量变量有关系 m*m
        self.P = np.eye(self.n) if P is None else P             # 状态变量误差协方差矩阵  大小 n*n
        self.x = np.zeros((self.n, 1)) if x0 is None else x0    # 状态变量 大小 n*1

    def predict(self, u=0):
        self.x = np.dot(self.F, self.x) + np.dot(self.B, u)
        self.P = np.dot(np.dot(self.F, self.P), self.F.T) + self.Q
        return self.x

    def update(self, z):
        y = z - np.dot(self.H, self.x)
        S = self.R + np.dot(self.H, np.dot(self.P, self.H.T))
        K = np.dot(np.dot(self.P, self.H.T), np.linalg.inv(S))
        self.x = self.x + np.dot(K, y)                          # 计算当前时刻的估计值
        I = np.eye(self.n)
        self.P = np.dot(I - np.dot(K, self.H), self.P)
        return self.x


def example():
    dt = 1.0 / 60
    F = np.array([[1, dt, 0], [0, 1, dt], [0, 0, 1]])
    H = np.array([1, 0, 0]).reshape(1, 3)                   # 这里测量变量只有一个
    Q = np.array([[0.05, 0.05, 0.0], [0.05, 0.05, 0.0], [0.0, 0.0, 0.0]])
    R = np.array([0.5]).reshape(1, 1)

    x = np.linspace(-10, 10, 1000)
    measurements = - (x ** 2 + 2 * x - 2) + np.random.normal(0, 2, 1000)     # 模拟测量值z
    src_array = - (x ** 2 + 2 * x - 2)

    kf = KalmanFilter(F=F, H=H, Q=Q, R=R)
    predictions = []

    for z in measurements:
        kf.predict()
        predictions.append(np.dot(kf.H, kf.update(z))[0])

    import matplotlib.pyplot as plt
    plt.plot(range(len(measurements)), measurements, label='Measurements')
    plt.plot(range(len(src_array)), np.array(src_array), label='src single')
    plt.plot(range(len(predictions)), np.array(predictions), label='Kalman Filter Correction')
    plt.legend()
    plt.show()


if __name__ == '__main__':
    example()

Output

Result

Result

可以看出跟踪时间越久,噪声干扰越小。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages