# Seminar 11 - Stereo Depth

There are different variants of depth estimation task, depending on available input data. It can be
- Rectified stereo image pair + ground truth
- Non-rectified stereo image pair + ground truth
- Stereo image pair **without** ground truth
- ...

The seminar will be about supervised setting, and in the homework you will go unsupervised.

This task is based on two papers:

1) Mayer et al. "A Large Dataset to Train Convolutional Networksfor Disparity, Optical Flow, and Scene Flow Estimation", CVPR 2016, ([pdf](https://openaccess.thecvf.com/content_cvpr_2016/papers/Mayer_A_Large_Dataset_CVPR_2016_paper.pdf), [poster](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/poster-MIFDB16.pdf), [supplimentary materials](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/supplementary-MIFDB16.pdf), [project page](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/)) 

2) Fischer at al. "FlowNet: Learning Optical Flow with Convolutional Networks", ICCV 2015, [pdf](https://arxiv.org/pdf/1504.06852.pdf)


First paper shows, that you can train depth estimation network on purely synthetical images and get strong performance.

<img src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/Depth-estimation-DispNet.png?resize=768%2C599&ssl=1" style="width:60%">

In [None]:
#!L
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
import time
%matplotlib inline

In [None]:
#!L
def get_computing_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

device = get_computing_device()
print(f"Our main computing device is '{device}'")

## 0. Warm up

1) What is rectified stereo pair? What rectification is needed for?

2) What is disparity? 

<details><summary>hint</summary>
<img src="https://cdn-images-1.medium.com/max/1200/1*8RmW8h5XxSADXpJT_rXJAA.png" style="width:60%">
<a href="https://www.gushiciku.cn/pl/2ThP">img src</src>
</details>

3) How does disparity help with depth computation?

<details><summary>hint</summary>
<img src="https://i.stack.imgur.com/7RtcV.png" style="width:60%">
<a href="https://stackoverflow.com/questions/56427239/convert-kinect-depth-intensity-to-distance-in-meter">img src</a>
</details>

Disparity is discrete and finite. How can we increase the maximal detectable depth?

- increase baseline (you will probably lose objects which are too close, since they will be present only on one camera)
- increase focus distance

Note: we can also get X and Y coordinates with similar formulas, and reconstruct entire point cloud.

## 1. KITTI Stereo Depth 2012

http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo

194 training image pairs, 195 test image pairs with hidden ground truth, ground truth depth captured by lidar.

Zero disparity value means that disparity is unknown.

In [None]:
import yfile
import os
if not os.path.exists('./kitti_stereo_2012_training_data.zip'):
    yfile.download_from_yadisk("https://disk.yandex.ru/d/WrA28IyHHYENOw", 
                               "kitti_stereo_2012_training_data.zip",
                               target_dir='.')

# alternative, but looks like this way of downloading from g.drive is not working anymore for large files
#import gfile
#gfile.download_list(
#    'https://drive.google.com/file/d/12zitJCsOVmoCHII5Ym_t2AAORXb6WMyU',
#    filename='kitti_stereo_2012_training_data.zip',
#    target_dir='.')

In [None]:
!unzip -q ./kitti_stereo_2012_training_data.zip 

In [None]:
def normalize_disparity(img):
    img = img.astype(np.float32) / 256
    return img

In [None]:
img_name = "000002_10.png"

plt.figure(figsize=(15,5))

plt.subplot(2,2,1)
plt.title('left image')
plt.imshow(Image.open(f'./kitti_stereo_2012_training_data/train/colored_0/{img_name}')); 
plt.xticks([])
plt.yticks([])

plt.subplot(2,2,2)
plt.title('right image')
plt.imshow(Image.open(f'./kitti_stereo_2012_training_data/train/colored_1/{img_name}')); 
plt.xticks([])
plt.yticks([])

plt.subplot(2,2,3)
plt.title('disparity')
disp = np.array(Image.open(f'./kitti_stereo_2012_training_data/train/disp_noc/{img_name}'))
plt.imshow(normalize_disparity(disp), 'gray')
plt.xticks([])
plt.yticks([])
           
plt.subplot(2,2,4)
plt.title('valid disparity mask')
plt.imshow(disp > 0, 'gray')
plt.xticks([])
plt.yticks([])

In [None]:
sample_max_disparity = normalize_disparity(disp).max()
sample_shape = disp.shape

print(f'max disp = {sample_max_disparity} , disp shape {sample_shape}')

## 2. Dataset loading

In [None]:
from kitti_dataset import KITTIStereoRAM

In [None]:
# To get info about dataset implementation
# KITTIStereoRAM??

In [None]:
means = np.array([0.35715697, 0.37349922, 0.35886646] , dtype=np.float32)
stds = np.array([0.27408948, 0.2807328,  0.27994434], dtype=np.float32)

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.Normalize(means, stds),
])
# Q: can we use RandomFlip?
# Q: can we use RandomRotation?
# Q: can we use RandomCrop?

# min kitti shape is [370, 1226], max shape is [376, 1242]
PAD_HEIGHT = 128*3 
PAD_WIDTH = 1280
CROP_WIDTH = 768
def transforms_train(left_image, right_image, disparity, valid_pixels_mask):
    disparity = torchvision.transforms.functional.to_tensor(disparity)
    valid_pixels_mask = torchvision.transforms.functional.to_tensor(valid_pixels_mask)
    left_image = transform_train(left_image)
    right_image = transform_train(right_image)
    left_image = pad_to_size(left_image, PAD_HEIGHT, PAD_WIDTH)
    right_image = pad_to_size(right_image, PAD_HEIGHT, PAD_WIDTH)
    disparity = pad_to_size(disparity, PAD_HEIGHT, PAD_WIDTH)
    valid_pixels_mask = pad_to_size(valid_pixels_mask, PAD_HEIGHT, PAD_WIDTH)

    shift = torch.randint(0, PAD_WIDTH-CROP_WIDTH, (1,))
    left_image = left_image[:,:,shift:shift+CROP_WIDTH]
    right_image = right_image[:, :, shift: shift+CROP_WIDTH]
    disparity = disparity[:, :, shift: shift+ CROP_WIDTH]
    valid_pixels_mask = valid_pixels_mask[:, :, shift: shift+CROP_WIDTH]
    return left_image, right_image, disparity, valid_pixels_mask


def pad_to_size(images, min_height, min_width):
    if images.shape[1] < min_height:
        images = torchvision.transforms.functional.pad(images, (0,0,0,min_height-images.shape[1]))
    if images.shape[2] < min_width:
        images = torchvision.transforms.functional.pad(images, (0,0, min_width - images.shape[2], 0))
    return images
        
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds),
])

def transforms_test(left_image, right_image, disparity, valid_pixels_mask):
    disparity = torchvision.transforms.functional.to_tensor(disparity)
    valid_pixels_mask = torchvision.transforms.functional.to_tensor(valid_pixels_mask)
    left_image = transform_test(left_image)
    right_image = transform_test(right_image)
    left_image = pad_to_size(left_image, PAD_HEIGHT, PAD_WIDTH)
    right_image = pad_to_size(right_image, PAD_HEIGHT, PAD_WIDTH)
    disparity = pad_to_size(disparity, PAD_HEIGHT, PAD_WIDTH)
    valid_pixels_mask = pad_to_size(valid_pixels_mask, PAD_HEIGHT, PAD_WIDTH)
    
    return left_image, right_image, disparity, valid_pixels_mask

In [None]:
train_loader = KITTIStereoRAM(root="./kitti_stereo_2012_training_data/", train=True, transforms=transforms_train)

train_batch_gen = torch.utils.data.DataLoader(train_loader, 
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=16)
val_loader = KITTIStereoRAM(root="./kitti_stereo_2012_training_data/", train=False, transforms=transforms_test)

val_batch_gen = torch.utils.data.DataLoader(val_loader, 
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=16)


In [None]:
for elem in train_batch_gen:
    left, right, target, mask = elem

    print("left.shape", left.shape)
    print("right.shape", right.shape)
    print("target.shape", target.shape)
    print("mask.shape", mask.shape)
    break

## 3. DispNet

[Paper](https://openaccess.thecvf.com/content_cvpr_2016/papers/Mayer_A_Large_Dataset_CVPR_2016_paper.pdf), [poster](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/poster-MIFDB16.pdf), [supplimentary materials](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/supplementary-MIFDB16.pdf), [project page](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/)

In [None]:
# Q: how would you solve this task, knowing the dataset? E.g. which loss, how to consider two images, which architecture to use?

### 3.1 DispNet Simple

The simplest way to predict the disparity is just concat pair of images and feed it to unet-like architecture.

[[FlowNet paper]](https://arxiv.org/pdf/1504.06852.pdf)

<img src="https://miro.medium.com/max/2400/0*LPtmtLr-mugr8OtN.png" style="width:80%">
<img src="https://miro.medium.com/max/692/0*blFDiciN3KbPNeov.png" style="width:80%">

Network architecture in more details:

<img src="./dispnet.png" style="width:50%">

Note: why we predict from several scales?

Answer: to help gradient propagation. Otherwise network might only use skip-connections from the very beginning. Also, it adds smoothing inductive bias -- for adjacent pixels depth is often similar (however, it's not true for edges). 

In [None]:
class ConvBNRelu(torch.nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, *args, **kwargs)
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
    
class UpConvBNRelu(torch.nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__()
        self.conv = torch.nn.ConvTranspose2d(in_channels, out_channels, *args, **kwargs)
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
    
class DispNetSimple(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvBNRelu(6, 64, kernel_size=(7,7), stride=2, padding=(3,3))
        self.conv2 = ConvBNRelu(64, 128, kernel_size=(5,5), stride=2, padding=(2,2))
        self.conv3 = torch.nn.Sequential(
            ConvBNRelu(128, 256, kernel_size=(5,5), stride=2, padding=(2,2)),
            ConvBNRelu(256, 256, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.conv4 = torch.nn.Sequential(
            ConvBNRelu(256, 512, kernel_size=(3,3), stride=2, padding=(1,1)),
            ConvBNRelu(512, 512, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.conv5 = torch.nn.Sequential(
            ConvBNRelu(512, 512, kernel_size=(3,3), stride=2, padding=(1,1)),
            ConvBNRelu(512, 512, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.conv6 = torch.nn.Sequential(
            ConvBNRelu(512, 1024, kernel_size=(3,3), stride=2, padding=(1,1)),
            ConvBNRelu(1024, 1024, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.pred6 = torch.nn.Conv2d(1024, 1, kernel_size=3, stride=1, padding=(1,1))
        
        self.upconv5 = UpConvBNRelu(1024, 512, kernel_size=4, stride=2, padding=(1,1))
        self.iconv5 = ConvBNRelu(1025, 512, kernel_size=3, stride=1, padding=(1,1))
        self.pred5 = torch.nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=(1,1))

        self.upconv4 = UpConvBNRelu(512, 256, kernel_size=4, stride=2, padding=(1,1))
        self.iconv4 = ConvBNRelu(256+512+1, 256, kernel_size=3, stride=1, padding=(1,1))
        self.pred4 = torch.nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=(1,1))
        
        self.upconv3 = UpConvBNRelu(256, 128, kernel_size=4, stride=2, padding=(1,1))
        self.iconv3 = ConvBNRelu(128+256+1, 128, kernel_size=3, stride=1, padding=(1,1))
        self.pred3 = torch.nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=(1,1))

        self.upconv2 = UpConvBNRelu(128, 64, kernel_size=4, stride=2, padding=(1,1))
        self.iconv2 = ConvBNRelu(64+128+1, 64, kernel_size=3, stride=1, padding=(1,1))
        self.pred2 = torch.nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=(1,1))

        self.upconv1 = UpConvBNRelu(64, 32, kernel_size=4, stride=2, padding=(1,1))
        self.iconv1 = ConvBNRelu(32+64+1, 32, kernel_size=3, stride=1, padding=(1,1))
        self.pred1 = torch.nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=(1,1))

        # Here modules are first defined as class fields and then combined into a tuple.
        # This way torch register them properly (for example, they occur in model.parameters() exactly once).
        # You could avoid defining modules as class fields using nn.ModuleList.
        self.encoder_layers = (self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6)
        self.decoder_layers = (
            (self.upconv5, self.iconv5, self.pred5),
            (self.upconv4, self.iconv4, self.pred4),
            (self.upconv3, self.iconv3, self.pred3),
            (self.upconv2, self.iconv2, self.pred2),
            (self.upconv1, self.iconv1, self.pred1),
        )
        
    def forward(self, left_img, right_img):
        x = torch.cat([left_img, right_img], dim=1)

        # TODO apply dispnet

        return predictions_per_scale


Let's check that it works

In [None]:
dispnet = DispNetSimple()

In [None]:
print(f"{sum(p.numel() for p in dispnet.parameters()) / 1000_000:.1f} million parameters")

In [None]:
for sample in train_batch_gen:
    left, right, target, mask = sample
    res = dispnet(left, right)
    break

In [None]:
[pred.shape for pred in res]

### 3.2 Loss

In [None]:
def compute_loss(predicted, target, mask):
    losses = []
    target_masked = target[mask]
    for scale_pred in predicted:
        scale_pred = torch.nn.functional.interpolate(
            scale_pred, size=target.shape[-2:], mode='bilinear', align_corners=True)
        scale_pred = scale_pred[mask]
        losses.append(torch.nn.functional.huber_loss(scale_pred, target_masked))
    total_loss = sum(losses) / len(losses)
    return total_loss, losses

# Q: can we downsample ground truth instead of upsampling predictions?

In [None]:
compute_loss(res, target, mask)

### 3.3 Training

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

dispnet = DispNetSimple()
dispnet = dispnet.to(device)

opt = torch.optim.AdamW(dispnet.parameters(), lr=4e-3, weight_decay=5e-2)

In [None]:
def train_network(network, opt, num_epochs=20):
    for epoch in range(num_epochs):
        start_time = time.time()
        train_loss = []
        val_loss = []
        train_scale_losses = []
        val_scale_losses = []

        network.train(True)

        for x_left, x_right, gt, valid_pixels_mask in tqdm.tqdm(train_batch_gen):
            opt.zero_grad()
            x_left = x_left.to(device)
            x_right = x_right.to(device)
            valid_pvalid_pixels_mask_mask = valid_pixels_mask.to(device)
            gt = gt.to(device)

            pred = network(x_left, x_right)

            loss, scale_losses = compute_loss(pred, gt, valid_pixels_mask)
            loss.backward()
            opt.step()

            train_loss.append(loss.item())
            train_scale_losses.append(np.array([elem.item() for elem in scale_losses]))

        network.train(False)
        with torch.no_grad():
            for x_left, x_right, gt, valid_pixels_mask in val_batch_gen:
                x_left = x_left.to(device)
                x_right = x_right.to(device)
                gt = gt.to(device)
                valid_pixels_mask = valid_pixels_mask.to(device)

                pred = network(x_left, x_right)

                loss, scale_losses = compute_loss(pred, gt, valid_pixels_mask)

                val_loss.append(loss.item())
                val_scale_losses.append(np.array([elem.item() for elem in scale_losses]))

        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}s".format(
            epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss (in-iteration): \t{:.6f} , \t component loss: {}".format(
            np.mean(train_loss), np.mean(np.stack(train_scale_losses), axis=0)))
        print("  validation loss: \t\t\t{:.2f} , \t\t component loss: {}".format(
            np.mean(val_loss), np.mean(np.stack(val_scale_losses), axis=0)))

In [None]:
train_network(dispnet, opt, num_epochs=20)

In [None]:
def visualize_result(network, img_index):
    network.train(False)
    for i, (x_left, x_right, target, mask) in enumerate(val_batch_gen):
        if i != img_index:
            continue
        pred = network(x_left.to(device), x_right.to(device))
        pred = pred[-1].cpu()
        break
        
    plt.figure(figsize=(20, 10))
    plt.subplot(3,1,1)
    plt.title('left image')
    plt.imshow(val_loader.left_images[img_index])
    plt.xticks([]), plt.yticks([])
    plt.subplot(3,1,2)
    plt.title('gt')
    plt.imshow(val_loader.targets[img_index])
    plt.xticks([]), plt.yticks([])
    plt.subplot(3,1,3)
    plt.title('pred')
    plt.imshow(pred.data.numpy()[0,0])
    plt.xticks([]), plt.yticks([])

In [None]:
visualize_result(dispnet, img_index=6)

After 20 epochs, predictions will very likely be oversmoothed.

### 3.4 DispNet-Corr1D

(image is taken from [poster](https://lmb.informatik.uni-freiburg.de/Publications/2016/MIFDB16/poster-MIFDB16.pdf))

<img src="./dispnet-corr1d.png" style="width:80%">

In [None]:
class Corr1DLayer(torch.nn.Module):
    def __init__(self, max_disp):
        super().__init__()
        self.max_disp = max_disp

    def forward(self, left_img, right_img):
        corr_result = []
        for shift in range(0, self.max_disp):
            # YOUR CODE
            corr = ...
            corr_result.append(corr)
        corr_result = torch.stack(corr_result, dim=1)
        return corr_result
        
        
class DispNetCorr1D(torch.nn.Module):
    def __init__(self, max_disp=40):
        super().__init__()
        self.conv1 = ConvBNRelu(3, 64, kernel_size=(7,7), stride=2, padding=(3,3))
        self.conv2 = ConvBNRelu(64, 128, kernel_size=(5,5), stride=2, padding=(2,2))
        
        self.corr1d = Corr1DLayer(max_disp)
        self.conv_refinement = ConvBNRelu(128, 64, kernel_size=(3,3), stride=1, padding=(1,1))
        
        self.conv3 = torch.nn.Sequential(
            ConvBNRelu(64+max_disp, 256, kernel_size=(5,5), stride=2, padding=(2,2)),
            ConvBNRelu(256, 256, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.conv4 = torch.nn.Sequential(
            ConvBNRelu(256, 512, kernel_size=(3,3), stride=2, padding=(1,1)),
            ConvBNRelu(512, 512, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.conv5 = torch.nn.Sequential(
            ConvBNRelu(512, 512, kernel_size=(3,3), stride=2, padding=(1,1)),
            ConvBNRelu(512, 512, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.conv6 = torch.nn.Sequential(
            ConvBNRelu(512, 1024, kernel_size=(3,3), stride=2, padding=(1,1)),
            ConvBNRelu(1024, 1024, kernel_size=(3,3), stride=1, padding=(1,1)))
        self.pred6 = torch.nn.Conv2d(1024, 1, kernel_size=3, stride=1, padding=(1,1))
        
        self.upconv5 = UpConvBNRelu(1024, 512, kernel_size=4, stride=2, padding=(1,1))
        self.iconv5 = ConvBNRelu(1025, 512, kernel_size=3, stride=1, padding=(1,1))
        self.pred5 = torch.nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=(1,1))

        self.upconv4 = UpConvBNRelu(512, 256, kernel_size=4, stride=2, padding=(1,1))
        self.iconv4 = ConvBNRelu(256+512+1, 256, kernel_size=3, stride=1, padding=(1,1))
        self.pred4 = torch.nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=(1,1))
        
        self.upconv3 = UpConvBNRelu(256, 128, kernel_size=4, stride=2, padding=(1,1))
        self.iconv3 = ConvBNRelu(128+256+1, 128, kernel_size=3, stride=1, padding=(1,1))
        self.pred3 = torch.nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=(1,1))

        self.upconv2 = UpConvBNRelu(128, 64, kernel_size=4, stride=2, padding=(1,1))
        self.iconv2 = ConvBNRelu(64+128+1, 64, kernel_size=3, stride=1, padding=(1,1))
        self.pred2 = torch.nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=(1,1))

        self.upconv1 = UpConvBNRelu(64, 32, kernel_size=4, stride=2, padding=(1,1))
        self.iconv1 = ConvBNRelu(32+64+1, 32, kernel_size=3, stride=1, padding=(1,1))
        self.pred1 = torch.nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=(1,1))
        
        self.encoder_layers = (self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6)
        self.decoder_layers = (
            (self.upconv5, self.iconv5, self.pred5),
            (self.upconv4, self.iconv4, self.pred4),
            (self.upconv3, self.iconv3, self.pred3),
            (self.upconv2, self.iconv2, self.pred2),
            (self.upconv1, self.iconv1, self.pred1),
        )

    def forward(self, left_img, right_img):
        # YOUR CODE

        return predictions_per_scale


In [None]:
dispnet = DispNetCorr1D()

In [None]:
print(f"{sum(p.numel() for p in dispnet.parameters()) / 1000_000:.1f} million parameters")

In [None]:
for sample in train_batch_gen:
    left, right, target, mask = sample
    res = dispnet(left, right)
    break

In [None]:
[pred.shape for pred in res]

### 3.5 DispNet-Corr1D training

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

dispnet = DispNetCorr1D(max_disp=40)
dispnet = dispnet.to(device)

opt = torch.optim.AdamW(dispnet.parameters(), lr=4e-3, weight_decay=5e-2)

In [None]:
train_network(dispnet, opt, num_epochs=20)

In [None]:
visualize_result(dispnet, img_index=6)