In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
from pykalman.datasets import load_robot
from pykalman import KalmanFilter

In [None]:
# Import data, data format should contain:fitx, fity
import scipy.io as sio
data = sio.loadmat('/content/drive/MyDrive/QPG Research/Velion/SuperRes Data/fitDrift_1ptKalman.mat')

In [None]:
def gaussian(x, sigma):
  return np.exp(- x**2/(2*sigma**2))/np.sqrt(2*np.pi*sigma**2)

In [None]:
# Extract the first column of 'fitx' from the data as x_data
x_data = data['fitx'][:, 0]  # 'fitx' likely contains position data
x = x_data  # Assigning to x for simplicity

# Create a time series array based on the length of x
ts = np.arange(len(x))

# Set a random state for reproducibility
random_state = np.random.RandomState(0)

# Define the transition matrix for the Kalman Filter
# This matrix models the state transition. Here, it's a simple model where
# the position depends on the previous position and velocity.
transition_matrix = [[1, 1],
                     [0, 1]]

# Define transition offsets (control inputs). Set to zero if there's no control input
transition_offset = [0, 0]

# Define the observation matrix
# This matrix maps the true state space into the observed space.
# Here, it's an identity matrix, meaning we directly observe the state.
observation_matrix = np.eye(2)

# Define observation offsets. Set to zero if there's no control input
observation_offset = [0, 0]

# Define the transition covariance matrix
# This represents the uncertainty in the state transition.
# A small value implies almost no random acceleration.
transition_covariance = np.eye(2) * 1e-5

# Define the observation covariance matrix
# This represents the uncertainty in the observations.
observation_covariance = np.eye(2)

# Define the initial state mean
initial_state_mean = [0, 0]

# Define the initial state covariance
initial_state_covariance = np.eye(2)

# Initialize the Kalman Filter with the specified parameters
kf = KalmanFilter(
    transition_matrices=transition_matrix,
    observation_matrices=observation_matrix,
    transition_covariance=transition_covariance,
    observation_covariance=observation_covariance,
    transition_offsets=transition_offset,
    observation_offsets=observation_offset,
    initial_state_mean=initial_state_mean,
    initial_state_covariance=initial_state_covariance,
    em_vars=['initial_state_mean', 'initial_state_covariance', 'observation_covariance']
)

# Prepare the observations for the Kalman Filter
# Create a 2D array where the first column is position and the second is velocity
observations = np.zeros((len(x), 2))
observations[:, 0] = x  # Position measurements
observations[1:, 1] = np.diff(x)  # Velocity measurements (difference between positions)

# Handle any NaN values in observations by masking them
observations[np.isnan(observations)] = np.ma.masked

# Initialize an array to store log-likelihoods for each EM iteration
loglikelihoods = np.zeros(5)

# Perform Expectation-Maximization (EM) to learn the Kalman Filter parameters
for i in range(len(loglikelihoods)):
    kf = kf.em(X=observations, n_iter=1)  # Update parameters iteratively

# After training, apply the Kalman Filter to smooth the state estimates
smoothed_state_estimates = kf.smooth(observations)[0]

# Create a 2x2 subplot for visualization
fig, axs = plt.subplots(2, 2, figsize=(12, 8))

# Define font styles for labels and ticks
labelfont = {'fontsize': 18, 'fontname': 'Times New Roman'}
tickfont = {'fontsize': 15, 'fontname': 'Times New Roman'}

# Plot the measured position over time in the first subplot (top-left)
axs[0, 0].plot(ts, x, '.', label='Measured', color=plt.cm.ocean(0.8))
# Optionally, you can plot the true position if available
# axs[0].plot(ts, x_truth, label='True', color=plt.cm.ocean(0.4))
axs[0, 0].legend()
axs[0, 0].set_xlabel('Time (frames)', labelfont)
axs[0, 0].set_ylabel('X (nm)', labelfont)

# Plot the histogram of the original position data in the bottom-left subplot
axs[1, 0].hist(x, bins=50, edgecolor='k', fc=plt.cm.ocean(0.7), density=True)
axs[1, 0].set_xlabel('Position before detrending (nm)', labelfont)
axs[1, 0].set_ylabel('PDF', labelfont)
axs[1, 0].tick_params(labelsize=15)

# Plot the smoothed position estimates over time in the top-right subplot
axs[0, 1].plot(ts, smoothed_state_estimates[:, 0], label='Filtered', color=plt.cm.ocean(0.4))
# Optionally, plot the true position if available
# axs[1].plot(ts, x_truth, label='True', color=plt.cm.ocean(0.4))

# Match the x and y limits with the first subplot for consistency
xlims = axs[0, 0].get_xlim()
ylims = axs[0, 0].get_ylim()
axs[0, 1].set_xlim(xlims)
axs[0, 1].set_xlabel('Time (frames)', fontsize=15)
axs[0, 1].set_ylabel('Smoothed X (nm)')
axs[0, 1].set_ylim(ylims)

# Calculate the standard deviation of the original position data
sigma_x_data = np.nanstd(x)

# Create a range of x values for plotting the Gaussian fit
xplot = np.linspace(-3, 3, 200)

# Plot the Gaussian fit of the original data in the bottom-left subplot
axs[1, 0].plot(xplot, gaussian(xplot, sigma_x_data), lw=2, color='k', label='Fitted data distribution')

# Plot the histogram of detrended positions in the bottom-right subplot
detrended_x = x - smoothed_state_estimates[:, 0]
axs[1, 1].hist(detrended_x, bins=50, edgecolor='k', fc=plt.cm.ocean(0.7), density=True)
axs[1, 1].set_xlabel('Detrended position (nm)', labelfont)
axs[1, 1].set_ylabel('PDF', labelfont)
axs[1, 1].tick_params(labelsize=15)

# Calculate the standard deviation from the Kalman Filter's observation covariance
sigma_x_kf = np.sqrt(kf.observation_covariance[0, 0])

# Plot the Gaussian fit of the detrended data based on KF's covariance
axs[1, 1].plot(xplot, gaussian(xplot, sigma_x_kf), lw=2, color='k', label='KF-EM distribution')

# Adjust the y-axis limits to accommodate all plots
ylims = [0, max(max(axs[1, 1].get_ylim()), max(axs[1, 0].get_ylim()))]
axs[1, 0].set_xlim([-2, 2])
axs[1, 1].set_xlim([-2, 2])
axs[1, 0].set_ylim(ylims)
axs[1, 1].set_ylim(ylims)

# Optionally, print the observation covariance matrix learned by the KF
print(kf.observation_covariance)

# Display the plots
plt.show()