In [None]:
import matplotlib.pyplot as plt
import numpy as np

%load_ext autoreload
%aimport kalman_filter

In [None]:
def fquad(A, B, C, tvals):
    return A + B * tvals + C * tvals**2


def fquad_prime(A, B, C, tvals):
    return B + 2 * C * tvals


In [None]:
def make_column(fn, x0, xL, xQ, xstd, tvals):
    return fn(x0, xL, xQ, tvals) + np.random.randn(tvals.shape[0]) * xstd


In [None]:
def make_data(timestart, timestop, deltat, tstd, x0, xL, xQ, xstd, y0, yL, yQ, ystd):
    tvals = np.arange(timestart, timestop, deltat)
    xvals = make_column(fquad, x0, xL, xQ, xstd, tvals)
    yvals = make_column(fquad, y0, yL, yQ, ystd, tvals)
    tvals += tvals + np.abs(np.random.randn(tvals.shape[0])) * tstd
    xprime = make_column(fquad_prime, x0, xL, xQ, 0, tvals)
    yprime = make_column(fquad_prime, y0, yL, yQ, 0, tvals)
    return np.concatenate([xvals, yvals, tvals, xprime, yprime]).reshape(5, -1).T


In [None]:
timestart = 0
timestop = 10
deltat = 0.25
tstd = 0.1 * deltat

x0 = 200
xL = 5
xQ = 2.5
xstd = 0.1

y0 = 600
yL = -2.5
yQ = 1.2
ystd = 0.1


dataset = make_data(
    timestart, timestop, deltat, tstd, x0, xL, xQ, xstd, y0, yL, yQ, ystd
)


In [None]:
measurement_error = np.array(
    [
        xstd**2,
        ystd**2,
        np.sqrt(xstd**2 + tstd**2),
        np.sqrt(ystd**2 + tstd**2),
    ]
)

initial_state = np.zeros(3)

state_history = np.array([])
# kf = kalman_filter.Kalman2D(measurement_error)
# you can neglect to estimate the measurement error and still be okay
kf = kalman_filter.Kalman2D()
for datum in dataset:
    kf.update(datum[:3])

    current_state = np.concatenate(
        [kf.get_position(), kf.get_time(), kf.get_velocity()]
    )

    # print(
    #     f"{kf.get_time().shape}, {kf.get_position().shape}, {kf.get_velocity().shape} {current_state.shape} {state_history.shape}"
    # )

    if state_history.shape[0] == 0:
        state_history = current_state
    else:
        state_history = np.concatenate([state_history, current_state])

state_history = state_history.reshape(-1, 5)


In [None]:
plt.plot(dataset[:, 2], dataset[:, 0], "o", label="x input")
plt.plot(state_history[:, 2], state_history[:, 0], "+", label="x fit")
plt.plot(dataset[:, 2], dataset[:, 1], "2", label="y input")
plt.plot(state_history[:, 2], state_history[:, 1], "x", label="y fit")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
plt.plot(dataset[:, 0], dataset[:, 1], "3", label="input")
plt.plot(state_history[:, 0], state_history[:, 1], "4", label="fit")
plt.grid(True)
plt.xlim(100, 600)
plt.ylim(400, 700)
plt.legend()
plt.show()


In [None]:
plt.plot(dataset[:, 2], dataset[:, 3], "o", label="x velocity input")
plt.plot(state_history[:, 2], state_history[:, 3], "+", label="x velocity fit")
plt.plot(dataset[:, 2], dataset[:, 4], "2", label="y velocity input")
plt.plot(state_history[:, 2], state_history[:, 4], "x", label="y velocity fit")
plt.legend()
plt.grid(True)
plt.show()
