In [35]:
import sys
sys.path.append('/home/xp/stereo_toolbox/')

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import matplotlib.pyplot as plt

# auto reload modules
%load_ext autoreload
%autoreload 2

from datasets import *
from visualization import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
def show_figures(left, right, colored_disp, noc_mask, raw_left, raw_right):
    left, right = left.squeeze().cpu().numpy(), right.squeeze().cpu().numpy()
    noc_mask = noc_mask.squeeze().cpu().numpy()
    raw_left, raw_right = raw_left.squeeze().cpu().numpy(), raw_right.squeeze().cpu().numpy()

    left = (left - left.min()) / (left.max() - left.min())
    right = (right - right.min()) / (right.max() - right.min())

    plt.figure(figsize=(24, 8))
    plt.subplot(2, 3, 1)
    plt.title('Left Image')
    plt.imshow(left.transpose(1, 2, 0))
    plt.axis('off')

    plt.subplot(2, 3, 2)
    plt.title('Right Image')
    plt.imshow(right.transpose(1, 2, 0))
    plt.axis('off')

    plt.subplot(2, 3, 3)
    plt.title('Colored Disparity')
    plt.imshow(colored_disp)
    plt.axis('off')

    plt.subplot(2, 3, 4)
    plt.title('Raw Left Image')
    plt.imshow(raw_left.transpose(1, 2, 0))
    plt.axis('off')

    plt.subplot(2, 3, 5)
    plt.title('Raw Right Image')
    plt.imshow(raw_right.transpose(1, 2, 0))
    plt.axis('off')

    plt.subplot(2, 3, 6)
    plt.title('NOC Mask')
    plt.imshow((noc_mask * 255.0).astype(np.uint8), vmin=0, vmax=255, cmap='gray')
    plt.axis('off')

    plt.show()


In [None]:
## SceneFlow

splits = ['train_cleanpass', 'train_finalpass', 'test_cleanpass', 'test_finalpass']

for split in splits:
    for training in [True, False]:
        dataset = SceneFlow_Dataset(split=split, training=training, raw_data=True)
        dataloader = DataLoader(dataset, batch_size=3, shuffle=False, num_workers=4)
        

        for i, (left, right, disp, noc_mask, raw_left, raw_right) in enumerate(dataloader):
            print('split: ', split, ' training: ', training, ' samples: ', dataset.__len__(), ' left shape: ', left.shape)

            colored_disp = colored_disparity_map_Spectral_r(disp[0])
            colored_disp = cv2.cvtColor(colored_disp, cv2.COLOR_BGR2RGB)

            show_figures(left[0], right[0], colored_disp, noc_mask[0], raw_left[0], raw_right[0])
            
            break