In [None]:
import sys
sys.path.insert(0, '/home/xp/stereo_toolbox/')

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import matplotlib.pyplot as plt
import argparse
import cv2

# auto reload modules
%load_ext autoreload
%autoreload 2

from stereo_toolbox.datasets import *
from stereo_toolbox.models import *
from stereo_toolbox.visualization import *

In [None]:
def show_left_gt_pred_error(left, gt, pred):
    left = left.cpu().numpy()
    gt = gt.cpu().numpy()
    pred = pred.cpu().numpy()

    colored_gt = colored_disparity_map_KITTI(gt)
    colored_gt = cv2.cvtColor(colored_gt, cv2.COLOR_BGR2RGB)

    colored_pred = colored_disparity_map_KITTI(pred, maxval=gt.max())
    colored_pred = cv2.cvtColor(colored_pred, cv2.COLOR_BGR2RGB)

    colored_error = colored_error_map_KITTI(pred, gt)
    colored_error = cv2.cvtColor(colored_error, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(24, 12))
    plt.subplot(1, 4, 1)
    plt.title('Left Image')
    plt.imshow(left.transpose(1, 2, 0))
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.title('Ground Truth')
    plt.imshow(colored_gt)
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.title('Predicted')
    plt.imshow(colored_pred)
    plt.axis('off')

    plt.subplot(1, 4, 4)
    plt.title('Error')
    plt.imshow(colored_error)
    plt.axis('off')

    plt.show()

In [None]:
## colored error map from KITTI
model = load_checkpoint_flexible(IGEVStereo(),
                                 '/home/xp/stereo_toolbox/stereo_toolbox/models/IGEVStereo/sceneflow.pth',
                                 )

model = model.cuda().eval()

dataset = KITTI2015_Dataset(split='train_all', training=False)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

for idx, data in enumerate(dataloader):
    left = data['left'].cuda()
    right = data['right'].cuda()
    gt = data['gt_disp'].cuda()
    raw_left = data['raw_left'].cuda()

    with torch.no_grad():
        pred = model(left, right)

    show_left_gt_pred_error(raw_left.squeeze(), gt.squeeze(), pred.squeeze())

    if idx > 4:
        break