In [55]:
from splatart.networks.PoseEstimator import PoseEstimator
from splatart.managers.SplatManager import SplatManagerSingle
import json
import numpy as np
import torch

In [7]:
manager_path = '../arm_accuracy_pths/seg_learned_manager_0.pth'
pose_estimator_path = '../arm_accuracy_pths/pose_estimator.pth'
gauss_param_path = '../arm_accuracy_pths/part_gauss_params.pth'
transforms_path = '../arm_accuracy_pths/transforms.json'

In [15]:
splatManager = torch.load(manager_path)
poseEstimator = torch.load(pose_estimator_path)
gaussParams = torch.load(gauss_param_path)
transforms = json.load(open(transforms_path, 'r'))

DataParser info

In [17]:
dp_scale, dp_tf_matrix = splatManager.dataparser_scale, splatManager.dataparser_tf_matrix.cpu()
dp_scale, dp_tf_matrix

(0.3335080532301977,
 tensor([[ 9.9778e-01, -8.8128e-04, -6.6559e-02,  3.4200e-09],
         [-8.8128e-04,  9.9965e-01, -2.6447e-02,  1.1238e-08],
         [ 6.6559e-02,  2.6447e-02,  9.9743e-01,  1.5395e-10],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]]))

In [43]:
gauss_means = [gaussParams[0][i]['means'].detach().cpu() for i in range(poseEstimator.num_parts)]
len(gauss_means)

12

PoseEstimator info

In [44]:
poseEstimator.num_parts

12

In [45]:
trans_estim = poseEstimator.part_means[0, :, :].detach().cpu()
rot_estim = poseEstimator.part_eulers[0, :, :].detach().cpu()

GT Transforms info

In [46]:
gt_transforms = transforms['gt_part_world_poses']['0']

ADD Metric per Part

In [58]:
from scipy.spatial.transform import Rotation as R

add_values = {}

for i in range(poseEstimator.num_parts):
    part_name = list(gt_transforms.keys())[i]

    gt_transform = np.array(gt_transforms[part_name])
    gt_transform = torch.from_numpy(gt_transform).float()

    # convert translation and euler angles to 4x4 rotation matrix
    pred_transform = torch.eye(4)
    pred_transform[:3, :3] = torch.from_numpy(R.from_euler('xyz', rot_estim[i]).as_matrix())
    pred_transform[3, :3] = trans_estim[i]
    pred_transform = pred_transform.float()

    # invert dataparser transform
    pred_transform = torch.linalg.inv(dp_tf_matrix) @ pred_transform
    pred_transform = pred_transform / dp_scale

    # apply to gaussian means
    pred_gauss_mean = gauss_means[i]
    pred_gauss_mean = torch.cat((pred_gauss_mean, torch.ones(pred_gauss_mean.shape[0], 1)), dim=1)
    pred_gauss_mean = pred_transform @ pred_gauss_mean.T
    pred_gauss_mean = pred_gauss_mean.T[:, :3]

    gt_gauss_mean = torch.from_numpy(np.array(gt_transform[:3, 3])).float()
    gt_gauss_mean = gt_gauss_mean.unsqueeze(0)
    gt_gauss_mean = torch.cat((gt_gauss_mean, torch.ones(gt_gauss_mean.shape[0], 1)), dim=1)
    gt_gauss_mean = torch.linalg.inv(gt_transform) @ gt_gauss_mean.T
    gt_gauss_mean = gt_gauss_mean.T[:, :3]

    # compute ADD metric
    add = torch.norm(pred_gauss_mean - gt_gauss_mean, dim=1).mean()
    add_values[part_name] = add.item()


In [59]:
add_values

{'panda_link0': nan,
 'panda_link1': 0.09673026949167252,
 'panda_link2': 0.27024930715560913,
 'panda_link3': 0.3905125558376312,
 'panda_link4': 0.5522987246513367,
 'panda_link5': 0.6008105278015137,
 'panda_link6': 0.41271713376045227,
 'panda_link7': 0.44242098927497864,
 'panda_link8': 0.46363022923469543,
 'panda_hand': 0.44844189286231995,
 'panda_leftfinger': 0.5124382376670837,
 'panda_rightfinger': 0.4623609185218811}

Function

In [60]:
def ADD_metric(splatManager:SplatManagerSingle, poseEstimator:PoseEstimator, gaussParams:list, transforms:dict, scene_id:int):
    """
    Computes the ADD metric for a given prediction and ground truth.
    """

    # dataparser and means
    dp_scale, dp_tf_matrix = splatManager.dataparser_scale, splatManager.dataparser_tf_matrix.cpu()
    gauss_means = [gaussParams[scene_id][i]['means'].detach().cpu() for i in range(poseEstimator.num_parts)]

    # estimated transforms
    trans_estim = poseEstimator.part_means[scene_id, :, :].detach().cpu()
    rot_estim = poseEstimator.part_eulers[scene_id, :, :].detach().cpu()

    # ground truth transforms
    gt_transforms = transforms['gt_part_world_poses'][str(scene_id)]

    add_values = {}
    for i in range(poseEstimator.num_parts):
        part_name = list(gt_transforms.keys())[i]

        gt_transform = np.array(gt_transforms[part_name])
        gt_transform = torch.from_numpy(gt_transform).float()

        # convert translation and euler angles to 4x4 rotation matrix
        pred_transform = torch.eye(4)
        pred_transform[:3, :3] = torch.from_numpy(R.from_euler('xyz', rot_estim[i]).as_matrix())
        pred_transform[3, :3] = trans_estim[i]
        pred_transform = pred_transform.float()

        # invert dataparser transform
        pred_transform = torch.linalg.inv(dp_tf_matrix) @ pred_transform
        pred_transform = pred_transform / dp_scale

        # apply to gaussian means
        pred_gauss_mean = gauss_means[i]
        pred_gauss_mean = torch.cat((pred_gauss_mean, torch.ones(pred_gauss_mean.shape[0], 1)), dim=1)
        pred_gauss_mean = pred_transform @ pred_gauss_mean.T
        pred_gauss_mean = pred_gauss_mean.T[:, :3]

        gt_gauss_mean = torch.from_numpy(np.array(gt_transform[:3, 3])).float()
        gt_gauss_mean = gt_gauss_mean.unsqueeze(0)
        gt_gauss_mean = torch.cat((gt_gauss_mean, torch.ones(gt_gauss_mean.shape[0], 1)), dim=1)
        gt_gauss_mean = torch.linalg.inv(gt_transform) @ gt_gauss_mean.T
        gt_gauss_mean = gt_gauss_mean.T[:, :3]

        # compute ADD metric
        add = torch.norm(pred_gauss_mean - gt_gauss_mean, dim=1).mean()
        add_values[part_name] = add.item()

    return add_values

    

In [61]:
add_values = ADD_metric(splatManager, poseEstimator, gaussParams, transforms, 0)
add_values

{'panda_link0': nan,
 'panda_link1': 0.09673026949167252,
 'panda_link2': 0.27024930715560913,
 'panda_link3': 0.3905125558376312,
 'panda_link4': 0.5522987246513367,
 'panda_link5': 0.6008105278015137,
 'panda_link6': 0.41271713376045227,
 'panda_link7': 0.44242098927497864,
 'panda_link8': 0.46363022923469543,
 'panda_hand': 0.44844189286231995,
 'panda_leftfinger': 0.5124382376670837,
 'panda_rightfinger': 0.4623609185218811}