In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import cv2
import random
import torch.optim as optim
import matplotlib.pyplot as plt
from IPython.display import clear_output
from scipy.spatial.transform import Rotation as R
import torchvision.models as models
import itertools
import os

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def quat_to_R_matrix(q):
    q = q / q.norm(p=2)
    q_w, q_x, q_y, q_z = q
    Rmat = torch.stack([
        torch.stack([1 - 2*(q_y**2 + q_z**2), 2*(q_x*q_y - q_w*q_z), 2*(q_x*q_z + q_w*q_y)]),
        torch.stack([2*(q_x*q_y + q_w*q_z), 1 - 2*(q_x**2 + q_z**2), 2*(q_y*q_z - q_w*q_x)]),
        torch.stack([2*(q_x*q_z - q_w*q_y), 2*(q_y*q_z + q_w*q_x), 1 - 2*(q_x**2 + q_y**2)])])
    return Rmat.float()

In [4]:
def cal_gt_relpose_360(T1_ls, Q1_ls, T2_ls, Q2_ls, diff_t=2.5, diff_r=25.0):
    """This function computes the ground truth realtive translations of the source images wrt template images
       together with their normalized relative quartenions
       T1_ls is a list of source image translations wrt G
       Q1_ls is a list of source image quats. wrt G
       T2_ls is a list of template image translations wrt G
       Q2_ls is a list of template image quats. wrt G
       diff_t is the maximum allowable relative translation between source and template image
       diff_r is the maximum allowable relative rotation (in euler angles) between source and template image
       """
    T_rel_k_i_list = []
    Q_rel_ck_ci_list = []
    flags = []
    for i in range(len(T1_ls)):
        T_i = np.array(T1_ls[i]).reshape(-1,1)
        R_i = quat_to_R_matrix(torch.tensor(Q1_ls[i])).numpy()
        Trel_i = []
        Qrel_i = []
        fl_i = []
        for j in range(len(T2_ls)):
            T_k = np.array(T2_ls[j]).reshape(-1,1)
            R_k = quat_to_R_matrix(torch.tensor(Q2_ls[j])).numpy()
            Trel_k_i = (R_k)@(T_i - T_k)
            fl_t = np.all(np.abs(Trel_k_i.flatten())<=diff_t).item()
            R_rel_ck_ci = R_i@R_k.T
            q_rel_ck_ci = R.from_matrix(R_rel_ck_ci).as_quat()
            eul_rel_ck_ci = R.from_matrix(R_rel_ck_ci).as_euler('xyz', degrees=True)
            fl_r = np.all(np.abs(eul_rel_ck_ci.flatten())<=diff_r).item()
            fl_i.append(fl_t*fl_r)
            
            q_rel_ck_ci = torch.tensor([q_rel_ck_ci[3], q_rel_ck_ci[0], q_rel_ck_ci[1], q_rel_ck_ci[2]])
            q_rel_ck_ci = q_rel_ck_ci / q_rel_ck_ci.norm(p=2) # unit norm quarternion
            Trel_i.append(Trel_k_i.flatten().tolist())
            Qrel_i.append(q_rel_ck_ci.flatten().tolist())
            
        T_rel_k_i_list.append(Trel_i)
        Q_rel_ck_ci_list.append(Qrel_i)
        flags.append(fl_i)
  
    T_rel_ck_ci = torch.tensor(T_rel_k_i_list).float().to(device) ## NxMx3 ground truth relative translation of source wrt target
    Q_rel_ck_ci = torch.tensor(Q_rel_ck_ci_list).float().to(device) ## NxMx4 ground truth relative quartenions of source wrt target
    fl_ci_ck = torch.tensor(flags).bool().to(device) ## NxM flags of valid relative poses
    return T_rel_ck_ci, Q_rel_ck_ci, fl_ci_ck

In [5]:
def geom_cstr_comb_flags360(T1_list, Q1_list, T2_list, Q2_list, diff_t=2.5, diff_r=25.0): 
    """returns the flags for valid camera pairs within translation and rotation limits"""
    all_Q_list = Q1_list[:]
    all_Q_list.extend(Q2_list)
    all_T_list = T1_list[:]
    all_T_list.extend(T2_list)
    geom_flag = []
    for i in range(len(all_Q_list)):
        flag2_i = []
        T1 = np.array(all_T_list[i]).reshape(-1,1)
        q1 = all_Q_list[i]
        R1 = np.array((R.from_quat([q1[1],q1[2],q1[3],q1[0]])).as_matrix())
        for j in range(len(all_Q_list)):
            T2_j = np.array(all_T_list[j]).reshape(-1,1)
            q2 = all_Q_list[j]
            R2_j = np.array((R.from_quat([q2[1],q2[2],q2[3],q2[0]])).as_matrix())
            Trel_k_i = (R2_j)@(T1 - T2_j)
            fl_t = np.all(np.abs(Trel_k_i.flatten())<=diff_t).item()
            R_rel_ck_ci = R1@R2_j.T
            eul_rel_ck_ci = R.from_matrix(R_rel_ck_ci).as_euler('xyz', degrees=True)
            fl_r = np.all(np.abs(eul_rel_ck_ci.flatten())<=diff_r).item()
            flag2_i.append(fl_t*fl_r)  
        geom_flag.append(flag2_i)
    return torch.tensor(geom_flag).bool().to(device)

In [6]:
def calc_cep_err(arr1, arr2):
    err = np.nanmedian(np.abs(arr1 - arr2))
    return err

In [7]:
def clip_val(val):
    id1 = val >= 90
    id2 = val <= -90
    val[id1] = 180 - val[id1]
    val[id2] = -180 - val[id2]
    return val

In [8]:
def compute_pose_loss360(pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, flag, gt_Q_rel_ck_ci_tsr, gt_T_rel_k_i_tsr, 
                        criterion, l1=1e4, l2=1e2):
    quat_loss = []
    trans_loss = []
    for i in range(gt_Q_rel_ck_ci_tsr.shape[0]):
        for k in range(gt_Q_rel_ck_ci_tsr.shape[1]):
            if flag[i, k].item():
                pred_q_rel_ck_ci = pred_Q_rel_ck_ci_tsr[i,k,0:4]
                pred_q_rel_ck_ci = pred_q_rel_ck_ci/pred_q_rel_ck_ci.norm(p=2)
                gt_q_i = gt_Q_rel_ck_ci_tsr[i,k,:]
                quat_loss.append(l1*criterion(pred_q_rel_ck_ci, gt_q_i))
                pred_T_rel_ck_ci = pred_T_rel_ck_ci_tsr[i,k,:]
                gt_T_rel_k_i = gt_T_rel_k_i_tsr[i,k,:]
                trans_loss.append(l2*criterion(pred_T_rel_ck_ci, gt_T_rel_k_i))
    if len(quat_loss)>=1:
        ql = sum(quat_loss)/len(quat_loss)
    if len(trans_loss)>=1:
        tl = sum(trans_loss)/len(trans_loss)
    return ql, tl     

In [9]:
def comb_rgbd_data(img0, sz=300):
    """img0 is gray scale image
       If using ViT backbone, set sz=224"""
    img = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY)
    img = np.array(img)
    img = img.astype('float32')
    img = img/255.0 
    img = F.interpolate(torch.tensor(img).unsqueeze(0).unsqueeze(0), size=(sz,sz), mode='bilinear', align_corners=False)
    img = img.squeeze(0)
    return img

In [10]:
def prep_inputs3(img1_list, img2_list):
    rgbd1_tsr_list_src, rgbd1_tsr_list_temp = [], []
    for i in range(len(img1_list)): 
        rgbd1_tsr_src = comb_rgbd_data(img1_list[i]) 
        rgbd1_tsr_list_src.append(rgbd1_tsr_src)
    for j in range(len(img2_list)):
        rgbd1_tsr_temp = comb_rgbd_data(img2_list[j])
        rgbd1_tsr_list_temp.append(rgbd1_tsr_temp)
    rgbd1_tsr_src = torch.stack(rgbd1_tsr_list_src,0).float().to(device)
    rgbd1_tsr_temp = torch.stack(rgbd1_tsr_list_temp,0).float().to(device)
    return rgbd1_tsr_src, rgbd1_tsr_temp

In [11]:
class ExtendedResidualBlock2(nn.Module):
    def __init__(self, in_features, out_features, n_layers=5):
        super(ExtendedResidualBlock2, self).__init__()
        self.n_layers = n_layers
        self.fc1 = nn.Linear(in_features, out_features)
        self.shortcut = nn.Sequential()
        if in_features != out_features:
            self.shortcut = nn.Sequential(
                nn.Linear(in_features, out_features))

        self.fcn_dict = nn.ModuleDict()
        self.fcn_dict2 = nn.ModuleDict()
        for i in range(n_layers):
            self.fcn_dict[str(i)] = nn.Sequential(nn.Linear(out_features, out_features))
         
    def forward(self, x):
        out = F.tanh(self.fc1(x))
        for i in range(self.n_layers):
            out = self.fcn_dict[str(i)](out)
            if (i+1)%3==0:
                out = out + self.shortcut(x)
        out = out + self.shortcut(x)
        return out

class EfficientNetwork2(nn.Module):
    def __init__(self, inp, out, n_layers=5):
        super(EfficientNetwork2, self).__init__()
        self.input_layer = nn.Linear(inp, 1024)
        self.residual_block1 = ExtendedResidualBlock2(1024, 512, n_layers)
        self.residual_block2 = ExtendedResidualBlock2(512, 256, n_layers)
        self.residual_block3 = ExtendedResidualBlock2(256, 128, n_layers)
        self.residual_block4 = ExtendedResidualBlock2(128, 64, n_layers)
        self.residual_block5 = ExtendedResidualBlock2(64, 32, n_layers)
        self.fc3 = nn.Linear(32, 16)
        self.output_layer = nn.Linear(16, out)

    def forward(self, x):
        out = self.input_layer(x)
        out = self.residual_block1(out)
        out = self.residual_block2(out)
        out = self.residual_block3(out)
        out = self.residual_block4(out)
        out = self.residual_block5(out)
        out = self.fc3(out)
        out = self.output_layer(out)
        
        return out

In [12]:
class ModResNet2(nn.Module):
    def __init__(self, in_chans, out):
        super(ModResNet2, self).__init__()
        original_model = models.resnet101(pretrained=True)
        
        original_model.conv1 = nn.Conv2d(
                    in_channels=in_chans,  # Change from 1 to 3 to accept rgb images
                    out_channels=original_model.conv1.out_channels,
                    kernel_size=original_model.conv1.kernel_size,
                    stride=original_model.conv1.stride,
                    padding=original_model.conv1.padding,
                    bias=original_model.conv1.bias)
        
        self.features = nn.Sequential(
            original_model.conv1,
            original_model.bn1,
            original_model.relu,
            original_model.maxpool,
            original_model.layer1,
            original_model.layer2,
            original_model.layer3,
            original_model.layer4
        )
        self.avgpool = original_model.avgpool
        
        num_features = original_model.fc.in_features
        num_out_feas = out
        original_model.fc = nn.Linear(num_features, num_out_feas)
        self.fc = original_model.fc  
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        out_fc = self.fc(x)
        return out_fc

In [13]:
from torchvision.models import vit_b_16

In [14]:
class CustomViT(nn.Module):
    def __init__(self, num_channels=1, output_size=512):
        super(CustomViT, self).__init__()
        self.vit = vit_b_16(pretrained=True)
        self.vit.conv_proj = nn.Conv2d(num_channels, self.vit.conv_proj.out_channels,
                                       kernel_size=self.vit.conv_proj.kernel_size,
                                       stride=self.vit.conv_proj.stride,
                                       padding=self.vit.conv_proj.padding,
                                       bias=False)
        self.vit.heads = nn.Linear(self.vit.heads.head.in_features, 512)
    def forward(self, x):
        out = self.vit(x)
        return out

In [15]:
class SiamesePoseNet3b_dec(nn.Module):
    def __init__(self):
        super(SiamesePoseNet3b_dec, self).__init__()
        self.model = ModResNet2(1,512) #CustomViT() --- use this for vision transformer
        self.lin4c = EfficientNetwork2(512, 4, 2)     
        self.lin4d = EfficientNetwork2(512, 3, 2)
    def forward(self, rgbd1, rgbd2):
        f1_rgb, f2_rgb = self.model(rgbd1), self.model(rgbd2)
        B1, D1 = f1_rgb.shape
        B2, D2 = f2_rgb.shape
        f1_rgb = f1_rgb.unsqueeze(1)
        f2_rgb = f2_rgb.unsqueeze(0)
        out_prod = f1_rgb.expand(B1, B2, D1) - f2_rgb.expand(B1, B2, D2)
        q_wxyz = self.lin4c(out_prod) ## NxMx4 unnormalized quartenions predictions
        xyz = self.lin4d(out_prod) ## NxMx3 relative translation predictions
        return q_wxyz, xyz

In [16]:
def comp_inter_rot(i, k, pred_Q_rel_ck_ci_tsr, gt_Q1_tsr, gt_Q2_tsr, N1):
    # returns the intermediate rotation matrix i.e R_ck_ci
    ## i stands for the right index while k stands for the left index
    ## an index value less than 10 implies domain 1 else its domain 2
    N1_b = gt_Q1_tsr.shape[0]
    N2 = pred_Q_rel_ck_ci_tsr.shape[1]
    if i<N1 and k>=N1:
        R_ck_ci = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[i, k-N1, 0:4]) 
    elif i<N1 and k<N1: # both cameras are in domain 1
        rind = np.random.randint(0,N2)
        r_r = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[k, rind, 0:4])
        r_l = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[i, rind, 0:4])
        R_ck_ci = r_l@r_r.T # rotation from ck to ci
    elif i>=N1 and k>=N1: # both cameras are in domain 2
        r_r = quat_to_R_matrix(gt_Q2_tsr[k-N1, :])
        r_l = quat_to_R_matrix(gt_Q2_tsr[i-N1, :])
        R_ck_ci = r_l@r_r.T
    elif i>=N1 and k<N1: #camera i is in domain 2, but k in 1
        pred_R_ci_ck = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[k, i-N1, 0:4])
        R_ck_ci = pred_R_ci_ck.T
    return R_ck_ci

In [None]:
def composed_rot_loss(n_ins, gt_Q1_tsr, gt_Q2_tsr, gt_Q_rel_ck_ci_tsr, pred_Q_rel_ck_ci_tsr, all_comb_flag,
                      criterion, l=1e3, Nsel=300):
   """this function computes the composed rotational loss, given the graph of camera connectivity
      n_ins -- this is the number of inner rotation matrices in the composed rotations
      """
   
   ncams = pred_Q_rel_ck_ci_tsr.shape[0] + gt_Q2_tsr.shape[0] 
   ncams = np.arange(ncams)
   np.random.shuffle(ncams)
   ncams = ncams[0:10].tolist()
   all_inds = list(range(len(ncams)))
   perms = list(itertools.permutations(all_inds, n_ins))
   comb_rot_loss = []
   N1 = pred_Q_rel_ck_ci_tsr.shape[0]
   for i in ncams:
      for j in ncams:
         gt_flag_ij = all_comb_flag[i,j].item()
         if len(perms)>Nsel:
            perms_sel = random.sample(perms, Nsel)
         else:
            perms_sel = perms[:]
         if gt_flag_ij:
            for perm_k in perms_sel:
               lst_perm_k = list(perm_k)
               ext_arr = lst_perm_k[:]
               ext_arr.extend([i, j])
               if not(np.all(np.array(ext_arr)<N1) or np.all(np.array(ext_arr)>=N1)):
                  out_flag_1 = all_comb_flag[i,lst_perm_k[0]].item()
                  out_flag_2 = all_comb_flag[j,lst_perm_k[-1]].item()
                  if out_flag_1 and out_flag_2:
                     gt_R_cj_ci = comp_inter_rot(i, j, pred_Q_rel_ck_ci_tsr, 
                                                   gt_Q1_tsr, gt_Q2_tsr, N1)
                     left_rot_mat = comp_inter_rot(i, lst_perm_k[0], pred_Q_rel_ck_ci_tsr, 
                                                   gt_Q1_tsr, gt_Q2_tsr, N1)
                     right_rot_mat = comp_inter_rot(lst_perm_k[-1], j, pred_Q_rel_ck_ci_tsr, 
                                                   gt_Q1_tsr, gt_Q2_tsr, N1)

                     comp_rots = [gt_R_cj_ci, left_rot_mat, right_rot_mat]
                     sub_all_flags = []
                     int_rot_mats = []
                     if len(lst_perm_k)>=2:
                        for k in range(len(lst_perm_k)-1):
                           sub_flags_i = all_comb_flag[lst_perm_k[k],lst_perm_k[k+1]]
                           sub_all_flags.append(sub_flags_i)
                           int_rot_mat = comp_inter_rot(lst_perm_k[k],lst_perm_k[k+1], pred_Q_rel_ck_ci_tsr, 
                                                      gt_Q1_tsr, gt_Q2_tsr, N1)
                           int_rot_mats.append(int_rot_mat)
                           comp_rots.append(int_rot_mat)
                     else:
                        sub_all_flags.append(True)
                        int_rot_mats.append(torch.eye(3).float().to(device))
                        
                     contains_no_none = all(elem is not None for elem in comp_rots)
                     if len(sub_all_flags)>=1 and all(sub_all_flags) and contains_no_none:
                        R_cj_ck = right_rot_mat
                        for int_rot_mat_k in int_rot_mats:
                           R_cj_ck = int_rot_mat_k@R_cj_ck
                        R_cj_ci = left_rot_mat@R_cj_ck
                        comb_rot_loss.append(l*criterion(R_cj_ci, gt_R_cj_ci))

   if len(comb_rot_loss)>=10:
      return sum(comb_rot_loss)/len(comb_rot_loss)
   else:
      return 0.0                   

In [18]:
def comp_inter_trans(i, k, pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, gt_T1_tsr, gt_T2_tsr, gt_Q1_tsr, gt_Q2_tsr, N1, id=True):
    # returns the intermediate rotation matrix i.e R_ck_ci
    N2 = pred_T_rel_ck_ci_tsr.shape[1]
    if i<N1 and k>=N1: # if camera i is in domain 1, but lst_perm_k[0] in 2
        T_ck_ci = pred_T_rel_ck_ci_tsr[i, k-N1, :].reshape(-1,1)   
    elif i<N1 and k<N1: # both cameras are in domain 1
        rind = np.random.randint(0,N2)
        R_g_ck = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[k, rind, 0:4])
        T_g_ck = pred_T_rel_ck_ci_tsr[k, rind, 0:3].reshape(-1,1)
        T_g_ci = pred_T_rel_ck_ci_tsr[i, rind, 0:3].reshape(-1,1)
        T_ck_ci = R_g_ck@(T_g_ci - T_g_ck)  
    elif i>=N1 and k>=N1: # both cameras are in domain 2
        T_g_ck = gt_T2_tsr[k-N1, :].reshape(-1,1)
        T_g_ci = gt_T2_tsr[i-N1, :].reshape(-1,1)
        R_g_ck = quat_to_R_matrix(gt_Q2_tsr[k-N1, :])
        T_ck_ci = R_g_ck@(T_g_ci - T_g_ck)
    elif i>=N1 and k<N1: #camera i is in domain 2, but k in 1
        T_ci_ck = pred_T_rel_ck_ci_tsr[k, i-N1, :].reshape(-1,1)
        R_ci_ck = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[k, i-N1, 0:4])
        T_ck_ci = -1.0*R_ci_ck@T_ci_ck  
    return T_ck_ci

In [19]:
def comp_R_trans(k, j, pred_Q_rel_ck_ci_tsr, gt_Q1_tsr, gt_Q2_tsr, N1):
    N2 = pred_Q_rel_ck_ci_tsr.shape[1]
    if j<N1 and k>=N1:
        R_cj_ck = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[j, k-N1, 0:4]).T
    elif j<N1 and k<N1:
        rind = np.random.randint(0,N2)
        r_g_ck = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[k, rind, 0:4])
        r_g_cj = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[j, rind, 0:4])
        R_cj_ck = r_g_ck@r_g_cj.T
    elif j>=N1 and k>=N1:
        r_g_ck = quat_to_R_matrix(gt_Q2_tsr[k-N1, :])
        r_g_cj = quat_to_R_matrix(gt_Q2_tsr[j-N1, :])
        R_cj_ck = r_g_ck@r_g_cj.T
    elif j>=N1 and k<N1:
        R_cj_ck = quat_to_R_matrix(pred_Q_rel_ck_ci_tsr[k, j-N1, 0:4])
    return R_cj_ck

In [20]:
def composed_trans_loss(gt_Q1_tsr, gt_Q2_tsr, gt_T_rel_k_i_tsr, gt_Q_rel_ck_ci_tsr, 
                        pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, gt_T1_tsr, gt_T2_tsr, 
                        all_comb_flag, criterion, l=10.0):
    

  ncams = pred_T_rel_ck_ci_tsr.shape[0] + gt_T2_tsr.shape[0] 
  ncams = np.arange(ncams)
  np.random.shuffle(ncams)
  comb_trans_loss = []
  N1 = pred_T_rel_ck_ci_tsr.shape[0]
  for i in ncams:
    for j in ncams:
      gt_flag_ij = all_comb_flag[i,j].item()
      if (gt_flag_ij):
        for k in ncams:
          gt_flag_jk = all_comb_flag[j,k].item()
          gt_flag_ik = all_comb_flag[i,k].item()
          ext_arr = [i, j, k]
          dom_flg = not(np.all(np.array(ext_arr)<N1) or np.all(np.array(ext_arr)>=N1))
          if gt_flag_jk and gt_flag_ik and dom_flg:
            pred_T_ck_ci = comp_inter_trans(i, k, pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, gt_T1_tsr, 
                              gt_T2_tsr, gt_Q1_tsr, gt_Q2_tsr, N1)
            pred_T_ck_cj = comp_inter_trans(j, k, pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, gt_T1_tsr, 
                              gt_T2_tsr, gt_Q1_tsr, gt_Q2_tsr, N1)
            pred_T_cj_ci = comp_inter_trans(i, j, pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, gt_T1_tsr, 
                              gt_T2_tsr, gt_Q1_tsr, gt_Q2_tsr, N1)
            pred_R_cj_ck = comp_R_trans(k, j, pred_Q_rel_ck_ci_tsr, gt_Q1_tsr, 
                                        gt_Q2_tsr, N1)
            calc_pred_T_ck_ci = pred_T_ck_cj + pred_R_cj_ck@pred_T_cj_ci

            tr_loss = l*criterion(calc_pred_T_ck_ci, pred_T_ck_ci)
            comb_trans_loss.append(tr_loss)     
  if len(comb_trans_loss)>=10:
    return sum(comb_trans_loss)/len(comb_trans_loss)
  else:
    return 0.0               

In [35]:
def training_loop(mod1, optimz, img_src_tsr, img_temp_tsr, T_rel_tsr, Q_rel_tsr, flags_tsr, gflag, T1_tsr, 
                  Q1_tsr, T2_tsr, Q2_tsr, criterion):
    
    pose_pred = mod1(img_src_tsr, img_temp_tsr)
    pred_Q_rel_ck_ci_tsr = pose_pred[0]
    pred_T_rel_ck_ci_tsr = pose_pred[1]
    loss = loss_rot = loss_trns = 0.0

    ### Compute data loss
    loss_dat = compute_pose_loss360(pred_Q_rel_ck_ci_tsr, pred_T_rel_ck_ci_tsr, flags_tsr, 
                                     Q_rel_tsr, T_rel_tsr, criterion=criterion)
    
    loss_rot = loss_rot + loss_dat[0]
    loss_trns = loss_trns + loss_dat[1]

    ### Compute the geometric constraint loss --- the 1st argument is the number on intermediate poses
    ### The last argument is the number of random samples as specified in the paper
    g1_loss = composed_rot_loss(1, Q1_tsr, Q2_tsr, Q_rel_tsr, pred_Q_rel_ck_ci_tsr, gflag, criterion=criterion)
    loss_rot = loss_rot + g1_loss
    g2_loss = composed_rot_loss(2, Q1_tsr, Q2_tsr, Q_rel_tsr, pred_Q_rel_ck_ci_tsr, gflag, criterion=criterion, Nsel=20)
    loss_rot = loss_rot + g2_loss
    g3_loss = composed_rot_loss(3, Q1_tsr, Q2_tsr, Q_rel_tsr, pred_Q_rel_ck_ci_tsr, gflag, criterion=criterion, Nsel=20) 
    loss_rot = loss_rot + g3_loss
    loss = loss + loss_rot
    tr1_loss = composed_trans_loss(Q1_tsr, Q2_tsr, T_rel_tsr, Q_rel_tsr, pred_Q_rel_ck_ci_tsr, 
                                   pred_T_rel_ck_ci_tsr, T1_tsr, T2_tsr, gflag, criterion=criterion)
    loss_trns = loss_trns + tr1_loss
    loss = loss + loss_trns

    
    if torch.is_tensor(loss):   
        optimz.zero_grad() 
        loss.backward()
        optimz.step()   

    with torch.no_grad():
        pose_loses = [loss_dat[0], loss_dat[1], g1_loss, g2_loss, g3_loss, tr1_loss, loss_rot, loss_trns, loss]
        for i, rot_loss in enumerate(pose_loses):
            if torch.is_tensor(rot_loss):
                pose_loses[i] = rot_loss.item()
            else:
                pose_loses[i] = None
        return pose_loses

In [22]:
def plot_progress(comb_loss):  
    clear_output(wait=True)
    plt.figure(figsize=(18, 5))
    for i in range(len(comb_loss[0])):
        plt.subplot(1, len(comb_loss[0]), i+1)
        l_i = []
        for loss_k in comb_loss:
            if loss_k[i] is not None:
                l_i.append(loss_k[i]) 
        if len(l_i)>1:
            plt.plot(l_i, label='\nCur. loss: '+str(round(l_i[-1],5)))
            plt.xlabel('Epoch')
            plt.legend()
    plt.suptitle('Loss History')

In [None]:
criterion1 = nn.SmoothL1Loss().to(device)
mod1 = SiamesePoseNet3b_dec()
mod1.load_state_dict(torch.load('./resnet_backbone.pth'))
mod1 = mod1.to(device).train()
optimz = optim.Adam([
    {'params': mod1.parameters(), 'lr': 1e-4},
])

In [24]:
def reprep(all_calc_p, all_calc_g):
    pr, pp, py = [], [], []
    gr, gp, gy = [], [], []
    for p, g in zip(all_calc_p, all_calc_g):
        pr.extend(p[:, 0].tolist()), pp.extend(p[:, 1].tolist()), py.extend(p[:, 2].tolist())
        gr.extend(g[:, 0].tolist()), gp.extend(g[:, 1].tolist()), gy.extend(g[:, 2].tolist())
    return (pr, gr), (pp, gp), (py, gy)

In [25]:
def repcep(dimx, dimy, dimz, nval):
    er_r1 = calc_cep_err(np.array(dimx[0])[-nval:], np.array(dimx[-1])[-nval:])
    er_p1 = calc_cep_err(np.array(dimy[0])[-nval:], np.array(dimy[-1])[-nval:])
    er_y1 = calc_cep_err(np.array(dimz[0])[-nval:], np.array(dimz[-1])[-nval:])
    rll_err = [er_r1]
    ptc_err = [er_p1]
    yw_err = [er_y1]
    all_err = [rll_err, ptc_err, yw_err]
    return all_err

In [26]:
%matplotlib inline
def plot_global_est2(all_calc, all_pred, all_gt, all_calc_t, all_pred_t, all_gt_t, 
                     all_rel_rot_pd, all_rel_rot_gt, all_rel_tr_pd, all_rel_tr_gt, nval2, 
                     all_global_rot_pd, all_global_rot_gt, all_global_t_pd, all_global_t_gt, 
                     nval3, nval4, c_all_gt_t_rel, c_all_pd_t_rel):
    
    if len(all_rel_rot_pd[-1])>=1 and len(all_global_rot_pd[-1])>=1:
        rot_rell = reprep(all_rel_rot_pd, all_rel_rot_gt)
        gb_rot = reprep(all_global_rot_pd, all_global_rot_gt)
        gb_trans = reprep(all_global_t_pd, all_global_t_gt)
        rot_err_rell = repcep(rot_rell[0], rot_rell[1], rot_rell[2], nval2)
        gb_rot_err = repcep(gb_rot[0], gb_rot[1], gb_rot[2], nval3)
        gb_trn_err = repcep(gb_trans[0], gb_trans[1], gb_trans[2], nval3)
        trel_pd_gt = reprep(c_all_pd_t_rel, c_all_gt_t_rel)
        trel_err = repcep(trel_pd_gt[0], trel_pd_gt[1], trel_pd_gt[2], nval4)
        
        ms = 3
        clear_output(wait=True)  
        dirs = ['X', 'Y', 'Z']
        plt.figure(figsize=(18, 5))
        for i in range(3):
                plt.subplot(1, 3, i+1)
                plt.plot(rot_rell[i][-1], '--*', markersize=ms, label='GT_rel_rot_'+dirs[i]+' deg')
                plt.plot(rot_rell[i][0], '--*', markersize=ms, label='pred_rel_rot_'+dirs[i]+
                        '\nCEP: '+str(round(rot_err_rell[i][0],3))+' deg.')
                plt.legend()
        plt.suptitle('Relative Rotation Estimates')
        plt.savefig('Relative_Rotation_Estimatesc_128_nicp_nsup4_sb_vit.png')
        
        plt.figure(figsize=(18, 5))
        for i in range(3):
                plt.subplot(1, 3, i+1)
                plt.plot(trel_pd_gt[i][-1], '--*', markersize=ms, label='GT_rel_trans_'+dirs[i]+' m')
                plt.plot(trel_pd_gt[i][0], '--*', markersize=ms, label='calc_rel_trans_'+dirs[i]+
                        '\nCEP: '+str(round(trel_err[i][0],3))+' m')
                plt.legend()
        plt.suptitle('Computed Relative Translation Estimates')
        
        plt.figure(figsize=(18, 5))
        for i in range(3):
                plt.subplot(1, 3, i+1)
                plt.plot(gb_rot[i][-1], '--*', markersize=ms, label='GT_rot_'+dirs[i]+' deg')
                plt.plot(gb_rot[i][0], '--*', markersize=ms, label='pred_rot_'+dirs[i]+
                        '\nCEP: '+str(round(gb_rot_err[i][0],3))+' deg.')
                plt.legend()
        plt.suptitle('Global Rotation Estimates')

        plt.figure(figsize=(18, 5))
        for i in range(3):
                plt.subplot(1, 3, i+1)
                plt.plot(gb_trans[i][-1], '--*', markersize=ms, label='GT_trans_'+dirs[i]+' m')
                plt.plot(gb_trans[i][0], '--*', markersize=ms, label='pred_trans_'+dirs[i]+
                        '\nCEP: '+str(round(gb_trn_err[i][0],3))+' m')
                plt.legend()
        plt.suptitle('Global Translation Estimates')
        plt.show()

In [27]:
def valid_rots(qv, thr=25, nu=True):
    """qv is Nx4 numpy array of quartenion in qw, qx, qy, qz"""
    ind1 = np.arange(len(qv)) # 
    qx, qy, qz, qw = qv[:,1].reshape(-1,1), qv[:,2].reshape(-1,1), qv[:,3].reshape(-1,1), qv[:,0].reshape(-1,1)
    pred_q_ck_ci = np.hstack((qx, qy, qz, qw))
    l2_norm = np.linalg.norm(pred_q_ck_ci, axis=1, keepdims=True)
    pred_q_ck_ci = pred_q_ck_ci/l2_norm
    pred_eul_ck_ci = R.from_quat(pred_q_ck_ci).as_euler('xyz', degrees=True) #Nx3 euler angles
    pred_R_ck_ci = R.from_quat(pred_q_ck_ci).as_matrix()
    pred_R_ci_ck = pred_R_ck_ci.transpose(0,2,1)
    p_ind = np.all(np.abs(pred_eul_ck_ci) <= thr, axis=1).tolist()
    
    if nu:
        pred_eul_ck_ci = pred_eul_ck_ci[p_ind, :]
        pred_R_ci_ck = pred_R_ci_ck[p_ind, :, :]
    return p_ind, pred_R_ci_ck, pred_eul_ck_ci

In [28]:
def valid_trans(tv, thr=2.5):
    p_ind = np.all(np.abs(tv)<= thr, axis=1).tolist()
    return p_ind

In [29]:
def quaternion_to_matrix2(qv):
    qx, qy, qz, qw = qv[:,1].reshape(-1,1), qv[:,2].reshape(-1,1), qv[:,3].reshape(-1,1), qv[:,0].reshape(-1,1)
    pred_q_ck_ci = np.hstack((qx, qy, qz, qw))
    l2_norm = np.linalg.norm(pred_q_ck_ci, axis=1, keepdims=True)
    pred_q_ck_ci = pred_q_ck_ci/l2_norm
    pred_R_ck_ci = R.from_quat(pred_q_ck_ci).as_matrix()
    pred_R_ci_ck = pred_R_ck_ci.transpose(0,2,1)
    return pred_R_ci_ck

In [30]:
def euler_to_rotation_matrix2(qv):
    qx, qy, qz = qv[:,0].reshape(-1,1), qv[:,1].reshape(-1,1), qv[:,2].reshape(-1,1)
    pred_q_ck_ci = np.hstack((qx, qy, qz))
    pred_R_ck_ci = R.from_euler('xyz', pred_q_ck_ci, degrees=True).as_matrix()
    pred_R_ci_ck = pred_R_ck_ci.transpose(0,2,1)
    return pred_R_ci_ck

In [31]:
def global_eval2(rgbd_tsr, p_gt, mod1):
    modn = mod1.eval()
    with torch.no_grad():
        gt_T1_tsr, gt_Q1_tsr, gt_T2_tsr, gt_Q2_tsr, gt_T_rel_k_i_tsr, gt_Q_rel_ck_ci_tsr, flags = p_gt
        rgbd1_tsr_src, rgbd2_tsr_temp = rgbd_tsr
        f1_rgb = modn.model(rgbd1_tsr_src) 
        f2_rgb = modn.model(rgbd2_tsr_temp) 
        
        B1, D1 = f1_rgb.shape
        B2, D2 = f2_rgb.shape
        f1_rgb = f1_rgb.unsqueeze(1)  # Shape: B1 x 1 x D1
        f2_rgb = f2_rgb.unsqueeze(0)  # Shape: 1 x B2 x D2
        out_prod = f1_rgb.expand(B1, B2, D1) - f2_rgb.expand(B1, B2, D2)
        p_wxyz = modn.lin4c(out_prod)
        p_xyz = modn.lin4d(out_prod)
        pred_Q_rel_ck_ci_tsr = p_wxyz
        pred_T_rel_ck_ci_tsr = p_xyz
        feas_temps = f2_rgb.squeeze(0).cpu().numpy()
        all_plt = []
        N1 = gt_Q1_tsr.shape[0]

        all_gt_trans, all_gt_rot = [], []
        all_pd_trans, all_pd_rot = [], []
        all_gt_t_rel, all_pd_t_rep_ind1l = [], []
        all_gt_rot_rel, all_pd_rot_rel = [], []
        all_pd_t_rel = []
        for i in range(N1):  
            pred_qi_ck_ci = pred_Q_rel_ck_ci_tsr[i, :, 0:4].cpu().numpy()
            p_ind1, pred_Ri_ci_ck, pred_eul_ck_ci = valid_rots(pred_qi_ck_ci, nu=False)
            if np.any(p_ind1):
                pred_Ri_ci_ck, pred_eul_ck_ci = pred_Ri_ci_ck[p_ind1, :, :], pred_eul_ck_ci[p_ind1, :]
                _, _, gt_rel_rot_i = valid_rots(gt_Q_rel_ck_ci_tsr[i, :, 0:4].cpu().numpy(), nu=False)
                gt_rel_rot_i = gt_rel_rot_i[p_ind1, :]
                feas_i = f1_rgb.squeeze(1)[i, :].cpu().numpy().reshape(1,-1)
                gt_R_g_ci = quat_to_R_matrix(gt_Q1_tsr[i,:]).cpu().numpy()
                gt_rpy_g_ci = clip_val(np.array(R.from_matrix(gt_R_g_ci).as_euler('xyz', degrees=True)))

                gt_Q2_tsr_i = gt_Q2_tsr[p_ind1, :].cpu().numpy()
                gt_R_g_ck = quaternion_to_matrix2(gt_Q2_tsr_i).transpose(0,2,1)
                pred_R_g_ci = np.matmul(pred_Ri_ci_ck.transpose(0,2,1), gt_R_g_ck)
                pred_rpy_g_ci = clip_val(np.array(R.from_matrix(pred_R_g_ci).as_euler('xyz', degrees=True)))

                if np.any(p_ind1):
                    feas_2 = feas_temps[p_ind1,:]
                    cs_sim = np.dot(feas_2, feas_i.T) / (np.linalg.norm(feas_i) * np.linalg.norm(feas_2, axis=1, keepdims=True))
                    wt_i = cs_sim/cs_sim.sum()
                    wt_i = wt_i.reshape(-1,1)
                    if len(pred_rpy_g_ci)>=1:
                        pred_rpy_bi_g_val = np.sum(wt_i*pred_rpy_g_ci, 0).flatten()
                        all_pd_rot.append(pred_rpy_bi_g_val)
                        all_gt_rot.append(gt_rpy_g_ci.flatten())
                    all_gt_rot_rel.append(gt_rel_rot_i)
                    all_pd_rot_rel.append(pred_eul_ck_ci)
        
        N2 = gt_T1_tsr.shape[0]
        for i in range(N2):  
            pred_T_rel_ck_ci_tsr = pred_T_rel_ck_ci_tsr[0:N2, :, :]
            pred_ti_ck_ci = pred_T_rel_ck_ci_tsr[i, :, 0:3].cpu().numpy()
            p_ind1 = valid_trans(pred_ti_ck_ci)
            if np.any(p_ind1):
                feas_i = f1_rgb.squeeze(1)[i, :].cpu().numpy().reshape(1,-1)
                gt_T_g_ci = gt_T1_tsr[i, :].flatten().cpu().numpy()
                pred_ti_ck_ci = pred_ti_ck_ci[p_ind1, :]
                gt_Q2_tsr_i = gt_Q2_tsr[p_ind1, :].cpu().numpy()
                gt_R_g_ck = quaternion_to_matrix2(gt_Q2_tsr_i).transpose(0,2,1)
                gt_T2_tsr_i = gt_T2_tsr[p_ind1, :]
                T_g_ck = gt_T2_tsr_i.cpu().numpy()
                pred_T_ck_ci = pred_ti_ck_ci[:,:,np.newaxis]
                gt_R_ck_g = gt_R_g_ck.transpose(0,2,1)
                pred_T_g_ci = T_g_ck + (np.matmul(gt_R_ck_g, pred_T_ck_ci)).squeeze(2)
                if np.any(p_ind1):
                    feas_2 = feas_temps[p_ind1,:]
                    cs_sim = np.dot(feas_2, feas_i.T) / (np.linalg.norm(feas_i) * np.linalg.norm(feas_2, axis=1, keepdims=True))
                    wt_i = cs_sim/cs_sim.sum()
                    wt_i = wt_i.reshape(-1,1)
                    p_pred_T_g_ci = pred_T_g_ci
                    if len(p_pred_T_g_ci)>=1:
                        pred_T_g_ci_val = np.sum(wt_i*p_pred_T_g_ci, 0).flatten()
                        all_pd_trans.append(pred_T_g_ci_val)
                        all_gt_trans.append(gt_T_g_ci)
                    gt_trel = gt_T_rel_k_i_tsr[i, :, 0:3][p_ind1, :].cpu().numpy()
                    all_gt_t_rel.append(gt_trel)
                    all_pd_t_rel.append(pred_ti_ck_ci)
        rots = (np.array(all_pd_rot), np.array(all_gt_rot))
        trans = (np.array(all_pd_trans), np.array(all_gt_trans)) 
        if len(all_gt_rot_rel)>=1 and len(all_gt_t_rel)>=1:
            all_gt_t_rel = np.vstack(all_gt_t_rel) 
            all_pd_t_rel = np.vstack(all_pd_t_rel) 
            all_gt_rot_rel = np.vstack(all_gt_rot_rel)  
            all_pd_rot_rel = np.vstack(all_pd_rot_rel)
            rel_rots = (all_pd_rot_rel, all_gt_rot_rel)   
            return rots, trans, all_gt_t_rel, all_pd_t_rel, rel_rots, all_plt

In [None]:
#TODO
## load the source images, their translations wrt G and quaternions denoting rotations from G to the camera frame
## If Gazebo is used these images and their poses can be subcribed to and obtained using ROS in realtime during training
## If the background contains diverse elements, it may be necessary to mask them using a pretrained model, as suggested in the paper.

path_to_src_imgs = './source_images/'
path_to_src_translations = './source_images/source_trans.npy'#'./...'
path_to_src_quaternions = './source_images/source_quat.npy'#'./...'
T1_list = np.load(path_to_src_translations).tolist()
Q1_list = np.load(path_to_src_quaternions).tolist()
img1_list = []
for i in range(len(T1_list)):
    img1_list.append(cv2.imread(path_to_src_imgs+'img_'+str(i)+'.png'))

In [None]:
#TODO
## load the template images, their translations wrt G and quaternions denoting rotations from G to the camera frame
## If Gazebo is used, it's recommended to first capture and store the template images and poses, then load them during training and deployment..
path_to_temp_imgs = './template_images/' #'./...'
path_to_temp_translations = './template_images/temp_poseTrans_.npy'#'./...'
path_to_temp_quaternions = './template_images/temp_poseQuat_.npy'#'./...'
T2_list = np.load(path_to_temp_translations).tolist()
Q2_list = np.load(path_to_temp_quaternions).tolist()
img2_list = []
for i in range(len(T2_list)):
    img2_list.append(cv2.imread(path_to_temp_imgs+'temp_'+str(i)+'.png'))

In [None]:
prev_loss = 1000000
comb_loss = []
all_global_ests = []
all_calc_r, all_pred_r, all_gt_r = [], [], []
all_calc_t, all_pred_t, all_gt_t = [], [], []
all_rel_rot_pd, all_rel_rot_gt, all_rel_tr_pd, all_rel_tr_gt = [], [], [], []
all_global_rot_pd, all_global_rot_gt, all_global_t_pd, all_global_t_gt = [], [], [], []
a_nval2 = 0
a_nval3 = 0
a_nval4 = 0
c_all_gt_t_rel, c_all_pd_t_rel = [], []
b_size_src = 20 ## set according to your GPU capacity
b_size_temp = 20 ## set according to your GPU capacity
nepochs = 500
for epoch in range(nepochs):
    b_inds_src = random.sample(range(len(T1_list)), b_size_src)
    b_inds_temp = random.sample(range(len(T2_list)), b_size_temp)
    T1_list_bs = [T1_list[i] for i in b_inds_src]
    Q1_list_bs = [Q1_list[i] for i in b_inds_src]
    img1_list_bs = [img1_list[i] for i in b_inds_src]

    T2_list_bs = [T2_list[i] for i in b_inds_temp]
    Q2_list_bs = [Q2_list[i] for i in b_inds_temp]
    img2_list_bs = [img2_list[i] for i in b_inds_temp]
    
    out1 = cal_gt_relpose_360(T1_list_bs, Q1_list_bs, T2_list_bs, Q2_list_bs)
    T_rel_tsr, Q_rel_tsr, flags_tsr = out1
    gflag = geom_cstr_comb_flags360(T1_list_bs, Q1_list_bs, T2_list_bs, Q2_list_bs)
    img_src_tsr, img_temp_tsr = prep_inputs3(img1_list_bs, img2_list_bs)
    T1_tsr, Q1_tsr = torch.tensor(T1_list_bs).float().to(device), torch.tensor(Q1_list_bs).float().to(device)
    T2_tsr, Q2_tsr = torch.tensor(T2_list_bs).float().to(device), torch.tensor(Q2_list_bs).float().to(device)
    out2 = training_loop(mod1, optimz, img_src_tsr, img_temp_tsr, T_rel_tsr, Q_rel_tsr, 
                         flags_tsr, gflag, T1_tsr, Q1_tsr, T2_tsr, Q2_tsr, criterion1)

    if out2 is not None:
        comb_loss.append(out2)
        if out2[-1] < prev_loss:
            torch.save(mod1.state_dict(), './gbdtposenet.pth')
            prev_loss = out2[-1]
        
    if epoch%10==0 and len(comb_loss)>=1:
        plot_progress(comb_loss)

    run_eval = True
    if epoch%15==0 and run_eval:
        b_inds_src = random.sample(range(len(T1_list)), b_size_src)
        b_inds_temp = random.sample(range(len(T1_list)), b_size_src)
        T1_list_bs_test = [T1_list[i] for i in b_inds_src]
        Q1_list_bs_test = [Q1_list[i] for i in b_inds_src]
        img1_list_bs_test = [img1_list[i] for i in b_inds_src]
        
        T2_list_bs_test = [T2_list[i] for i in b_inds_temp]
        Q2_list_bs_test = [Q2_list[i] for i in b_inds_temp]
        img2_list_bs_test = [img2_list[i] for i in b_inds_temp]


        out3 = cal_gt_relpose_360(T1_list_bs_test, Q1_list_bs_test, T2_list_bs_test, Q2_list_bs_test)
        T_rel_tsr_test, Q_rel_tsr_test, flags_tsr_test = out3
        rgbd_tsr_test = prep_inputs3(img1_list_bs_test, img2_list_bs_test)
        T1_tsr_test, Q1_tsr_test = torch.tensor(T1_list_bs_test).float().to(device), torch.tensor(Q1_list_bs_test).float().to(device)
        T2_tsr_test, Q2_tsr_test = torch.tensor(T2_list_bs_test).float().to(device), torch.tensor(Q2_list_bs_test).float().to(device)

        p_gt = (T1_tsr_test, Q1_tsr_test, T2_tsr_test, Q2_tsr_test, T_rel_tsr_test, Q_rel_tsr_test, flags_tsr_test)
        out4 = global_eval2(rgbd_tsr_test, p_gt, mod1)

        if out4 is not None:
            rots_gb, trans_gb, all_gt_t_rel, all_pd_t_rel, rots_m, all_plt = out4
            all_rel_rot_pd.append(rots_m[0]), all_rel_rot_gt.append(rots_m[-1])
            all_global_rot_pd.append(rots_gb[0]), all_global_rot_gt.append(rots_gb[-1])
            all_global_t_pd.append(trans_gb[0]), all_global_t_gt.append(trans_gb[-1])
            c_all_gt_t_rel.append(all_gt_t_rel), c_all_pd_t_rel.append(all_pd_t_rel)
        
            nval2 = rots_m[0].shape[0]
            nval3 = rots_gb[0].shape[0]
            nval4 = all_gt_t_rel.shape[0]
            a_nval2 += nval2
            a_nval3 += nval3
            a_nval4 += nval4
            plot_global_est2(all_calc_r, all_pred_r, all_gt_r, all_calc_t, all_pred_t, all_gt_t, 
                                all_rel_rot_pd, all_rel_rot_gt, all_rel_tr_pd, all_rel_tr_gt, a_nval2, 
                                all_global_rot_pd, all_global_rot_gt, all_global_t_pd, all_global_t_gt, 
                                a_nval3, a_nval4, c_all_gt_t_rel, c_all_pd_t_rel)