In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import cv2
import random
import matplotlib.pyplot as plt
from IPython.display import clear_output
import math
import rospy
from scipy.spatial.transform import Rotation as R
import torchvision.models as models
import os

In [None]:
rospy.init_node("gbdtpose")

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class ExtendedResidualBlock2(nn.Module):
    def __init__(self, in_features, out_features, n_layers=5):
        super(ExtendedResidualBlock2, self).__init__()
        self.n_layers = n_layers
        self.fc1 = nn.Linear(in_features, out_features)
        self.shortcut = nn.Sequential()
        if in_features != out_features:
            self.shortcut = nn.Sequential(
                nn.Linear(in_features, out_features))

        self.fcn_dict = nn.ModuleDict()
        self.fcn_dict2 = nn.ModuleDict()
        for i in range(n_layers):
            self.fcn_dict[str(i)] = nn.Sequential(nn.Linear(out_features, out_features))
         
    def forward(self, x):
        out = F.tanh(self.fc1(x)) 
        for i in range(self.n_layers):
            out = self.fcn_dict[str(i)](out)
            if (i+1)%3==0:
                out = out + self.shortcut(x)
        out = out + self.shortcut(x)
        return out

class EfficientNetwork2(nn.Module):
    def __init__(self, inp, out, n_layers=5):
        super(EfficientNetwork2, self).__init__()
        self.input_layer = nn.Linear(inp, 1024)
        self.residual_block1 = ExtendedResidualBlock2(1024, 512, n_layers)
        self.residual_block2 = ExtendedResidualBlock2(512, 256, n_layers)
        self.residual_block3 = ExtendedResidualBlock2(256, 128, n_layers)
        self.residual_block4 = ExtendedResidualBlock2(128, 64, n_layers)
        self.residual_block5 = ExtendedResidualBlock2(64, 32, n_layers)
        self.fc3 = nn.Linear(32, 16)
        self.output_layer = nn.Linear(16, out)

    def forward(self, x):
        out = self.input_layer(x)
        out = self.residual_block1(out)
        out = self.residual_block2(out)
        out = self.residual_block3(out)
        out = self.residual_block4(out)
        out = self.residual_block5(out)
        out = self.fc3(out)
        out = self.output_layer(out)
        return out

In [None]:
class ModResNet2(nn.Module):
    def __init__(self, in_chans, out):
        super(ModResNet2, self).__init__()
        original_model = models.resnet101(pretrained=True)
        original_model.conv1 = nn.Conv2d(
                    in_channels=in_chans,  # Change from 3 to 1 to accept grayscale images
                    out_channels=original_model.conv1.out_channels,
                    kernel_size=original_model.conv1.kernel_size,
                    stride=original_model.conv1.stride,
                    padding=original_model.conv1.padding,
                    bias=original_model.conv1.bias)
        
        self.features = nn.Sequential(
            original_model.conv1,
            original_model.bn1,
            original_model.relu,
            original_model.maxpool,
            original_model.layer1,
            original_model.layer2,
            original_model.layer3,
            original_model.layer4
        )
        self.avgpool = original_model.avgpool
        num_features = original_model.fc.in_features
        num_out_feas = out
        original_model.fc = nn.Linear(num_features, num_out_feas)
        self.fc = original_model.fc
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        out_fc = self.fc(x)
        return out_fc

In [None]:
class SiamesePoseNet3b_dec(nn.Module):
    def __init__(self): 
        super(SiamesePoseNet3b_dec, self).__init__()
        self.model = ModResNet2(1,512)
        self.lin4c = EfficientNetwork2(512, 4, 2)     
        self.lin4d = EfficientNetwork2(512, 3, 2)
    def forward(self, rgbd1, rgbd2):
        f1_rgb, f2_rgb = self.model(rgbd1), self.model(rgbd2)
        B1, D1 = f1_rgb.shape
        B2, D2 = f2_rgb.shape
        f1_rgb = f1_rgb.unsqueeze(1)
        f2_rgb = f2_rgb.unsqueeze(0)
        out_prod = f1_rgb.expand(B1, B2, D1) - f2_rgb.expand(B1, B2, D2)
        p_wxyz = self.lin4c(out_prod)
        p_xyz = self.lin4d(out_prod)
        return p_wxyz, p_xyz

In [None]:
mod1_e = SiamesePoseNet3b_dec().to(device)
mod1_e.load_state_dict(torch.load('./gbdtposenet.pth'))
mod1_e = mod1_e.eval()

In [None]:
def comb_rgbd_data(img, sz=300):
    """img0 is gray scale image
       If using ViT backbone, set sz=224"""
    img = np.array(img)
    img = img.astype('float32')
    img = img/255.0 
    img = F.interpolate(torch.tensor(img).unsqueeze(0).unsqueeze(0), size=(sz,sz), mode='bilinear', align_corners=False)
    return img


In [None]:
def prep_rgb_inputs(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    rgbd1_tsr = comb_rgbd_data(img).float()
    return rgbd1_tsr

In [None]:
def create_temp_feas(mod1, img_list2):
    ### This function returns the features of all template images offline
    with torch.no_grad():
        all_feas = []
        for i in range(len(img_list2)):
            rgbd2_i = prep_rgb_inputs(img_list2[i])
            f_temp_i = mod1.model(rgbd2_i.to(device))
            all_feas.append(f_temp_i) 
        all_feas_tnsr = torch.vstack(all_feas) 
    return all_feas_tnsr

In [None]:
def quaternion_to_matrix2(qv):
    qx, qy, qz, qw = qv[:,1].reshape(-1,1), qv[:,2].reshape(-1,1), qv[:,3].reshape(-1,1), qv[:,0].reshape(-1,1)
    pred_q_ck_ci = np.hstack((qx, qy, qz, qw))
    l2_norm = np.linalg.norm(pred_q_ck_ci, axis=1, keepdims=True)
    pred_q_ck_ci = pred_q_ck_ci/l2_norm
    pred_R_ck_ci = R.from_quat(pred_q_ck_ci).as_matrix()
    pred_R_ci_ck = pred_R_ck_ci.transpose(0,2,1)
    return pred_R_ci_ck

In [None]:
def valid_rots(qv, thr=25, nu=False):
    """qv is Nx4 numpy array of quartenion in qw, qx, qy, qz"""
    ind1 = np.arange(len(qv)) # 
    qx, qy, qz, qw = qv[:,1].reshape(-1,1), qv[:,2].reshape(-1,1), qv[:,3].reshape(-1,1), qv[:,0].reshape(-1,1)
    pred_q_ck_ci = np.hstack((qx, qy, qz, qw))
    l2_norm = np.linalg.norm(pred_q_ck_ci, axis=1, keepdims=True)
    pred_q_ck_ci = pred_q_ck_ci/l2_norm
    pred_eul_ck_ci = R.from_quat(pred_q_ck_ci).as_euler('xyz', degrees=True) #Nx3 euler angles
    pred_R_ck_ci = R.from_quat(pred_q_ck_ci).as_matrix()
    pred_R_ci_ck = pred_R_ck_ci.transpose(0,2,1)
    p_ind = np.all(np.abs(pred_eul_ck_ci) <= thr, axis=1)
    if nu:
        pred_eul_ck_ci = pred_eul_ck_ci[p_ind, :]
        pred_R_ci_ck = pred_R_ci_ck[p_ind, :, :]
    return p_ind, pred_R_ci_ck, pred_eul_ck_ci

In [None]:
def clip_val(val):
    id1 = val >= 90
    id2 = val <= -90
    val[id1] = 180 - val[id1]
    val[id2] = -180 - val[id2]
    return val

In [None]:
global R_c_b ## If Gazebo is used, R_c_b is the rotation matrix from the camera to body frame of the UAV
roll_c_u, pitch_c_u, yaw_c_u = -(3.141592/2), 0, 3.141592/2
rx = np.matrix([[1, 0, 0],[0, math.cos(roll_c_u),-math.sin(roll_c_u)], [0, math.sin(roll_c_u), math.cos(roll_c_u)]])
ry = np.matrix([[math.cos(pitch_c_u), 0, math.sin(pitch_c_u)],[0, 1, 0],[-math.sin(pitch_c_u), 0, math.cos(pitch_c_u)]])
rz = np.matrix([[math.cos(yaw_c_u), -math.sin(yaw_c_u), 0],[math.sin(yaw_c_u), math.cos(yaw_c_u), 0],[0, 0, 1]])
R_c_b = np.matmul(np.matmul(rz,ry),rx)

In [None]:
#TODO
## load the template images, their translations wrt G and quaternions denoting rotations from G to the camera frame
path_to_temp_imgs = './...'
path_to_temp_translations = './...'
path_to_temp_quaternions = './...'
T2_list = np.load(path_to_temp_translations).tolist()
Q2_list = np.load(path_to_temp_quaternions).tolist()
img2_list = []
for i in range(len(T2_list)):
    img2_list.append(cv2.imread(path_to_temp_imgs+'temp_'+str(i)+'.png'))

In [None]:
out_temps = create_temp_feas(mod1_e, img2_list)
out_temps.shape

In [None]:
def softmax(x):
    """This function computes the softmax score of a source image pose wrt all template images for weighting the pose proposals"""
    x_shifted = x - np.max(x, axis=0, keepdims=True)
    exp_x = np.exp(x_shifted)
    return exp_x / np.sum(exp_x, axis=0, keepdims=True)

In [None]:
def valid_trans(tv, thr=2.5):
    p_ind = np.all(np.abs(tv)<= thr, axis=1).tolist()
    return p_ind

In [None]:
def global_eval360(rgbd1_tsr, p_gt, mod1):
    modn = mod1.eval()
    gt_T2_tsr, gt_Q2_tsr = p_gt
    f1_rgb = modn.model(rgbd1_tsr.to(device))
    f2_rgb = out_temps
    B1, D1 = f1_rgb.shape
    B2, D2 = f2_rgb.shape
    f1_rgb = f1_rgb.unsqueeze(1)
    f2_rgb = f2_rgb.unsqueeze(0)
    out_prod = f1_rgb.expand(B1, B2, D1) - f2_rgb.expand(B1, B2, D2)
    pred_Q_rel_ck_ci_tsr = modn.lin4c(out_prod)
    pred_T_rel_ck_ci_tsr = modn.lin4d(out_prod)

    N1 = rgbd1_tsr.shape[0] ## If multiple cameras are used at once, then N1 > 1 otherwise, N1 = 1
    feas_temps = f2_rgb.squeeze(0).detach().cpu().numpy()
    all_pose_comb = []
    all_pd_trans, all_pd_rot = [], []
    all_t_unc, all_r_unc = [], []
    for i in range(N1):
        pred_qi_ck_ci = pred_Q_rel_ck_ci_tsr[i, :, 0:4].detach().cpu().numpy()
        pred_ti_ck_ci = pred_T_rel_ck_ci_tsr[i, :, 0:3].detach().cpu().numpy()
        pind_1 = valid_trans(pred_ti_ck_ci)
        pind_2, pred_Ri_ci_ck, pred_eul_ck_ci = valid_rots(pred_qi_ck_ci)
        pind = (np.array(pind_1)*np.array(pind_2)).tolist()
        if np.any(pind):
            N = sum(pind)
            pred_Ri_ci_ck, pred_eul_ck_ci = pred_Ri_ci_ck[pind, :, :], pred_eul_ck_ci[pind, :]
            pred_ti_ck_ci = pred_ti_ck_ci[pind, :]
            gt_Q2_tsr_i = gt_Q2_tsr[pind, :].cpu().numpy()
            gt_R_g_ck = quaternion_to_matrix2(gt_Q2_tsr_i).transpose(0,2,1)
            gt_T2_tsr_i = gt_T2_tsr[pind, :]
            T_g_ck = gt_T2_tsr_i.cpu().numpy()
            pred_R_g_ci = np.matmul(pred_Ri_ci_ck.transpose(0,2,1), gt_R_g_ck)
            pred_T_ck_ci = pred_ti_ck_ci[:,:,np.newaxis]
            R_c_b_rep = np.stack([np.array(R_c_b)]*N)
            pred_R_bi_g = np.matmul(pred_R_g_ci.transpose(0,2,1), R_c_b_rep.transpose(0,2,1))
            p_pred_rpy_bi_g = clip_val(np.array(R.from_matrix(pred_R_bi_g).as_euler('xyz', degrees=True)))
            gt_R_ck_g = gt_R_g_ck.transpose(0,2,1)
            p_pred_T_g_ci = T_g_ck + (np.matmul(gt_R_ck_g, pred_T_ck_ci)).squeeze(2)
        
            feas_i = f1_rgb.squeeze(1)[i, :].detach().cpu().numpy().reshape(1,-1)
            feas_2 = feas_temps[pind,:]
            cs_sim = np.dot(feas_2, feas_i.T) / (np.linalg.norm(feas_i) * np.linalg.norm(feas_2, axis=1, keepdims=True))
            wt_i = softmax(cs_sim.reshape(-1,1))

            if len(p_pred_T_g_ci)>=1:
                all_pose_comb.append(np.hstack((p_pred_T_g_ci, p_pred_rpy_bi_g)))
                pred_T_g_ci_val = np.nansum(wt_i*p_pred_T_g_ci, 0).flatten()
                pred_rpy_bi_g_val = np.nansum(wt_i*p_pred_rpy_bi_g, 0).flatten()
                trans_unc = np.var(p_pred_T_g_ci, 0)
                rot_unc = np.var(p_pred_rpy_bi_g, axis=0)
                all_t_unc.append(trans_unc), all_r_unc.append(rot_unc)
                all_pd_rot.append(pred_rpy_bi_g_val)
                all_pd_trans.append(pred_T_g_ci_val)

    if len(all_pd_rot)>=1:
        rots = np.array(all_pd_rot)
        trans = np.array(all_pd_trans)
        rots_unc, tran_unc = np.array(all_r_unc), np.array(all_t_unc)
        return rots, trans, rots_unc, tran_unc

In [None]:
global T2_main_tsr, Q2_main_tsr
T2_main_tsr = torch.tensor(T2_list).float().to(device)
Q2_main_tsr = torch.tensor(Q2_list).float().to(device)

In [None]:
def transform2enu(x_g, rot_g, t_unc_g, rot_unc_g):
    """T_enu_g is the tranformation from the frame G to the enu frame for drone navigation with Ardupilot"""
    R_g_enu = np.zeros((3,3)) 
    R_g_enu[0,1] = 1.0
    R_g_enu[1,0] = -1.0
    R_g_enu[2,2] = 1.0
    R_b_g = R.from_euler('xyz', rot_g, degrees=True).as_matrix()
    x_enu = R_g_enu@x_g.reshape(-1,1)
    quat_enu = R.from_matrix((R_g_enu@R_b_g).T).as_quat()
    cov_trans_enu = np.diag((np.abs(R_g_enu@t_unc_g.reshape(-1,1))).flatten())
    cov_rot_enu = np.diag((np.abs(R_g_enu@rot_unc_g.reshape(-1,1))).flatten())
    zeros_3x3 = np.zeros_like(cov_rot_enu)
    cov_enu = np.block([[cov_trans_enu, zeros_3x3],[zeros_3x3, cov_rot_enu]])
    return x_enu.flatten(), quat_enu.flatten(), cov_enu.flatten()

In [None]:
from sensor_msgs.msg import Image
from sensor_msgs.msg import CameraInfo
from rospy.numpy_msg import numpy_msg
from nav_msgs.msg import Odometry
from geometry_msgs.msg import PoseStamped
from geometry_msgs.msg import PoseWithCovarianceStamped

In [None]:
real_img = None    
def real_imcallback(msg):
    global real_img
    real_img = np.frombuffer(msg.data, dtype=np.uint8).reshape(msg.height, msg.width, -1)
## Subscribe to T265 fisheye camera

#TODO
camera_topic = "/camera/fisheye1/image_raw" # provide the camera topic name here
real_uav_img = rospy.Subscriber("/camera/fisheye1/image_raw", numpy_msg(Image), callback = real_imcallback)

use_ekf = False ## If integration with T265 is needed, set this parameter to True
if use_ekf:
    real_uav_odom = rospy.Publisher("/gbdtpose", Odometry, queue_size=10)
else:
    real_uav_odom = rospy.Publisher("/mavros/vision_pose/pose_cov", PoseWithCovarianceStamped, queue_size=10)

real_odom = PoseWithCovarianceStamped()
rate = rospy.Rate(30) ## Normal rate for transmitting poses to the UAV flight controller
while not(rospy.is_shutdown()):
    real_img = cv2.imread('./tempfisheye2_360_1/temp_0.png')
    if (real_img is not None):
        rgbd1_tsr_src = prep_rgb_inputs(real_img)    
        p_gt = (T2_main_tsr.to(device), Q2_main_tsr.to(device))   
        out3 = global_eval360(rgbd1_tsr_src, p_gt, mod1_e)
        rots, trans, rot_unc, t_unc = out3
        xyz_enu, quat_enu, cov_enu = transform2enu(trans.flatten(), rots.flatten(), t_unc.flatten(), rot_unc.flatten())         
        real_odom.header.stamp = rospy.Time.now()
        real_odom.pose.pose.position.x = xyz_enu[0]
        real_odom.pose.pose.position.y = xyz_enu[1]
        real_odom.pose.pose.position.z = xyz_enu[2]
        real_odom.pose.pose.orientation.x = quat_enu[0]
        real_odom.pose.pose.orientation.y = quat_enu[1]
        real_odom.pose.pose.orientation.z = quat_enu[2]
        real_odom.pose.pose.orientation.w = quat_enu[3]
        real_odom.pose.covariance = cov_enu.tolist()
        real_uav_odom.publish(real_odom)
    rate.sleep()