In [None]:
import os
import numpy as np
import pandas as pd
import pickle as pkl
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from scipy.spatial.transform import Rotation as R
from keypoint_util import KeypointSelector


def attach_grip_states(df, grip):
    new_df = df.copy()
    new_df['action'] = None
    for i, row in grip.iterrows():
        check_time = row['Timestamp']
        new_df.iloc[(df['Time'] - check_time).abs().argsort()[:1], -1] = row['Gripper state']
    return new_df

def extract_points(dfs, period, backward=False):
    max_lens = [len(df) for df in dfs]
    frame_periods, active_demos = [], []
    for t in range(max(max_lens)):
        if t > 0 and t%period==0:
            if backward:
                holder = [df.iloc[max_lens[s] - t:max_lens[s] - t + period].to_numpy() for s, df in enumerate(dfs) if t < len(df)]
            else:
                holder = [df.iloc[t-period:t].to_numpy() for df in dfs if t < len(df)]
            active_demos.append(len(holder))
            frame_periods.append(np.concatenate(holder))
    return frame_periods, active_demos

def get_mean_cov(pts_lst, rho=1, shrink=False):
    mean_lst, cov_mats = [], []
    for period, pts in enumerate(pts_lst):
        cov_mats.append(np.cov(pts, rowvar=False) + np.diag(np.full(pts.shape[1], rho)))
#         multipliers.append(np.trace(cov_mats[-1])/np.trace(cov_mats[-2]))
#         if shrink and np.trace(cov_mats[-1]):
        mean_lst.append(np.mean(pts, axis=0))
    return mean_lst, cov_mats

def get_mean_cov_hats(ref_means, ref_covs, min_len=None):
    sigma_hats, ref_pts = [], len(ref_means)
    if not min_len:
        min_len = min([len(r) for r in ref_means]) 
    # solve for global covariance
    for p in range(min_len):
        covs = [cov[p] for cov in ref_covs]
        inv_sum = np.linalg.inv(covs[0])
        for ref in range(1, ref_pts):
            inv_sum = inv_sum + np.linalg.inv(covs[ref])      
        sigma_hat = np.linalg.inv(inv_sum)
        sigma_hats.append(sigma_hat)
    
    mean_hats = []
    for p in range(min_len):
        mean_w_sum = np.matmul(np.linalg.inv(ref_covs[0][p]), ref_means[0][p]) 
        for ref in range(1, ref_pts):
            mean_w_sum = mean_w_sum + np.matmul(np.linalg.inv(ref_covs[ref][p]), ref_means[ref][p])
        mean_hats.append(np.matmul(sigma_hats[p], mean_w_sum))
    return np.array(mean_hats), np.array(sigma_hats)

folder = os.path.join('ndi_servo', '2022-04-13-morning')
instances = sorted(os.listdir(folder))
all_dfs = []
for instance in instances:
    instance_folder = os.path.join(folder, instance)
    file_set = sorted(os.listdir(instance_folder))[:2]
    file1, file2 = file_set
    # with open(os.path.join(data_folder, data497, 'labeled_actions_new.pkl'), 'rb') as f:
    #     label_actions = pkl.load(f)
    # trajectory_file = os.path.join(data_folder, data497, 'obj_pose_trajectory_interpolated_median.h5')
    # trajectory_dataset = pd.read_hdf(trajectory_file)
    grip_file = os.path.join(instance_folder, file1)
    state_file = os.path.join(instance_folder, file2)
    grip_states = pd.read_csv(grip_file)
    pos_states = pd.read_csv(state_file, index_col='Unnamed: 0')
    dframe = attach_grip_states(pos_states, grip_states)
    dframe.loc[:,['Rx', 'Ry', 'Rz']]
    all_dfs.append(dframe)

In [None]:
instances

In [None]:
# For drawing axis
# for df in all_dfs:
#     xax = [1,0,0]
#     r = R.from_rotvec(df.loc[:,['Rx', 'Ry', 'Rz']])
#     df.loc[:, ['xu','xv','xw']] = r.apply(xax)

# Setup key point search parameters
selector = KeypointSelector(16, 8, 0.07, 0.05, rotation_max=.7, 
                            velocity_window_size=5, rotation_window_size=5)

In [None]:
task_dfs = []
for f, df in enumerate(all_dfs):
    cond_df, keypoints, vel_df, rot_df = selector.search(all_dfs[f])
    grasp_indices = cond_df.index[cond_df['condition'] == 'grasp'].tolist()
    task_periods = [grasp_indices[i*2:i*2+2] for i in range(int(len(grasp_indices)/2))]

#     Set start and end frames
    start, end = task_periods[1]
    save_df = cond_df.iloc[start:end].copy().reset_index(drop=True)
    task_dfs.append(save_df)

In [None]:
period = 10
demo_start_pts, demo_end_pts = [], []
task_start_dfs, task_end_dfs = [], [] 
for df in task_dfs:
    start_pt = df[['x','y','z']].iloc[0].tolist()
    end_pt = df[['x','y','z']].iloc[-1].tolist()
    demo_start_pts.append(start_pt)
    demo_end_pts.append(end_pt)
    task_start_dfs.append((df[['x','y','z']] - start_pt))
    task_end_dfs.append((df[['x','y','z']] - end_pt))

# pickout every interval
start_checkpts, snum_act_demos = extract_points(task_start_dfs, period=period)
end_checkpts, enum_act_demos = extract_points(task_end_dfs, period=period, backward=True)

# get mean and covariance matrices
start_mean, start_cov = get_mean_cov(start_checkpts, shrink=False)
end_mean, end_cov = get_mean_cov(end_checkpts, shrink=True)

In [None]:
from statistics import mean
# Extrapolate forward and backward
multipliers = []
for s, num in enumerate(snum_act_demos):
    if num < 3:
        scaling = mean(multipliers) - 1
        dim = start_cov[s].shape[1]
        start_cov[s] = start_cov[s-1] + np.diag(np.full(dim, scaling*np.trace(start_cov[s-1])/dim))
    elif s > 0: 
        multipliers.append(np.trace(start_cov[s])/np.trace(start_cov[s-1]))

multipliers = []     
for s, num in enumerate(enum_act_demos):
    if num < 3:
        scaling = mean(multipliers) - 1
        dim = end_cov[s].shape[1]
        end_cov[s] = end_cov[s-1] + np.diag(np.full(dim, scaling*np.trace(end_cov[s-1])/dim))
    elif s > 0: multipliers.append(np.trace(end_cov[s])/np.trace(end_cov[s-1]))
end_cov.reverse()
end_mean.reverse()

In [None]:
# starting reference point
demo_data = []
for df in task_start_dfs:
    fig_line = px.line_3d(df, x='x', y='y', z='z')
    demo_data.append(fig_line.data)
start_mean_df = pd.DataFrame(start_mean, columns=['x', 'y', 'z'])
    
data_cat = demo_data[0]
for i in range(1, len(demo_data)):
    data_cat = data_cat + demo_data[i]
fig_main= go.Figure(data=data_cat)
fig_main.update_traces(marker=dict(size=5))
fig_main.show()


In [None]:
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms

def ellipse(mean, cov, dim0, dim1, ax, n_std=1, facecolor='none', **kwargs):
    pearson = cov[dim0, dim1]/np.sqrt(cov[dim0, dim0] * cov[dim1, dim1])
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, **kwargs)

    # Calculating the stdandard deviation of x from
    # the squareroot of the variance and multiplying
    # with the given number of standard deviations.
    scale_x = np.sqrt(cov[dim0, dim0]) * n_std
    mean_x = mean[dim0]

    # calculating the stdandard deviation of y ...
    scale_y = np.sqrt(cov[dim1, dim1]) * n_std
    mean_y = mean[dim1]

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean_x, mean_y)

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)


In [None]:
fig, ax0 = plt.subplots()
fig.set_figheight(10)
fig.set_figwidth(10)
ax0.plot(0,0, color='blue', linestyle="dashed", alpha=0.2)
for c in start_cov[:120]:
    ellipse([0,0,0], c, 2, 1, ax0, n_std=1, edgecolor='red')
plt.show()

In [None]:
# end reference point


def plot_start_end(frame, n):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4)
    fig.set_figheight(14)
    fig.set_figwidth(10)

    # start point view
    cur_pts_start = start_checkpts[frame]
    for df in task_start_dfs: 
        ax1.plot(df['z'],df['y'], color='blue', linestyle="dashed", alpha=0.2)
        ax1.plot(df['z'].iloc[-1], df['y'].iloc[-1],'b+', color='black')
    ax1.plot(df['z'].iloc[0], df['y'].iloc[0],'x', color='purple')
    ax1.plot(cur_pts_start[:,2], cur_pts_start[:,1],'o', color='green')
    ellipse(start_mean[frame], start_cov[frame], 2, 1, ax1, n_std=n, edgecolor='red')
    ax1.set_title('Start Point Local Frame')
    ax1.set_xlabel('z')
    ax1.set_ylabel('y')

    for df in task_start_dfs: 
        ax2.plot(df['z'],df['x'], color='blue', linestyle="dashed", alpha=0.2)
        ax2.plot(df['z'].iloc[-1],df['x'].iloc[-1],'b+', color='black')
    ax2.plot(df['z'].iloc[0],df['x'].iloc[0],'x', color='purple')
    ax2.plot(cur_pts_start[:,2], cur_pts_start[:,0],'o', color='green')
    ellipse(start_mean[frame], start_cov[frame], 2, 0, ax2, n_std=n, edgecolor='red')
    ax2.set_ylabel('x')
    ax2.invert_yaxis()

    # End point view
    cur_pts_end = end_checkpts[len(end_checkpts)-frame-1]
    for df in task_end_dfs: 
        ax3.plot(df['z'],df['y'], color='blue', linestyle="dashed", alpha=0.2)
        ax3.plot(df['z'].iloc[0], df['y'].iloc[0],'x', color='purple')
    ax3.plot(df['z'].iloc[-1], df['y'].iloc[-1],'b+', color='black')
    ax3.plot(cur_pts_end[:,2], cur_pts_end[:,1],'o', color='green')
    ellipse(end_mean[frame], end_cov[frame], 2, 1, ax3, n_std=n, edgecolor='red')
    ax3.set_title('End Point Local Frame')
    ax3.set_ylabel('y')

    for df in task_end_dfs: 
        ax4.plot(df['z'],df['x'], color='blue', linestyle="dashed", alpha=0.2)
        ax4.plot(df['z'].iloc[0],df['x'].iloc[0],'x', color='purple')
    ax4.plot(df['z'].iloc[-1],df['x'].iloc[-1],'b+', color='black')
    ax4.plot(cur_pts_end[:,2], cur_pts_end[:,0],'o', color='green')
    ellipse(end_mean[frame], end_cov[frame], 2, 0, ax4, n_std=n, edgecolor='red')
    ax4.set_xlabel('z')
    ax4.set_ylabel('x')
    ax4.invert_yaxis()
    plt.show()

frame = 5
n = 2
plot_start_end(frame, n)

In [None]:
inst = 0
t_start_pt = demo_start_pts[inst]
t_end_pt = demo_end_pts[inst]
glob_start_mean = [row + t_start_pt for row in start_mean]
glob_end_mean = [row + t_end_pt for row in end_mean]
opt_mean, opt_cov = get_mean_cov_hats([glob_start_mean, glob_end_mean], [start_cov, end_cov])    

frame = 8
n =1
fig, (ax5, ax6) = plt.subplots(2)
fig.set_figheight(10)
fig.set_figwidth(12)

# start point view
cur_pts_start = start_checkpts[frame]
ax5.plot(t_end_pt[2], t_end_pt[1],'b+', color='black')
ax5.plot(t_start_pt[2], t_start_pt[1],'x', color='purple')
ellipse(glob_start_mean[frame], start_cov[frame], 2, 1, ax5, n_std=n, edgecolor='yellow')
ellipse(glob_end_mean[frame], end_cov[frame], 2, 1, ax5, n_std=n, edgecolor='orange')
ellipse(opt_mean[frame], opt_cov[frame], 2, 1, ax5, n_std=1, edgecolor='red', facecolor='red')
ax5.set_title('Testing Global Frame')
ax5.set_xlabel('z')
ax5.set_ylabel('y')

ax6.plot(t_end_pt[2], t_end_pt[0],'b+', color='black')
ax6.plot(t_start_pt[2], t_start_pt[0],'x', color='purple')
ellipse(glob_start_mean[frame], start_cov[frame], 2, 0, ax6, n_std=n, edgecolor='yellow')
ellipse(glob_end_mean[frame], end_cov[frame], 2, 0, ax6, n_std=n, edgecolor='orange')
ellipse(opt_mean[frame], opt_cov[frame], 2, 0, ax6, n_std=1, edgecolor='red', facecolor='red')
ax6.set_ylabel('x')
ax6.invert_yaxis()

In [None]:
opt_mean.shape

In [None]:
fig, (ax7, ax8) = plt.subplots(2)
fig.set_figheight(10)
fig.set_figwidth(12)

# start point view
cur_pts_start = start_checkpts[frame]

ax7.plot(task_dfs[inst]['z'],task_dfs[inst]['y'], color='blue', linestyle="dashed", alpha=0.2)
ax7.plot(opt_mean[:,2], opt_mean[:,1], color='red', alpha=0.5)
ax7.scatter(opt_mean[:,2], opt_mean[:,1], color='red', alpha=0.5)
ax7.plot(t_end_pt[2], t_end_pt[1],'b+', color='black')
ax7.plot(t_start_pt[2], t_start_pt[1],'x', color='purple')
ax7.set_title('Testing Global Frame')
ax7.set_xlabel('z')
ax7.set_ylabel('y')

ax8.plot(task_dfs[inst]['z'],task_dfs[inst]['x'], color='blue', linestyle="dashed", alpha=0.2)
ax8.plot(opt_mean[:,2], opt_mean[:,0], color='red', alpha=0.5)
ax8.scatter(opt_mean[:,2], opt_mean[:,0], color='red', alpha=0.5)
ax8.plot(t_end_pt[2], t_end_pt[0],'b+', color='black')
ax8.plot(t_start_pt[2], t_start_pt[0],'x', color='purple')
ax8.set_ylabel('x')
ax8.invert_yaxis()