##reference : https://github.com/cogsys-tuebingen/mobilestereonet

##load library

In [None]:
from mobilestereonet import *
import os
import random
import numpy as np
from PIL import Image
import copy

import torch
import torch.nn.parallel
import torch.optim as optim
import torch.nn as nn

from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
import re
import torchvision.transforms as transforms

import torch.utils.data
from torch.autograd import Variable, Function

from torch import Tensor

from skimage import io
import cv2

In [None]:
import warnings
warnings.filterwarnings("ignore")

##loss function

In [None]:
def model_loss(disp_ests, disp_gt, mask):
    weights = [0.5, 0.5, 0.7, 1.0]
    all_losses = []
    for disp_est, weight in zip(disp_ests, weights):
        all_losses.append(weight * F.smooth_l1_loss(disp_est[mask], disp_gt[mask], size_average=True))
    return sum(all_losses)

## sceneflow dataset

In [None]:
#download sample dataset

!wget https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlow/assets/Sampler.tar.gz

In [None]:
!tar zxvf /content/Sampler.tar.gz

In [None]:
def get_transform():
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)])

In [None]:
def read_all_lines(filename):
    with open(filename) as f:
        lines = [line.rstrip() for line in f.readlines()]
    return lines

In [None]:
def pfm_imread(filename):
    file = open(filename, 'rb')
    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().decode('utf-8').rstrip()
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0:  
        endian = '<'
        scale = -scale
    else:
        endian = '>'  

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data, scale

In [None]:
class SceneFlowDataset(Dataset):
    def __init__(self, datapath, list_filename, training):
        self.datapath = datapath
        self.left_filenames, self.right_filenames, self.disp_filenames = self.load_path(list_filename)
        self.training = training

    def load_path(self, list_filename):
        lines = read_all_lines(list_filename)
        splits = [line.split() for line in lines]
        left_images = [x[0] for x in splits]
        right_images = [x[1] for x in splits]
        disp_images = [x[2] for x in splits]
        return left_images, right_images, disp_images

    def load_image(self, filename):
        return Image.open(filename).convert('RGB')

    def load_disp(self, filename):
        data, scale = pfm_imread(filename)
        data = np.ascontiguousarray(data, dtype=np.float32)
        return data

    def __len__(self):
        return len(self.left_filenames)

    def __getitem__(self, index):
        left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index]))
        right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index]))
        disparity = self.load_disp(os.path.join(self.datapath, self.disp_filenames[index]))

        if self.training:
            w, h = left_img.size
            crop_w, crop_h = 512, 256

            x1 = random.randint(0, w - crop_w)
            y1 = random.randint(0, h - crop_h)

            # random crop
            left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
            right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
            disparity = disparity[y1:y1 + crop_h, x1:x1 + crop_w]

            # to tensor, normalize
            processed = get_transform()
            left_img = processed(left_img)
            right_img = processed(right_img)

            return {"left": left_img,
                    "right": right_img,
                    "disparity": disparity,
                    }
        else:
            w, h = left_img.size
            crop_w, crop_h = 960, 512

            left_img = left_img.crop((w - crop_w, h - crop_h, w, h))
            right_img = right_img.crop((w - crop_w, h - crop_h, w, h))
            disparity = disparity[h - crop_h:h, w - crop_w: w]

            processed = get_transform()
            left_img = processed(left_img)
            right_img = processed(right_img)

            return {"left": left_img,
                    "right": right_img,
                    "disparity": disparity,
                    "top_pad": 0,
                    "right_pad": 0,
                    "left_filename": self.left_filenames[index]}

##hyperparameter

In [None]:
#hyperparameter

seed = 1

datapath = "/content/"

trainlist = '/content/trainsamplelist.txt'

testlist = '/content/testsamplelist.txt' 

batch_size = 1

test_batch_size = 1

maxdisp = 192

lr=0.001

start_epoch = 0

epochs = 2

lrepochs = "1,2:2"

In [None]:
#set seeds
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#os.makedirs(args.logdir, exist_ok=True)

In [None]:
def adjust_learning_rate(optimizer, epoch, base_lr, lrepochs):
    splits = lrepochs.split(':')
    assert len(splits) == 2

    # parse the epochs to downscale the learning rate (before :)
    downscale_epochs = [int(eid_str) for eid_str in splits[0].split(',')]
    # parse downscale rate (after :)
    downscale_rate = float(splits[1])
    print("downscale epochs: {}, downscale rate: {}".format(downscale_epochs, downscale_rate))

    lr = base_lr
    for eid in downscale_epochs:
        if epoch >= eid:
            lr /= downscale_rate
        else:
            break
    print("setting learning rate to {}".format(lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

## create dataset

In [None]:
# dataset, dataloader
train_dataset = SceneFlowDataset(datapath, trainlist, True)
test_dataset = SceneFlowDataset(datapath, testlist, False)
TrainImgLoader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=8, drop_last=True)
TestImgLoader = DataLoader(test_dataset, test_batch_size, shuffle=False, num_workers=4, drop_last=False)

In [None]:
train_dataset[0]

{'disparity': array([[ 45.36875 ,  45.28208 ,  45.195408, ...,  16.93244 ,  16.952013,
          16.997654],
        [ 45.36269 ,  45.27602 ,  45.189346, ...,  16.911587,  16.944834,
          16.991657],
        [ 45.35663 ,  45.269962,  45.18329 , ...,  16.90439 ,  16.941797,
          16.970854],
        ...,
        [214.51234 , 214.5125  , 214.51254 , ..., 214.51263 , 214.51254 ,
         214.5125  ],
        [215.51253 , 215.51266 , 215.51257 , ..., 215.51247 , 215.51253 ,
         215.51253 ],
        [216.51242 , 216.51237 , 216.51218 , ..., 216.51222 , 216.51218 ,
         216.51228 ]], dtype=float32),
 'left': tensor([[[-1.6384, -1.6555, -1.6727,  ...,  2.2318,  2.2318,  2.2318],
          [-1.6727, -1.7069, -1.7069,  ...,  2.2318,  2.2318,  2.2318],
          [-1.6384, -1.6727, -1.6555,  ...,  2.2318,  2.2318,  2.2318],
          ...,
          [ 0.4851,  0.5193,  0.5364,  ...,  0.2282,  0.2111,  0.2282],
          [ 0.4679,  0.5364,  0.6049,  ...,  0.1597,  0.2453,  0.3138]

## create model

In [None]:
# model, optimizer
model = MSNet2D(maxdisp)
model = nn.DataParallel(model)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

## metric

In [None]:
def make_nograd_func(func):
    def wrapper(*f_args, **f_kwargs):
        with torch.no_grad():
            ret = func(*f_args, **f_kwargs)
        return ret

    return wrapper

In [None]:
def check_shape_for_metric_computation(*vars):
    assert isinstance(vars, tuple)
    for var in vars:
        assert len(var.size()) == 3
        assert var.size() == vars[0].size()


# a wrapper to compute metrics for each image individually
def compute_metric_for_each_image(metric_func):
    def wrapper(D_ests, D_gts, masks, *nargs):
        check_shape_for_metric_computation(D_ests, D_gts, masks)
        bn = D_gts.shape[0]  # batch size
        results = []  # a list to store results for each image
        # compute result one by one
        for idx in range(bn):
            # if tensor, then pick idx, else pass the same value
            cur_nargs = [x[idx] if isinstance(x, (Tensor, Variable)) else x for x in nargs]
            if masks[idx].float().mean() / (D_gts[idx] > 0).float().mean() < 0.1:
                print("masks[idx].float().mean() too small, skip")
            else:
                ret = metric_func(D_ests[idx], D_gts[idx], masks[idx], *cur_nargs)
                results.append(ret)
        if len(results) == 0:
            print("masks[idx].float().mean() too small for all images in this batch, return 0")
            return torch.tensor(0, dtype=torch.float32, device=D_gts.device)
        else:
            return torch.stack(results).mean()

    return wrapper


@make_nograd_func
@compute_metric_for_each_image
def D1_metric(D_est, D_gt, mask):
    D_est, D_gt = D_est[mask], D_gt[mask]
    E = torch.abs(D_gt - D_est)
    err_mask = (E > 3) & (E / D_gt.abs() > 0.05)
    return torch.mean(err_mask.float())


@make_nograd_func
@compute_metric_for_each_image
def Thres_metric(D_est, D_gt, mask, thres):
    assert isinstance(thres, (int, float))
    D_est, D_gt = D_est[mask], D_gt[mask]
    E = torch.abs(D_gt - D_est)
    err_mask = E > thres
    return torch.mean(err_mask.float())


@make_nograd_func
@compute_metric_for_each_image
def EPE_metric(D_est, D_gt, mask):
    D_est, D_gt = D_est[mask], D_gt[mask]
    return F.l1_loss(D_est, D_gt, size_average=True)

##training setting

In [None]:
def make_iterative_func(func):
    def wrapper(vars):
        if isinstance(vars, list):
            return [wrapper(x) for x in vars]
        elif isinstance(vars, tuple):
            return tuple([wrapper(x) for x in vars])
        elif isinstance(vars, dict):
            return {k: wrapper(v) for k, v in vars.items()}
        else:
            return func(vars)

    return wrapper

In [None]:
@make_iterative_func
def tensor2float(vars):
    if isinstance(vars, float):
        return vars
    elif isinstance(vars, torch.Tensor):
        return vars.data.item()
    else:
        raise NotImplementedError("invalid input type for tensor2float")

In [None]:
@make_iterative_func
def check_allfloat(vars):
    assert isinstance(vars, float)

In [None]:
class AverageMeterDict(object):
    def __init__(self):
        self.data = None
        self.count = 0

    def update(self, x):
        check_allfloat(x)
        self.count += 1
        if self.data is None:
            self.data = copy.deepcopy(x)
        else:
            for k1, v1 in x.items():
                if isinstance(v1, float):
                    self.data[k1] += v1
                elif isinstance(v1, tuple) or isinstance(v1, list):
                    for idx, v2 in enumerate(v1):
                        self.data[k1][idx] += v2
                else:
                    assert NotImplementedError("error input type for update AvgMeterDict")

    def mean(self):
        @make_iterative_func
        def get_mean(v):
            return v / float(self.count)

        return get_mean(self.data)

## train function

In [None]:
def train():
    best_checkpoint_loss = 100
    for epoch_idx in range(start_epoch, epochs):
        adjust_learning_rate(optimizer, epoch_idx, lr, lrepochs)

        # training
        for batch_idx, sample in enumerate(TrainImgLoader):
            global_step = len(TrainImgLoader) * epoch_idx + batch_idx
            #start_time = time.time()
            #do_summary = global_step % args.summary_freq == 0
            loss, scalar_outputs, image_outputs = train_sample(sample, compute_metrics=True)
            #if do_summary:
                #save_scalars(logger, 'train', scalar_outputs, global_step)
                #save_images(logger, 'train', image_outputs, global_step)
            del scalar_outputs, image_outputs
            print('Epoch {}/{}, Iter {}/{}, train loss = {:.3f}'.format(epoch_idx, epochs,
                                                                                       batch_idx,
                                                                                       len(TrainImgLoader), loss))
        # saving checkpoints
        #if (epoch_idx + 1) % args.save_freq == 0:
            #checkpoint_data = {'epoch': epoch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
            #torch.save(checkpoint_data, "{}/checkpoint_{:0>6}.ckpt".format(args.logdir, epoch_idx))
        #gc.collect()

        # testing
        avg_test_scalars = AverageMeterDict()
        for batch_idx, sample in enumerate(TestImgLoader):
            global_step = len(TestImgLoader) * epoch_idx + batch_idx
            #start_time = time.time()
            #do_summary = global_step % args.summary_freq == 0
            loss, scalar_outputs, image_outputs = test_sample(sample, compute_metrics=True)
            #if do_summary:
                #save_scalars(logger, 'test', scalar_outputs, global_step)
                #save_images(logger, 'test', image_outputs, global_step)
            avg_test_scalars.update(scalar_outputs)
            del scalar_outputs, image_outputs
            print('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}'.format(epoch_idx,epochs,
                                                                                     batch_idx,
                                                                                     len(TestImgLoader), loss))
        avg_test_scalars = avg_test_scalars.mean()

        #save_scalars(logger, 'fulltest', avg_test_scalars, len(TrainImgLoader) * (epoch_idx + 1))
        #print("avg_test_scalars", avg_test_scalars)

        # saving new best checkpoint
        #if avg_test_scalars['loss'] < best_checkpoint_loss:
            #best_checkpoint_loss = avg_test_scalars['loss']
            #print("Overwriting best checkpoint")
            #checkpoint_data = {'epoch': epoch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
            #torch.save(checkpoint_data, "{}/best.ckpt".format(args.logdir))

        #gc.collect()

In [None]:
# train one sample
def train_sample(sample, compute_metrics=False):
    model.train()

    imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity']
    imgL = imgL.cuda()
    imgR = imgR.cuda()
    disp_gt = disp_gt.cuda()

    optimizer.zero_grad()

    disp_ests = model(imgL, imgR)
    mask = (disp_gt < maxdisp) & (disp_gt > 0)
    loss = model_loss(disp_ests, disp_gt, mask)

    scalar_outputs = {"loss": loss}
    image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR}
    if compute_metrics:
        with torch.no_grad():
            #image_outputs["errormap"] = [disp_error_image_func.apply(disp_est, disp_gt) for disp_est in disp_ests]
            scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
            scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
            scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests]
            scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests]
            scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests]
    loss.backward()
    optimizer.step()

    return tensor2float(loss), tensor2float(scalar_outputs), image_outputs

In [None]:
# test one sample
@make_nograd_func
def test_sample(sample, compute_metrics=True):
    model.eval()

    imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity']
    imgL = imgL.cuda()
    imgR = imgR.cuda()
    disp_gt = disp_gt.cuda()

    disp_ests = model(imgL, imgR)
    mask = (disp_gt < maxdisp) & (disp_gt > 0)
    loss = model_loss(disp_ests, disp_gt, mask)

    scalar_outputs = {"loss": loss}
    image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR}

    scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
    scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
    scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests]
    scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests]
    scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests]

    #if compute_metrics:
        #image_outputs["errormap"] = [disp_error_image_func.apply(disp_est, disp_gt) for disp_est in disp_ests]

    return tensor2float(loss), tensor2float(scalar_outputs), image_outputs

In [None]:
train()

downscale epochs: [1, 2], downscale rate: 2.0
setting learning rate to 0.001
Epoch 0/2, Iter 0/9, train loss = 40.860
Epoch 0/2, Iter 1/9, train loss = 30.813
Epoch 0/2, Iter 2/9, train loss = 90.524
Epoch 0/2, Iter 3/9, train loss = 64.898
Epoch 0/2, Iter 4/9, train loss = 133.306
Epoch 0/2, Iter 5/9, train loss = 19.360
Epoch 0/2, Iter 6/9, train loss = 51.668
Epoch 0/2, Iter 7/9, train loss = 158.464
Epoch 0/2, Iter 8/9, train loss = 85.385
Epoch 0/2, Iter 0/9, test loss = 22.941
Epoch 0/2, Iter 1/9, test loss = 22.841
Epoch 0/2, Iter 2/9, test loss = 22.840
Epoch 0/2, Iter 3/9, test loss = 14.385
Epoch 0/2, Iter 4/9, test loss = 13.056
Epoch 0/2, Iter 5/9, test loss = 13.674
Epoch 0/2, Iter 6/9, test loss = 9.964
Epoch 0/2, Iter 7/9, test loss = 12.034
Epoch 0/2, Iter 8/9, test loss = 13.644
downscale epochs: [1, 2], downscale rate: 2.0
setting learning rate to 0.0005
Epoch 1/2, Iter 0/9, train loss = 55.692
Epoch 1/2, Iter 1/9, train loss = 97.590
Epoch 1/2, Iter 2/9, train loss =

## prediction

In [None]:
@make_iterative_func
def tensor2numpy(vars):
    if isinstance(vars, np.ndarray):
        return vars
    elif isinstance(vars, torch.Tensor):
        return vars.data.cpu().numpy()
    else:
        raise NotImplementedError("invalid input type for tensor2numpy")

In [None]:
def kitti_colormap(disparity, maxval=-1):
    """
	A utility function to reproduce KITTI fake colormap
	Arguments:
	  - disparity: numpy float32 array of dimension HxW
	  - maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used)
	
	Returns a numpy uint8 array of shape HxWx3.
	"""
    if maxval < 0:
        maxval = np.max(disparity)

    colormap = np.asarray(
        [[0, 0, 0, 114], [0, 0, 1, 185], [1, 0, 0, 114], [1, 0, 1, 174], [0, 1, 0, 114], [0, 1, 1, 185], [1, 1, 0, 114],
         [1, 1, 1, 0]])
    weights = np.asarray([8.771929824561404, 5.405405405405405, 8.771929824561404, 5.747126436781609, 8.771929824561404,
                          5.405405405405405, 8.771929824561404, 0])
    cumsum = np.asarray([0, 0.114, 0.299, 0.413, 0.587, 0.701, 0.8859999999999999, 0.9999999999999999])

    colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3])
    values = np.expand_dims(np.minimum(np.maximum(disparity / maxval, 0.), 1.), -1)
    bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum, axis=0), axis=0), disparity.shape[1], axis=1),
                     disparity.shape[0], axis=0)
    diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins))
    index = np.argmax(diffs, axis=-1) - 1

    w = 1 - (values[:, :, 0] - cumsum[index]) * np.asarray(weights)[index]

    colored_disp[:, :, 2] = (w * colormap[index][:, :, 0] + (1. - w) * colormap[index + 1][:, :, 0])
    colored_disp[:, :, 1] = (w * colormap[index][:, :, 1] + (1. - w) * colormap[index + 1][:, :, 1])
    colored_disp[:, :, 0] = (w * colormap[index][:, :, 2] + (1. - w) * colormap[index + 1][:, :, 2])

    return (colored_disp * np.expand_dims((disparity > 0), -1) * 255).astype(np.uint8)

In [None]:
def test(colored=True):
    print("Generating the disparity maps...")

    os.makedirs('./predictions_scene', exist_ok=True)

    for batch_idx, sample in enumerate(TestImgLoader):


        disp_est_tn = test_sample(sample)
        disp_est_np = tensor2numpy(disp_est_tn)
        top_pad_np = tensor2numpy(sample["top_pad"])
        right_pad_np = tensor2numpy(sample["right_pad"])
        left_filenames = sample["left_filename"]

        for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames):

            assert len(disp_est.shape) == 2

            #disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32)

            if disp_est != np.array([]):

                name = fn.split('/')
                fn = os.path.join("predictions_scene", '_'.join(name[2:]))

                if colored == True:
                    disp_est = kitti_colormap(disp_est)
                    cv2.imwrite(fn, disp_est)
                else:
                    disp_est = np.round(disp_est * 256).astype(np.uint16)
                    io.imsave(fn, disp_est)

    print("Done.")

In [None]:
@make_nograd_func
def test_sample(sample):
    model.eval()
    disp_ests = model(sample['left'].cuda(), sample['right'].cuda())
    return disp_ests[-1]

In [None]:
test()

Generating the disparity maps...
Done.
