In [None]:
from typing import override
from filterpy.kalman import ExtendedKalmanFilter
import numpy as np
from scipy.spatial.transform import Rotation

def skew_symmetric(v):
    return (v - v.T) / 2

class FlightFilter:
    def __init__(self, x_nom, P, sigma_a_noise, sigma_w_noise, sigma_a_walk, sigma_w_walk):
        self.f = ExtendedKalmanFilter(dim_x=3*6, dim_z=3, dim_u=3*2)
        self.f.x = np.zeros(3*6)
        self.f.P = P
        self.sigma_a_noise = sigma_a_noise
        self.sigma_w_noise = sigma_w_noise
        self.sigma_a_walk = sigma_a_walk
        self.sigma_w_walk = sigma_w_walk
        self.x_nom = x_nom
    
    def get_rotation_matrix(self):
        q = self.x_nom[6:10]
        return Rotation.from_quat(q, scalar_first=True).as_matrix()
    
    def get_F_x(self, u, dt):
        a_m = u[0:3]
        w_m = u[3:6]
        a_b = self.x_nom[10:13]
        w_b = self.x_nom[13:16]
        R = self.get_rotation_matrix()
        F_x = np.eye(3*6)
        F_x[0:3,3:6] = np.eye(3)*dt

        F_x[3:6,6:9] = -R @ skew_symmetric(a_m-a_b)*dt
        F_x[3:6,9:12] = -R * dt
        F_x[3:6,15:18] = np.eye(3)*dt
        
        F_x[6:9,6:9] = Rotation.from_rotvec((w_m-w_b)*dt).as_matrix().T
        F_x[6:9,12:15] = -np.eye(3)*dt
        return F_x
    
    def predict(self, u, dt):
        V_i = np.eye(3)*self.sigma_a_noise*dt**2
        Theta_i = np.eye(3)*self.sigma_w_noise*dt**2
        A_i = np.eye(3)*self.sigma_a_walk*dt
        Omega_i = np.eye(3)*self.sigma_w_walk*dt
        F_i = np.vstack((
            np.zeros((1,3*4)),
            np.eye(3*4),
            np.zeros((1,3*4)),
            ))
        Q_i = np.diag([V_i, Theta_i, A_i, Omega_i])
        self.f.Q = F_i @ Q_i @ F_i.T
        F_x = self.get_F_x(u, dt)

        # update error state: x, P
        # no update to x, since it's the error state with mean of 0 always
        self.f.P = F_x @ self.f.P @ F_x.T + self.f.Q
        
        # update nominal state: x_nom
        # euler's approximation
        R = self.get_rotation_matrix()
        a_m = u[0:3]
        w_m = u[3:6]
        a_b = self.x_nom[10:13]
        w_b = self.x_nom[13:16]
        g = self.x_nom[16:19]
        v = self.x_nom[3:6]
        self.x_nom[0:3] += v*dt + 0.5*(R@(a_m-a_b)+g)*dt**2
        self.x_nom[3:6] += (R@(a_m-a_b)+g)*dt
        self.x_nom[6:10] = (Rotation.from_quat(self.x_nom[6:10], scalar_first=True) * Rotation.from_rotvec((w_m-w_b)*dt)).as_quat(scalar_first=True)

    def get_X_dx(self):
        qw, qx, qy, qz = self.x_nom[6:10]
        Q_dtheta = 0.5 * np.array([
            [-qx, -qy, -qz],
            [qw, -qz, qy],
            [qz, qw, -qx],
            [-qy, qx, qw]
        ])
        X_dx = np.zeros((19, 18))
        X_dx[0:6,0:6] = np.eye(6)
        X_dx[6:10,6:9] = Q_dtheta
        X_dx[10:19,9:18] = np.eye(9)
        return X_dx

    def update(self, h, z, R, H_x):
        H = H_x @ self.get_X_dx()
        y = z - h(self.x_nom)
        S = H @ self.f.P @ H.T + R
        K = self.f.P @ H.T @ np.linalg.inv(S)
        
        # update error state: x, P
        self.f.x = self.f.x + K @ y
        A = np.eye(18) - K @ H
        self.f.P = A @ self.f.P @ A.T + K @ R @ K.T
        
        # inject error state into nominal state
        self.x_nom[0:6] += self.f.x[0:6]
        self.x_nom[6:10] = (Rotation.from_quat(self.x_nom[6:10], scalar_first=True) * Rotation.from_rotvec(self.f.x[6:9])).as_quat(scalar_first=True)
        self.x_nom[10:19] += self.f.x[9:18]

        # reset error state to 0, adjust x, P to account for injection
        self.f.x = np.zeros(18)
        G = np.eye(18)
        G[6:9,6:9] -= skew_symmetric(0.5*self.f.x[6:9])
        self.f.P = G @ self.f.P @ G.T

