In [None]:
import os
import yaml
import pickle
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib
from statistics import mean
from plot import remove_repetitive_labels
from torch.utils.data import DataLoader, Dataset

from transformer_model import *
from process_data import Task_data
from outlier_detection import detect_outlier
from transformations import *


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Cuda available: ", torch.cuda.is_available())
# print("Current device: ", torch.cuda.current_device())

# Load data
task_config_dir = '../Process_data/postprocessed/2022-10-06'
with open(os.path.join(task_config_dir, 'task_config.yaml')) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
project_dir = config["project_path"] # Modify this to your need
base_dir = os.path.join(project_dir, config["postprocessed_dir"])
triangulation = 'dlc3d'
template_dir = os.path.join(project_dir, config["postprocessed_dir"],f'transformations/{triangulation}')
individuals = config["individuals"] # The objects that we will place a reference frame on
objs = config["objects"]
data = Task_data(base_dir, triangulation, individuals)
data.dims = ['x', 'y', 'z']

n_actions = data.get_number_of_actions()
gripper_trajs_truncated = data.get_gripper_trajectories_for_each_action()

with open(os.path.join(base_dir, 'processed', triangulation, 'gripper_trajs_in_obj_aligned_filtered.pickle',), 'rb') as f1:
    gripper_trajs_in_obj_all_actions = pickle.load(f1)
with open(os.path.join(base_dir, 'processed', triangulation, 'HTs_obj_in_ndi.pickle',), 'rb') as f2:
    HTs_obj_in_ndi_all_actions = pickle.load(f2)
with open(os.path.join(base_dir, 'processed', triangulation, 'gripper_traj_in_grouped_objs_aligned_filtered.pickle',), 'rb') as f3:
    gripper_trajs_in_grouped_objs_all_actions = pickle.load(f3)
with open(os.path.join(base_dir, 'processed', triangulation, 'HTs_grouped_objs_in_ndi.pickle',), 'rb') as f4:
    HTs_grouped_objs_in_ndi_all_actions = pickle.load(f4)

ind = 0  # index of action to be tested
gripper_trajs_in_ndi = gripper_trajs_truncated[ind]
gripper_traj_in_obj = gripper_trajs_in_obj_all_actions[ind]
gripper_traj_in_grouped_obj = gripper_trajs_in_grouped_objs_all_actions[ind]
gripper_traj_in_generalized_obj = gripper_traj_in_obj | gripper_traj_in_grouped_obj

HTs_obj_in_ndi = HTs_obj_in_ndi_all_actions[ind]
HTs_grouped_obj_in_ndi = HTs_grouped_objs_in_ndi_all_actions[ind]
HTs_generalized_obj_in_ndi = HTs_obj_in_ndi | HTs_grouped_obj_in_ndi

outliers = []
std_thres = 3
individuals= ['cup', 'teabag1']

for individual in individuals:
    n_std = std_thres
    outlier_individual = detect_outlier(gripper_traj_in_generalized_obj[individual], n=n_std)
    print(f'The outliers for individual {individual} are {outlier_individual}')
    outliers += outlier_individual
outliers = list(set(outliers))
bad_demos = outliers

demos = sorted(list(HTs_obj_in_ndi.keys()))

train_demos_pool = [demo for demo in demos[:40] if demo not in bad_demos]
test_demos_pool = [demo for demo in demos[40:] if demo not in bad_demos]
# Train model
print(f'The number of training pool is: {len(train_demos_pool)}')
print(f'The number of outliers is: {len(outliers)}')
n_dims = len(data.dims)
n_train = 32

In [None]:
train_demos = random.sample(train_demos_pool, k=n_train)
test_demos_pool_updated = [demo for demo in test_demos_pool if demo not in train_demos]
test_demos = random.sample(test_demos_pool_updated, k=8)
test_demo = test_demos[0]
ground_truth = gripper_trajs_in_ndi[test_demo][data.dims].to_numpy()
t_pmp = np.linspace(0, 1, ground_truth.shape[0])
t_gmm = gripper_trajs_in_ndi[test_demo]['Time'].to_numpy()

data_temp = []
times_temp = []
data_all_frames_tp_pmp = {}
data_all_frames_pmp = {}

In [None]:
train_obj_pos = []
train_traj_pos = []
for demo in train_demos:
    H = HTs_generalized_obj_in_ndi[demo]
    traj = gripper_traj_in_obj['global'][demo][['x','y','z']].to_numpy()
    teabag_pos = lintrans(np.zeros([1,3]), 
                    HTs_obj_in_ndi[demo]['teabag1'])
    
    cup_pos = lintrans(np.zeros([1,3]), 
                    HTs_obj_in_ndi[demo]['cup'])
    train_obj_pos.append(np.concatenate([cup_pos, teabag_pos]))
    train_traj_pos.append(traj)
    
test_obj_pos = []
test_traj_pos = []
for demo in test_demos:
    H = HTs_generalized_obj_in_ndi[demo]
    traj = gripper_traj_in_obj['global'][demo][['x','y','z']].to_numpy()
    teabag_pos = lintrans(np.zeros([1,3]), 
                    HTs_obj_in_ndi[demo]['teabag1'])
    
    cup_pos = lintrans(np.zeros([1,3]), 
                    HTs_obj_in_ndi[demo]['cup'])
    test_obj_pos.append(np.concatenate([cup_pos, teabag_pos]))
    test_traj_pos.append(traj)

train_obj_pos = np.array(train_obj_pos)
train_traj_pos = np.array(train_traj_pos)
test_obj_pos = np.array(test_obj_pos)
test_traj_pos = np.array(test_traj_pos)

In [None]:
def random_rotation(x, axis='x'):
    degree, idx = random.randrange(0, 360), random.randrange(0, x.shape[0])
    rot = R.from_euler(axis, degree, degrees=True)
    H = np.zeros([4,4])
    H[:3,:3] = rot.as_matrix()
    rand_pt = x[idx]
    centered_x = x - rand_pt
    return lintrans(centered_x, H) + rand_pt

class TrajectoryDataset(Dataset):
    def __init__(self, obj_data, traj_data, transform=lambda a : a):
        self.traj_data = torch.tensor(traj_data)
        self.obj_data = torch.tensor(obj_data)
        self.length = len(self.obj_data)
        self.transform = transform
       

    def __getitem__(self, idx):
        traj_data = self.traj_data[idx]
        obj_data = self.obj_data[idx]
        x = np.concatenate([obj_data, traj_data])
        # Transformation process
        x = self.transform(x)
        # Tag objects and trajectory
        traj_tags = np.zeros([traj_data.shape[0]-1, traj_data.shape[1]])
        traj_tags[:,-1] = 1
        obj_tags = np.diag(np.ones(obj_data.shape[0]+1))
        tags = np.concatenate([obj_tags, traj_tags])
        x = np.concatenate([x, tags], axis=-1)
        x = torch.tensor(x)
        return x

    def __len__(self):
        return self.length

    
train_mean = np.mean(train_obj_pos, keepdims=2)
train_std = np.std(train_traj_pos)

train_obj_pos = (train_obj_pos - train_mean)/train_std
train_traj_pos = (train_traj_pos - train_mean)/train_std
test_obj_pos = (test_obj_pos - train_mean)/train_std
test_traj_pos = (test_traj_pos - train_mean)/train_std
    
# Create dataloaders
training_data = TrajectoryDataset(train_obj_pos, train_traj_pos, transform=random_rotation)
test_data = TrajectoryDataset(test_obj_pos, test_traj_pos, transform=random_rotation)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
# %matplotlib notebook
# shows random set of augmented trajectory
matplotlib.rcParams.update({'font.size': 10})
fig = plt.figure(figsize = (8, 6))
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.set_facecolor('white')
ax.locator_params(nbins=3, axis='z')
colors = ['red', 'blue', 'yellow', 'orange', 'green', 'purple','pink']
for i_batch, sample_batched in enumerate(training_data):
    if i_batch > 5: break
    teabag_pos = sample_batched[0,]
    cup_pos = sample_batched[1,]
    traj_pos = sample_batched[2:]
    line = ax.plot(traj_pos[:,2], traj_pos[:,1], -traj_pos[:,0], '--', color=colors[i_batch%len(colors)], 
                   label = f'demo {i_batch}')
    ax.plot(teabag_pos[2], teabag_pos[1], -teabag_pos[0], 'o',
            color=colors[i_batch%len(colors)], label=f'teabag {i_batch}')
    ax.plot(cup_pos[2], cup_pos[1], -cup_pos[0], 'x',
            color=colors[i_batch%len(colors)], label=f'cup {i_batch}')
ax.set_xlabel('x (mm)')
ax.set_ylabel('y (mm)')
ax.set_zlabel('z (mm)')
ax.set_box_aspect([ub - lb for lb, ub in (getattr(ax, f'get_{a}lim')() for a in 'xyz')])
handles, labels = ax.get_legend_handles_labels()
newHandles_temp, newLabels_temp = remove_repetitive_labels(handles, labels)
newLabels, newHandles = [], []

for handle, label in zip(newHandles_temp, newLabels_temp):
    if label not in ['start', 'middle', 'end']:
        newLabels.append(label)
        newHandles.append(handle)
plt.legend(newHandles, newLabels, loc = 'upper left',  prop={'size': 10})
plt.show()

In [None]:
# Select model type
# model = EncoderModel(traj_dim=6, embed_dim=256, nhead=8,
#                         d_hid=512).to(device)
model = TFModel(traj_dim=6, embed_dim=256, nhead=8,
                        d_hid=512).to(device)
adam = optim.Adam(model.parameters(), lr=1e-5)

def loss_func(pred, truth):
    return F.mse_loss(pred, truth)

def train(model, optimizer, t_dataloader, v_dataloader=None, epochs=100):
    for i in range(epochs):
        train_losses, valid_losses = [], []
        model.train()
        for sample_batched in t_dataloader:
            # input modification
            sample_batched = sample_batched.float()
            masked_input = sample_batched.clone().float().to(device)
#             masked_input = random_mask(masked_input, [2, masked_input.shape[1]], percent)
            masked_input[:,2:,:3] = 0
            optimizer.zero_grad()
            output = model(masked_input)
            loss = loss_func(output, sample_batched[:,2:].to(device))
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            
        model.eval()
        if not v_dataloader is None:
            for sample_batched in v_dataloader:
                sample_batched = sample_batched.float()
                masked_input = sample_batched.clone().float().to(device)
                masked_input[:,2:, :3] = 0
                output = model(masked_input)
                loss = loss_func(output, sample_batched[:,2:].to(device))
                valid_losses.append(loss.item())
        print(f"Epoch:{i}, train/valid loss:{round(mean(train_losses),4)}/{round(mean(valid_losses),4)}")
        
# train(model, adam, train_dataloader, test_dataloader, epochs=50000)

In [None]:
# Save and load models
# torch.save(model, "saved_model.pt")
model = torch.load("saved_model.pt")

In [None]:
%matplotlib notebook
demo_input = next(iter(test_data))
masked_test_input = demo_input.clone().float().to(device)
masked_test_input[2:,:3] = 0
masked_test_input = masked_test_input.unsqueeze(0)

predicted_traj = model(masked_test_input).cpu().detach().numpy()[0]
matplotlib.rcParams.update({'font.size': 10})
fig = plt.figure(figsize = (7, 4))
ax2 = fig.add_subplot(1, 1, 1, projection='3d')
colors = ['red', 'blue', 'yellow', 'orange', 'green', 'purple','pink']
teabag_pos = demo_input[0,]
cup_pos = demo_input[1,]
traj_pos = demo_input[2:]
line = ax2.plot(traj_pos[:,2], traj_pos[:,1], -traj_pos[:,0], '--', color='red', 
               label = f'ground truth')
ax2.plot(teabag_pos[2], teabag_pos[1], -teabag_pos[0], 'o',
        color='red', label=f'teabag')
ax2.plot(cup_pos[2], cup_pos[1], -cup_pos[0], 'x',
        color='red', label=f'cup')


line = ax2.plot(predicted_traj[2:,2], predicted_traj[2:,1], -predicted_traj[2:,0], '--', color='blue', 
               label = f'predicted')


ax2.set_xlabel('x (mm)')
ax2.set_ylabel('y (mm)')
ax2.set_zlabel('z (mm)')
ax2.set_box_aspect([ub - lb for lb, ub in (getattr(ax2, f'get_{a}lim')() for a in 'xyz')])
handles, labels = ax2.get_legend_handles_labels()
newHandles_temp, newLabels_temp = remove_repetitive_labels(handles, labels)
newLabels, newHandles = [], []

for handle, label in zip(newHandles_temp, newLabels_temp):
    if label not in ['start', 'middle', 'end']:
        newLabels.append(label)
        newHandles.append(handle)
plt.show()

In [None]:
axis = ['z', 'y', 'x']
matplotlib.rcParams.update({'font.size': 10})
fig = plt.figure(figsize = (5, 4))
axs = fig.subplots(3, 1)
colors = ['red', 'blue', 'yellow', 'orange', 'green', 'purple','pink']
teabag_pos = demo_input[0,]

for i in range(3):
    axs[i].plot(-demo_input[2:,i], 'o', color='red',  label = f' ground truth')
    axs[i].plot(-predicted_traj[2:,i], 'x', color='blue', label=f'predict')
    axs[i].set_xlabel('time')
    axs[i].set_ylabel(axis[i])
    axs[i].set_title(f'{axis[i]}-axis vs Time')

plt.show()

In [None]:
demo_input = next(iter(training_data))
masked_test_input = demo_input.clone().float().to(device)
masked_test_input[2:,:3] = 0
masked_test_input = masked_test_input.unsqueeze(0)

loss_func(model(masked_test_input)[2:].cpu(), demo_input)