In [3]:
%load_ext autoreload

In [23]:
%autoreload
import torch
import torch.nn as nn
from collections import OrderedDict
from models.SceneNetStereoJointIter import SceneNetStereoJointIter
from models.SceneNetMonoJointIter import SceneNetMonoJointIter
from datasets.kitti_2015_train import KITTI_2015_MonoSceneFlow_Full
from datasets.kitti_raw_monosf import KITTI_Odom_Test

from losses_eval import Eval_SceneFlow_KITTI_Test, Eval_SceneFlow_KITTI_Train
from losses import Loss_SceneFlow_SelfSup_JointIter

from augmentations import Augmentation_Resize_Only
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from tqdm import tqdm
from pprint import pprint

class Args:
    model_name='scenenet_joint_stereo_iter'
    cuda = True
    use_bn = True
    momentum = 0.9
    beta = 0.999
    weight_decay=0.0
    use_mask = True
    sf_lr_w = 0.0
    pose_lr_w = 0.0
    mask_lr_w = 1.0
    disp_lr_w = 1.0
    disp_pts_w = 0.0
    sf_pts_w = 0.2
    sf_sm_w = 200
    pose_sm_w = 200
    pose_pts_w = 0.2
    disp_sm_w = 0.2
    disp_smooth_w = 0.1
    mask_reg_w = 0.2
    encoder_name="pwc"
    static_cons_w = 1.0
    mask_cons_w = 0.2
    mask_sm_w = 0.1
    flow_diff_thresh=1e-3
    evaluation=True
    finetuning=False

args = Args()

model = SceneNetMonoJointIter(args).cuda()
state_dict = torch.load('pretrained/joint_iter/latest.ckpt')['model']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval()

dataset_09 = KITTI_Odom_Test(args, "/external/datasets/kitti_data_jpg/", "09")
dataloader_09 = DataLoader(dataset_09, shuffle=False, batch_size=1, pin_memory=True)
dataset_10 = KITTI_Odom_Test(args, "/external/datasets/kitti_data_jpg/", "10")
dataloader_10 = DataLoader(dataset_10, shuffle=False, batch_size=1, pin_memory=True)

aug = Augmentation_Resize_Only(args).cuda()

In [49]:
# from https://github.com/tinghuiz/SfMLearner
def compute_ate(gtruth_xyz, pred_xyz):
    alignment_error = pred_xyz - gtruth_xyz
    rmse = np.sqrt(sum(alignment_error ** 2)) / gtruth_xyz.shape[0]
    
    return rmse

ates = []
for i, data in enumerate(tqdm(dataloader_09)):
    with torch.no_grad():
        # Get input and target tensor keys
        input_keys = list(filter(lambda x: "input" in x, data.keys()))
        target_keys = list(filter(lambda x: "target" in x, data.keys()))
        tensor_keys = input_keys + target_keys
        
        # Possibly transfer to Cuda
        for k, v in data.items():
            if k in tensor_keys:
                data[k] = v.cuda(non_blocking=True)
                
        aug_data = aug(data)
        out = model(aug_data)
        pose_f = out['pose_f'][0][:, :3, 3].double().cpu()
        gt_t = data['target_pose'][:, :3, 3].cpu()
        ate = compute_ate(gt_t, pose_f)
        ates.append(np.array(ate))
        
        if i == 20:
            break

ates = np.array(ates)
print("\n   Trajectory error: {:0.3f}, std: {:0.3f}\n".format(np.mean(ates), np.std(ates)))

# save_path = os.path.join(opt.load_weights_folder, "poses.npy")
# np.save(save_path, pred_poses)
# print("-> Predictions saved to", save_path)

  1%|▏         | 20/1590 [00:03<05:10,  5.05it/s]

[array([0.02557536, 0.01086417, 0.28617973]), array([0.01843651, 0.00904032, 0.29097395]), array([0.01458911, 0.00690351, 0.29686111]), array([0.00952712, 0.00596559, 0.30431479]), array([0.00442635, 0.00465334, 0.31390454]), array([0.00053958, 0.00523094, 0.32177713]), array([0.0066841 , 0.00784034, 0.33625371]), array([0.01206659, 0.00937131, 0.3538369 ]), array([0.0139446 , 0.00723075, 0.35837967]), array([0.00935413, 0.00575072, 0.37183138]), array([0.00736904, 0.00458036, 0.38441855]), array([0.01021255, 0.00606851, 0.39332453]), array([0.02010878, 0.00835971, 0.40087475]), array([0.01955232, 0.00861912, 0.41565599]), array([0.01725722, 0.00872964, 0.42705677]), array([0.01036243, 0.00818809, 0.44109209]), array([0.00436428, 0.00656745, 0.45122663]), array([0.00533916, 0.00564975, 0.46720161]), array([0.00923828, 0.00611654, 0.48053615]), array([0.00452175, 0.00596867, 0.50123691]), array([0.00051625, 0.00685171, 0.50337788])]

   Trajectory error: 0.134, std: 0.182




