In [2]:
%load_ext autoreload
%autoreload
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from models.ResModel import ResModel
from models.Model import Model
from models.JointModel import JointModel
from losses_eval import Eval_SceneFlow_KITTI_Train
from datasets.kitti_2015_train import KITTI_2015_MonoSceneFlow
from augmentations import Augmentation_Resize_Only
from collections import OrderedDict

class Args:
    cuda = True
    use_bn = False
    momentum = 0.9
    beta = 0.999
    weight_decay=0.0
    train_exp_mask=False
    train_census_mask=True
    encoder_name='resnet'
    disp_pts_w = 0.0
    flow_pts_w = 0.0
    sf_sm_w = 200
    pose_pts_w = 0.0
    disp_sm_w = 0.2
    do_pose_c2f = False
    ssim_w = 0.85
    disp_smooth_w = 0.1
    mask_reg_w = 0.0
    num_examples = 200
    static_cons_w = 0.0
    mask_cons_w = 0.0
    mask_sm_w = 0.0
    batch_size=1
    flow_diff_thresh=1e-3
    pt_encoder=False
    num_scales = 4
    evaluation=True
    use_disp_min=False
    flow_reduce_mode="sum"
    apply_flow_mask = False
    apply_mask = True 
    flow_sm_w = 200
    use_static_mask=False
    disp_lr_w = 0.1

args = Args()

model = JointModel(args).cuda()
loss = Eval_SceneFlow_KITTI_Train(args).cuda()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {num_params} learnable parameters")

state_dict = torch.load('pretrained/bottleneck_latest.ckpt')['model']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model = model.eval()

augmentation = Augmentation_Resize_Only(args).cuda()

val_dataset = KITTI_2015_MonoSceneFlow(args, data_root='/external/datasets/kitti2015/')
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=args.batch_size, pin_memory=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The model has 22697562 learnable parameters


In [8]:
%autoreload
final_loss_dict = {}
for i, data in enumerate(tqdm(val_loader)):
    with torch.no_grad():
        input_keys = list(filter(lambda x: "input" in x, data.keys()))
        target_keys = list(filter(lambda x: "target" in x, data.keys()))
        tensor_keys = input_keys + target_keys
        
        for k, v in data.items():
            if k in tensor_keys:
                data[k] = v.cuda(non_blocking=True)
                
        aug_data = augmentation(data)
        out = model(aug_data)
        loss_dict = loss(out, data)
        for k, v in loss_dict.items():
            if k not in final_loss_dict:
                final_loss_dict[k] = v
            else:
                final_loss_dict[k] += v
        break

  0%|          | 0/200 [00:01<?, ?it/s]


In [22]:
update = {}
for k, v in final_loss_dict.items():
    update[k] = (v * args.batch_size) / 200
update

{'d_abs': tensor(0.1521, device='cuda:0'),
 'd_sq': tensor(1.5342, device='cuda:0'),
 'd1': tensor(0.4295, device='cuda:0'),
 'acc1': tensor(0.8048, device='cuda:0'),
 'acc2': tensor(0.9453, device='cuda:0'),
 'acc3': tensor(0.9809, device='cuda:0'),
 'f_epe': tensor(11.9924, device='cuda:0'),
 'f1': tensor(0.3625, device='cuda:0'),
 'd2': tensor(0.5560, device='cuda:0'),
 'sf': tensor(0.6954, device='cuda:0')}