In [1]:
import sys
sys.path.insert(1, 'data_utils')
sys.path.insert(1, 'models/')

from transform_functions import PCRNetTransform as transform
import transform_functions
from modelnet_reg_utils import ModelNet40Data, RegistrationData
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import numpy as np
import os
import open3d as o3d
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
import time
from models.feature_models import PointResNet, PointNet, AttentionPointResNet
from models.attention_pooling import AttentionPooling
from utils.load_model import load_model
from args import Args
import transforms3d
from scipy.spatial.transform import Rotation as R
arger = Args()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class get_model(nn.Module):
    def __init__(self,args,  feature_extractor = PointResNet):
        super(get_model, self).__init__()
        self.args = args
        self.feature_extractor = feature_extractor(self.args)
        if self.args.attention_pooling:
            self.attentional_pooling = AttentionPooling()
        else:
            self.attentional_pooling = None
    def forward(self, x):
        x_ap, x_mp = self.feature_extractor(x)
        if self.args.attention_pooling:
            x_feat= self.attentional_pooling(x_ap, x_mp)
        else:
            x_feat = x_mp
        return x_feat

In [3]:
class iPCRNet(nn.Module):
	def __init__(self, feature_model, droput=0.0):
		super().__init__()
		self.feature_model = feature_model

		self.linear = [nn.Linear(1024 * 2, 1024), nn.ReLU(),
				   	   nn.Linear(1024, 1024), nn.ReLU(),
				   	   nn.Linear(1024, 512), nn.ReLU(),
				   	   nn.Linear(512, 512), nn.ReLU(),
				   	   nn.Linear(512, 256), nn.ReLU()]

		if droput>0.0:
			self.linear.append(nn.Dropout(droput))
		self.linear.append(nn.Linear(256,7))

		self.linear = nn.Sequential(*self.linear)

	# Single Pass Alignment Module (SPAM)
	def spam(self, template_features, source, est_R, est_t):
		batch_size = source.size(0)
		self.source_features = self.feature_model(source.permute(0,2,1))
		y = torch.cat([template_features, self.source_features], dim=1)
		pose_7d = self.linear(y)
		pose_7d = transform.create_pose_7d(pose_7d)

		# Find current rotation and translation.
		identity = torch.eye(3).to(source).view(1,3,3).expand(batch_size, 3, 3).contiguous()
		est_R_temp = transform.quaternion_rotate(identity, pose_7d).permute(0, 2, 1)
		est_t_temp = transform.get_translation(pose_7d).view(-1, 1, 3)

		# update translation matrix.
		est_t = torch.bmm(est_R_temp, est_t.permute(0, 2, 1)).permute(0, 2, 1) + est_t_temp
		# update rotation matrix.
		est_R = torch.bmm(est_R_temp, est_R)
		
		source = transform.quaternion_transform(source, pose_7d)      # Ps' = est_R*Ps + est_t
		return est_R, est_t, source

	def forward(self, template, source, max_iteration=3):
		est_R = torch.eye(3).to(template).view(1, 3, 3).expand(template.size(0), 3, 3).contiguous()         # (Bx3x3)
		est_t = torch.zeros(1,3).to(template).view(1, 1, 3).expand(template.size(0), 1, 3).contiguous()     # (Bx1x3)
		template_features = self.feature_model(template.permute(0,2,1))
		if max_iteration == 1:
			est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
		else:
			for i in range(max_iteration):
				est_R, est_t, source = self.spam(template_features, source, est_R, est_t)

		result = {'est_R': est_R,				# source -> template
				  'est_t': est_t,				# source -> template
				  'est_T': transform.convert2transformation(est_R, est_t),			# source -> template
				  'r': template_features - self.source_features,
				  'transformed_source': source}
		return result

In [4]:
BATCH_SIZE = 16
testset = RegistrationData(ModelNet40Data(train=False, download=True), is_testing=True, 
                           angle_range=90, translation_range=1, add_noise=False, shuffle_points=False)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=2)

In [5]:
def visualize_with_trans(source, template, trans):
    src_b = source.detach().cpu().numpy()
    tar_b = template.detach().cpu().numpy()
    trans_b = trans.detach().cpu().numpy()

    for i in range(src_b.shape[0]):
        src = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(src_b[i])).paint_uniform_color([1, 0.706, 0])
        tar = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(tar_b[i])).paint_uniform_color([0, 0.651, 0.929])
        src.transform(trans_b[0])
        o3d.visualization.draw_geometries([src, tar])

In [6]:
def display_open3d(template, source, transformed_source):
	template_ = o3d.geometry.PointCloud()
	source_ = o3d.geometry.PointCloud()
	transformed_source_ = o3d.geometry.PointCloud()
	template_.points = o3d.utility.Vector3dVector(template)
	source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
	transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
	template_.paint_uniform_color([1, 0, 0])
	source_.paint_uniform_color([0, 1, 0])
	transformed_source_.paint_uniform_color([0, 0, 1])
	o3d.visualization.draw_geometries([template_, source_, transformed_source_])

In [7]:
def batch_inverse(T):
    """
    Invert a batch of 4x4 transformation matrices.
    Args:
        T: A torch tensor of shape (B, 4, 4), where B is the batch size.
    Returns:
        inv_T: A torch tensor of shape (B, 4, 4), where each 4x4 matrix
               is the inverse of the corresponding input matrix.
    """
    B = T.shape[0]
    inv_T = torch.zeros_like(T)
    for i in range(B):
        inv_T[i] = torch.inverse(T[i])
    return inv_T

In [8]:
def quat2mat(q):
    quat = q[:,0:4]
    trans = q[:,4:7]

    x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]

    B = quat.size(0)
    device = quat.device
    
    w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
    wx, wy, wz = w*x, w*y, w*z
    xy, xz, yz = x*y, x*z, y*z

    rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
                          2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
                          2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3)
    transMat = torch.cat([rotMat, trans.unsqueeze(-1)], dim = -1)
    transMat = torch.cat([transMat, torch.tensor([0,0,0,1]).view(1,1,4).expand(B,1,4).to(device)], dim = 1)
    return transMat

In [9]:
from common.torch import to_numpy
from common.math import se3
from common.math_torch import se3
from common.math.so3 import dcm2euler

def compute_metrics(points_ref, points_src, gt_transforms, pred_transforms):
    """Compute metrics required in the paper
    """

    def square_distance(src, dst):
        return torch.sum((src[:, :, None, :] - dst[:, None, :, :]) ** 2, dim=-1)

    with torch.no_grad():

        # Euler angles, Individual translation errors (Deep Closest Point convention)
        # TODO Change rotation to torch operations
        r_gt_euler_deg = dcm2euler(gt_transforms[:, :3, :3].detach().cpu().numpy(), seq='xyz')
        r_pred_euler_deg = dcm2euler(pred_transforms[:, :3, :3].detach().cpu().numpy(), seq='xyz')
        t_gt = gt_transforms[:, :3, 3]
        t_pred = pred_transforms[:, :3, 3]
        r_mse = np.mean((r_gt_euler_deg - r_pred_euler_deg) ** 2, axis=1)
        r_mae = np.mean(np.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1)
        t_mse = torch.mean((t_gt - t_pred) ** 2, dim=1)
        t_mae = torch.mean(torch.abs(t_gt - t_pred), dim=1)

        # Rotation, translation errors (isotropic, i.e. doesn't depend on error
        # direction, which is more representative of the actual error)
        concatenated = se3.concatenate(se3.inverse(gt_transforms), pred_transforms)
        rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2]
        residual_rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi
        residual_transmag = concatenated[:, :, 3].norm(dim=-1)

        # Modified Chamfer distance
        src_transformed = se3.transform(pred_transforms, points_src)
        src_clean = se3.transform(se3.concatenate(pred_transforms, se3.inverse(gt_transforms)), points_ref)
        dist_src = torch.min(square_distance(src_transformed, points_ref), dim=-1)[0]
        dist_ref = torch.min(square_distance(points_ref, src_clean), dim=-1)[0]
        chamfer_dist = torch.mean(dist_src, dim=1) + torch.mean(dist_ref, dim=1)

        metrics = {
            'l1_dist': torch.mean(torch.abs(points_ref - src_clean), dim=1),
            'r_mse': r_mse,
            'r_rmse': np.sqrt(r_mse),
            'r_mae': r_mae,
            't_mse': to_numpy(t_mse),
            't_rmse': to_numpy(torch.sqrt(t_mse)),
            't_mae': to_numpy(t_mae),
            'err_r_deg': to_numpy(residual_rotdeg),
            'err_t': to_numpy(residual_transmag),
            'chamfer_dist': to_numpy(chamfer_dist)
        }

    return metrics

def evaluate(device, model, test_loader):
    model.eval()
    test_loss = 0.0
    r_mse = []
    t_mse = []
    r_mae = []
    t_mae = []
    err_r_deg = []
    err_t = []
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_loader)):
            template, source, igt, igt_R, igt_t = data

            template = template.to(device)
            source = source.to(device)
            igt = igt.to(device).squeeze(1)
            igt_R = igt_R.to(device)
            igt_t = igt_t.to(device)
            source_original = source.clone()
            template_original = template.clone()
            igt_t = igt_t - torch.mean(source, dim=1).unsqueeze(1)
            source = source - torch.mean(source, dim=1, keepdim=True)
            template = template - torch.mean(template, dim=1, keepdim=True)

            output = model(template, source)
            gt_tsf = transform.convert2transformation(igt_R, igt_t).to(device)
            pred_tsf = output['est_T']
            est_R = output['est_R']
            est_t = output['est_t']
            metrics = compute_metrics(source, template, batch_inverse(gt_tsf), pred_tsf)
            
            r_mse.append(metrics['r_mse'])
            t_mse.append(metrics['t_mse'])
            r_mae.append(metrics['r_mae'])
            t_mae.append(metrics['t_mae'])
            err_r_deg.append(metrics['err_r_deg'])
            err_t.append(metrics['err_t'])
            # visualize_with_trans(source, template, pred_tsf)
            # visualize_with_trans(source, template, batch_inverse(gt_tsf))
            #display_open3d(template.detach().cpu().numpy()[0], source_original.detach().cpu().numpy()[0], source_tsf[0])
        print("rotation mse:", np.mean(np.array(r_mse).reshape(-1,1), axis=0))
        print("translation mse:", np.mean(np.array(t_mse).reshape(-1,1), axis=0))
        print("rotation mae:", np.mean(np.array(r_mae).reshape(-1,1), axis=0))
        print("translation mae:", np.mean(np.array(t_mse).reshape(-1,1), axis=0))
        print("rotation error:", np.mean(np.array(err_r_deg).reshape(-1,1), axis=0))
        print("translation error:", np.mean(np.array(err_t).reshape(-1,1), axis=0))

In [10]:
if not torch.cuda.is_available():
    device = 'cpu'
else: 
    device = 'cuda:0'
device = torch.device(device)

# Create PointNet Model.
ptresnet = get_model(arger, feature_extractor= AttentionPointResNet)
respcr = iPCRNet(feature_model=ptresnet)

In [11]:
best_model_path = "backup/best_pointresnet.t7"
LOAD = True
if os.path.isfile(best_model_path) and LOAD:
    print("Found checkpoint, loading weights")
    checkpoint = torch.load(best_model_path)
    start_epoch = checkpoint['epoch']
    respcr.load_state_dict(checkpoint['model'])
    respcr.to(device)
else:
    print("notfound")
    start_epoch = 0
    respcr.to(device)

Found checkpoint, loading weights


In [12]:
evaluate(device, respcr, test_loader)

100%|██████████| 154/154 [00:32<00:00,  4.71it/s]

rotation mse: [2888.81540847]
translation mse: [1.6581904e-06]
rotation mae: [30.4355322]
translation mae: [1.6581904e-06]
rotation error: [55.15109]
translation error: [0.00200028]





In [15]:
# Create PointNet Model.
ptnet = get_model(arger, feature_extractor = PointNet)
pcrnet = iPCRNet(feature_model=ptnet)

In [16]:
best_model_path = "backup/vanilla_pointnet_it3.t7"
LOAD = True
if os.path.isfile(best_model_path) and LOAD:
    print("Found checkpoint, loading weights")
    checkpoint = torch.load(best_model_path)
    start_epoch = checkpoint['epoch']
    pcrnet.load_state_dict(checkpoint['model'])
    pcrnet.to(device)
else:
    print("notfound")
    start_epoch = 0
    pcrnet.to(device)

Found checkpoint, loading weights


In [17]:
evaluate(device, pcrnet, test_loader)

100%|██████████| 617/617 [00:25<00:00, 24.62it/s]

rotation mse: [126.0874611]
translation mse: [7.859694e-05]
rotation mae: [6.36654402]
translation mae: [7.859694e-05]
rotation error: [12.99217]
translation error: [0.01199805]



