In [None]:
import numpy as np
import pr3_utils 
from pr3_torch_utils import *
from stereo import *
from tqdm import tqdm

In [None]:
dataset = "03"
time_window_length = 100
feature_type = "all"
feature_type = "selected"


time_stamp,features,linear_velocity,angular_velocity,k,b,imu_T_cam = \
    load_data_torch(f"../data/{dataset}.npz")

num_time_stamp = time_stamp.shape[0]
tau = (time_stamp[1:] - time_stamp[:-1]).to(dtype=DTYPE, device=DEVICE)
velocity = torch.concatenate([linear_velocity, angular_velocity], dim=1)

if feature_type == "all":
    M_init = np.load(f"../data/{dataset}_EKF_mapping_M_init_all.npy")
    num_features = features.shape[1]
    print("num_features", num_features)
else:
    # # apply selection mask and sample mask
    M_init = np.load(f"../data/{dataset}_EKF_mapping_M_init_selected.npy")
    M_selection_mask = np.load(f"../data/{dataset}_EKF_mapping_M_mask_selected.npy")
    M_init = M_init[M_selection_mask,:]
    features = features[:,M_selection_mask,:]
    num_features = features.shape[1]
    print("num_features", num_features)

In [None]:
cam_T_imu = inversePose(imu_T_cam)
fsu = k[0,0]
fsv = k[1,1]
cu  = k[0,2]
cv  = k[1,2]
Ks = torch.tensor([
    [fsu,0,cu,0],
    [0,fsv,cv,0],
    [fsu,0,cu,-fsu*b],
    [0,fsv,cv,0],
],dtype=DTYPE, device=DEVICE)
P = torch.tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,0],
],dtype=DTYPE, device=DEVICE)

In [None]:
# noise model
W = torch.diag(torch.tensor(
    [1e-8, 1e-8, 1e-5, 1e-5, 1e-5, 1e-3], dtype=DTYPE, device=DEVICE))
V = torch.tensor([
    [10, 5, 0, 0],
    [ 5,10, 0,15],
    [ 0, 0,10, 5],
    [ 0,15, 5,10],
], dtype=DTYPE, device=DEVICE)

# log list
log_M_covar_norm_t = []
log_M_covar_norm = []
log_T_predict_innovation = []
log_T_update_innovation = []

# init
T_mean = torch.zeros([num_time_stamp, 4, 4], 
                     dtype=DTYPE, device=DEVICE)
T_covar = torch.zeros([num_time_stamp, 6, 6], 
                     dtype=DTYPE, device=DEVICE)
M_mean = torch.from_numpy(M_init).to(dtype=DTYPE, device=DEVICE)
covar_flat = 5*torch.eye(num_features*3+6,
                         dtype=DTYPE, device=DEVICE)
covar_flat[-6:,-6:] = W
covar = covar_flat.view(num_features+2, 3, num_features+2, 3)
T_mean[0,:,:] = torch.tensor([
    [1, 0, 0,0],
    [0, 1, 0,0],
    [0, 0, 1,0],
    [0, 0, 0,1],
])
present_mask = torch.zeros([num_time_stamp, num_features], dtype=bool, device=DEVICE)

# EKF
bar = tqdm(range(1, num_time_stamp))
for t in bar:
    # predict step
    T_mean[t,:,:] = T_mean[t-1,:,:] @ twist2pose(tau[t-1]*axangle2twist(velocity[t]))
    F = twist2pose(-tau[t-1]*axangle2adtwist(velocity[t]))
    log_T_predict_innovation.append(torch.linalg.norm(velocity[t]).item())
    covar_flat[-6:,:-6] = F @ covar_flat[-6:,:-6]
    covar_flat[:-6,-6:] = covar_flat[-6:,:-6].T
    covar_flat[-6:,-6:] = F @ covar_flat[-6:,-6:] @ F.T + W
    T_covar[t,:,:] = covar_flat[-6:,-6:]
    
    # update step
    present_mask[t,:] = get_seeing_mask_torch(features[t,:,:], d_min=8, d_max=35)
    Nt = int(torch.sum(present_mask[t,:]))
    if Nt == 0:
        log_T_update_innovation.append(0)
        continue
    time_window_start = max(0, t - time_window_length)
    window_mask = torch.any(present_mask[time_window_start:t+1, :], dim=0)
    Mt = int(torch.sum(window_mask))
    bar.set_postfix({"Nt":Nt, "Mt":Mt})
    # compose observation and state
    present_M_mean = M_mean[present_mask[t,:], :]
    present_M_mean_homo = torch.hstack([present_M_mean, torch.ones([Nt,1],device=present_M_mean.device)])
    extend_window_mask = torch.concatenate([
        window_mask, 
        torch.tensor([True, True],device=DEVICE)])
    window_covar_flat = covar[extend_window_mask,:,:,:][:,:,extend_window_mask,:].view((Mt+2)*3, (Mt+2)*3) # this is a copy
    
    # build H
    H = torch.zeros([Nt, 4, Mt+2, 3], dtype=DTYPE, device=DEVICE)
    imu_T_world = inversePose(T_mean[t,:,:])
    cam_T_world = cam_T_imu @ imu_T_world
    Ks_pJ = Ks @ projectionJacobian(present_M_mean_homo @ cam_T_world.T)
    # δh(T, M) / δM
    H_M =  Ks_pJ @ cam_T_world @ P.T
    i_list = torch.where(present_mask[t,:][window_mask] == True)[0].detach().cpu().tolist()
    for j, i in enumerate(i_list):
        H[j,:,i,:] = H_M[j]
    # δh(T, M) / δT
    H[:,:,-2:,:] = (
        -Ks_pJ @ cam_T_imu @ odot(present_M_mean_homo @ imu_T_world.T)
    ).reshape(Nt, 4, 2, 3)

    H = H.reshape([Nt*4, (Mt+2)*3])

    # Kalman gain and innovation 
    K = window_covar_flat @ H.T @ torch.linalg.inv(H @ window_covar_flat @ H.T + torch.block_diag(*([V] * Nt)))

    innovation = K @ (
        features[t,present_mask[t,:],:] - (projection(present_M_mean_homo @ cam_T_world.T)@Ks.T)
    ).reshape(-1)
    if torch.isnan(innovation).any():
        print("nan at", t)
        break
    innovation_M = innovation[:Mt*3]
    innovation_T = innovation[Mt*3:]

    # update
    # if torch.linalg.norm(innovation_T).item() < 2.5:
    M_mean[window_mask] += innovation_M.reshape(-1, 3)
    T_mean[t,:,:] = T_mean[t,:,:] @ axangle2pose(innovation_T)
    log_T_update_innovation.append(torch.linalg.norm(innovation_T).item())
    covar[extend_window_mask,:,:,:][:,:,extend_window_mask,:] = (
        (torch.eye((Mt+2)*3, dtype=DTYPE, device=DEVICE) - K@H) @ window_covar_flat
    ).reshape((Mt+2), 3, (Mt+2), 3)
    # log covariance
    if t%50 == 0:
        # covar_flat = (covar_flat + covar_flat.T)/2
        # covar = covar_flat.view(num_features+2, 3, num_features+2, 3)
        log_M_covar_norm_t.append(time_stamp[t].item())
        log_M_covar_norm.append(torch.trace(covar_flat).item())
    # break

In [None]:
fig,ax = pr3_utils.visualize_trajectory("EKF_SLAM_time_window", T_mean.detach().cpu().numpy(), show_ori=False)

plot_bound = 200
x_min, x_max = torch.min(T_mean[:,0,3]), torch.max(T_mean[:,0,3])
y_min, y_max = torch.min(T_mean[:,1,3]), torch.max(T_mean[:,1,3])
plot_mask = (x_min-plot_bound < M_mean[:,0]) & (M_mean[:,0] < x_max+plot_bound) & \
            (y_min-plot_bound < M_mean[:,1]) & (M_mean[:,1] < y_max+plot_bound)
ax.scatter(M_mean[plot_mask, 0].cpu(), M_mean[plot_mask, 1].cpu(), 
           s=0.1, c='C4',label=f"features({feature_type})")

ax.legend()
fig.savefig(f"../img/{dataset}_EKF_SLAM_TW_{feature_type}", dpi=300)
plt.plot()

In [None]:
T = np.load(f"../data/{dataset}_EKF_localization_T_mean.npy")
x = T[:, 0, 3]
y = T[:, 1, 3]
fig,ax = pr3_utils.visualize_trajectory("EKF_SLAM_time_window", 
                                        T_mean=T_mean.detach().cpu().numpy(), 
                                        T_covar=T_covar.detach().cpu().numpy(), 
                                        show_ori=False, show_var = True)
ax.plot(x,y,label="EKF_localization", c="C2", ls=":")
ax.scatter(x[ 0],y[0],marker='s',c="C1")
ax.scatter(x[-1],y[-1],marker='o',c="C1")
ax.legend()
fig.savefig(f"../img/{dataset}_EKF_SLAM_TW_{feature_type}_trj", dpi=300)
plt.plot()

In [None]:
plt.plot(time_stamp[1:].cpu(), log_T_predict_innovation, label="predict")
plt.plot(time_stamp[1:].cpu(), log_T_update_innovation, label="update")
plt.xlabel("time stamp")
plt.ylabel("|T innovation|")
plt.legend()
plt.savefig(f"../img/{dataset}_EKF_SLAM_TW_{feature_type}_innovation", dpi=300)
plt.show()

In [None]:
plt.plot(log_M_covar_norm_t, log_M_covar_norm)
plt.scatter(log_M_covar_norm_t, log_M_covar_norm)
plt.xlabel("time stamp")
plt.ylabel("tr(covar)")
plt.savefig(f"../img/{dataset}_EKF_SLAM_TW_{feature_type}_covar_tr", dpi=300)
plt.show()

In [None]:
plt.imshow(covar_flat[-6:,-6:].cpu().numpy())
plt.savefig(f"../img/{dataset}_EKF_SLAM_TW_{feature_type}_T_covar", dpi=300)
plt.show()