# 6: Stereo Matching Fundamentals

A key problem in computer vision is to estimate the depth of a scene from a pair of stereo images. This is known as stereo matching. In this notebook, we will learn about the fundamentals of stereo matching and implement a simple stereo matching algorithm.

## Dataset

For this notebook I have used some data from the newest Middlebury Stereo Evaluation dataset. The dataset can be seen here:

- [2021 Mobile stereo datasets with ground truth](https://vision.middlebury.edu/stereo/data/scenes2021/)
- [Download all as zip](https://vision.middlebury.edu/stereo/data/scenes2021/zip/all.zip)

Only the dataset "traproom1" is included in this repository. You can download the rest using the zip link above. To make it seamlessly integrate with repo by:

- Placing the desired scene folders in the `test_data` folder.
- Add the dataset to `TestDataPaths` similar to that of `traproom1`:
    ```python
    class TestDataPaths:
        ...
        traproom1_dir: Path = _test_data_dir / "traproom1"
        your_folder_dir: Path = _test_data_dir / "your_folder"
    ```

### Custom

I have also included some data from my own setup that is better suited for us to get started with stereo matching. The data is in the `test_data` folder in `stereo_data_0` and `stereo_data_1` directories, these are read with the same function. These does not have any ground truth data.

## Load data

Lets start by loading the [traproom1](https://vision.middlebury.edu/stereo/data/scenes2021/data/traproom1/) dataset from the repository `test_data` folder and visualize the stereo images.

In [None]:
import matplotlib.pyplot as plt
from nptyping import Float32, NDArray, Shape
import numpy as np

from oaf_vision_3d._test_data_paths import TestDataPaths
from oaf_vision_3d.lens_model import LensModel
from oaf_vision_3d.point_cloud_visualization import open3d_visualize_point_cloud
from oaf_vision_3d.transformation_matrix import TransformationMatrix
from oaf_vision_3d._stereo_data_reader import StereoData

data_dir = TestDataPaths.traproom1_dir
stereo_data = StereoData.from_path(data_dir)

fig, ax = plt.subplots(2, 1, figsize=(12, 12))
ax[0].imshow(stereo_data.image_0, cmap="gray", vmin=0, vmax=1)
ax[0].set_title("Image left")
ax[0].axis("off")
ax[1].imshow(stereo_data.image_1, cmap="gray", vmin=0, vmax=1)
ax[1].set_title("Image right")
ax[1].axis("off")
plt.tight_layout()
plt.show()

This data also have ground truth disparity maps that we can use to triangulate the depth of the scene:

$$ \text{disparity} = x_{left} - x_{right} $$

In [None]:
if stereo_data.ground_truth_disparity is not None:
    plt.figure(figsize=(16, 8))
    plt.imshow(stereo_data.ground_truth_disparity)
    plt.axis("off")
    plt.colorbar()
    plt.title("Ground truth disparity")
    plt.show()

## Triangulate disparity

We can very simply triangulate the depth of the scene by using the disparity map and the camera calibration of the system by using the [`triangulate_points`](../oaf_vision_3d/triangulation.py) function that we implemented in the previous [workshop](05_dual_camera_setups.ipynb).

To triangulate the depth we need to create two sets of pixel grid, which can be done easily using disparity by using the same $y$-index while calculating the $x$-index for the right image as:

$$ x_{right} = x_{left} - \text{disparity} $$

In [3]:
from oaf_vision_3d.triangulation import triangulate_points


def triangulate_disparity(
    disparity: NDArray[Shape["H, W"], Float32],
    lens_model_0: LensModel,
    lens_model_1: LensModel,
    transformation_matrix: TransformationMatrix,
) -> NDArray[Shape["H, W, 3"], Float32]:
    y, x = np.indices(disparity.shape, dtype=np.float32)
    pixels_0 = np.stack([x, y], axis=-1)
    pixels_1 = np.stack([x - disparity, y], axis=-1)

    undistortied_normalized_pixels_0 = lens_model_0.undistort_pixels(
        normalized_pixels=lens_model_0.normalize_pixels(pixels=pixels_0)
    )
    undistortied_normalized_pixels_1 = lens_model_1.undistort_pixels(
        normalized_pixels=lens_model_1.normalize_pixels(pixels=pixels_1)
    )

    return triangulate_points(
        undistorted_normalized_pixels_0=undistortied_normalized_pixels_0,
        undistorted_normalized_pixels_1=undistortied_normalized_pixels_1,
        transformation_matrix=transformation_matrix,
    )

And since we have the ground truth disparities we can use them to calculate the depth of the scene:

In [None]:
if stereo_data.ground_truth_disparity is not None:
    xyz_ground_truth = triangulate_disparity(
        disparity=stereo_data.ground_truth_disparity,
        lens_model_0=stereo_data.lens_model_0,
        lens_model_1=stereo_data.lens_model_1,
        transformation_matrix=stereo_data.transformation_matrix,
    )

    if xyz_ground_truth is not None:
        plt.figure(figsize=(16, 8))
        plt.imshow(xyz_ground_truth[..., 2])
        plt.axis("off")
        plt.colorbar()
        plt.title("Ground truth depth")
        plt.show()

### Open3D Pointcloud

I will show of the pointcloud using Open3D, which is a great library for 3D visualization but also hard to work with in interactive mode, especially jupyter book. The pointcloud will open in a separate window, and block the execution of the code. You can close the window to continue the execution of the code.

In [None]:
if xyz_ground_truth is not None:
    open3d_visualize_point_cloud(xyz=xyz_ground_truth, rgb=stereo_data.image_0)

## Stereo matching

For stereo matching we will implement a simple block matching algorithm, alot of the stereo field has moved to deep learning based methods, I highly recommend looking at the [Middlebury Stereo Evaluation](https://vision.middlebury.edu/stereo/eval/) for more information on the state of the art in stereo matching. Many of these networks/methods are also easily accessible as many share the full source code. 

### Load data

For stereo matching we will use a different dataset that I have included, this has fewer disparities. But feel free to later test your solution with different datasets.

In [None]:
data_dir = TestDataPaths.stereo_data_0_dir
stereo_matching_data = StereoData.from_path(data_dir)

fig, ax = plt.subplots(2, 1, figsize=(12, 12))
ax[0].imshow(stereo_matching_data.image_0, cmap="gray", vmin=0, vmax=1)
ax[0].set_title("Image left")
ax[0].axis("off")
ax[1].imshow(stereo_matching_data.image_1, cmap="gray", vmin=0, vmax=1)
ax[1].set_title("Image right")
ax[1].axis("off")
plt.tight_layout()
plt.show()

This dataset is of a flat planar ish piece of paper with an aruco marker on it. Since the quality is fairly high resolution we actually have alot of detail to work with, but we will struggle in many areas.

### Block matching

The block matching algorithm is a simple algorithm that works by comparing blocks of pixels in the left image with blocks of pixels in the right image. The algorithm works by sliding a window of a fixed size over the left image and comparing the block in the left image with blocks in the right image. The block with the smallest difference is the best match. We can easily implement this algorithm using numpy by(remember we assume movement in the x-direction):

- Defining a disparity range we want to search in.
- For each disparity:
    - Shift the right image by the disparity.
    - Calculate some error between the left and right image (I will use Absolute Difference).
    - Sum the error over the block size.
  - The disparity with the smallest error is the best match for each pixel.

In [None]:
from scipy.signal import convolve2d


kernel_size = 29

image_0 = stereo_matching_data.image_0
image_1 = stereo_matching_data.image_1

disparities = np.arange(
    stereo_matching_data.expected_disparity[0],
    stereo_matching_data.expected_disparity[1] + 1,
    dtype=np.int32,
)

error = []
for _disparity in disparities:
    shifted_image_1 = np.roll(image_1, _disparity, axis=1)
    single_pixel_error = np.abs(image_0 - shifted_image_1).sum(axis=-1)

    convoluted_error = convolve2d(
        convolve2d(
            single_pixel_error, np.ones((1, kernel_size)) / kernel_size, mode="same"
        ),
        np.ones((kernel_size, 1)) / kernel_size,
        mode="same",
    )
    error.append(convoluted_error)

disparity_error = np.array(error, dtype=np.float32)
disparity = disparities[np.argmin(disparity_error, axis=0)].astype(np.float32)

disparity[:, : int(np.abs(stereo_matching_data.expected_disparity).max())] = np.nan
disparity[:, -int(np.abs(stereo_matching_data.expected_disparity).max()) :] = np.nan
disparity[disparity >= int(stereo_matching_data.expected_disparity.max())] = np.nan
disparity[disparity <= int(stereo_matching_data.expected_disparity.min())] = np.nan

plt.figure(figsize=(10, 6))
plt.imshow(disparity)
plt.colorbar()
plt.title("Disparity Map")
plt.axis("off")
plt.show()

### Evaluation

To evaluate the current results let us calculate the depth:

In [None]:
xyz_simple = triangulate_disparity(
    disparity=disparity,
    lens_model_0=stereo_matching_data.lens_model_0,
    lens_model_1=stereo_matching_data.lens_model_1,
    transformation_matrix=stereo_matching_data.transformation_matrix,
)

if xyz_simple is not None:
    plt.figure(figsize=(10, 6))
    plt.imshow(xyz_simple[..., 2])
    plt.colorbar()
    plt.title("Depth Map Simple Block Matching")
    plt.axis("off")
    plt.show()

I ran this with a 29x29 kernel, which takes some time, but get failry reasonable results on a near flat object like this. In the point cloud below (if set up Open3D above) you can sede the discrete disparity values:

In [None]:
if xyz_simple is not None:
    open3d_visualize_point_cloud(xyz=xyz_simple, rgb=stereo_matching_data.image_0)

### Subdisparity resolution

To avoid having discrete depth levels, we an do a subpixel fit on the disparity values. We can do this easily by not piciking the disparity with the smallest error, but instead fitting a 2nd degree polynomial to the error values for the index $\pm$ 1 of the smallest error and then finding the minima of the polynomial.

To do this effieciently we use $x = [-1, 0, 1]$ and $y = [y_{-1}, y_{0}, y_{1}]$ and fit a 2nd degree polynomial to this:

$$ y = ax^2 + bx + c $$

Which becomes a matrix system:

$$ \begin{bmatrix} 1 & -1 & 1 \\ 0 & 0 & 1 \\ 1 & 1 & 1 \end{bmatrix} \begin{bmatrix} a \\ b \\ c \end{bmatrix} = \begin{bmatrix} y_{-1} \\ y_{0} \\ y_{1} \end{bmatrix} $$

That gives the following solution:

$$ \begin{align*} a &= 0.5 (y_{-1} + y_{1}) - y_{0} \\ b &= 0.5 (y_{1} - y_{-1}) \\ c &= y_{0} \end{align*} $$

The minimum of this polynomial can be found when it's derivative is zero:

$$ \frac{dy}{dx} = 2ax + b = 0 \\ \Downarrow \\ x = -\frac{b}{2a} $$

Since we used $x = [-1, 0, 1]$ we can find the subpixel disparity by adding this $x$ (more like a $\Delta x$) to the idx of the smallest error.

$$ \text{subpixel disparity} = idx + x $$

In [None]:
from nptyping import Int32


def find_subpixel_disparities_poly_2(
    disparities: NDArray[Shape["N"], Int32],
    function_value: NDArray[Shape["N, H, W"], Float32],
) -> NDArray[Shape["H, W"], Float32]:
    h_idx = np.arange(function_value.shape[1])
    w_idx = np.arange(function_value.shape[2])
    h_idx, w_idx = np.meshgrid(h_idx, w_idx, indexing="ij")

    idx = np.clip(np.argmin(function_value, axis=0), 1, disparities.shape[0] - 2)

    f_0 = function_value[idx - 1, h_idx, w_idx]
    f_1 = function_value[idx, h_idx, w_idx]
    f_2 = function_value[idx + 1, h_idx, w_idx]

    a = 0.5 * (f_0 + f_2) - f_1
    b = 0.5 * (f_2 - f_0)

    denom = 2 * a
    denom = np.where(denom == 0, np.nan, denom)

    delta = -b / denom
    delta = np.where(np.abs(delta) > 1, np.nan, delta)

    return disparities[idx].astype(np.float32) + delta


disparity_sub_pixel = find_subpixel_disparities_poly_2(
    disparities=disparities,
    function_value=disparity_error,
)


plt.figure(figsize=(10, 6))
plt.imshow(disparity_sub_pixel)
plt.colorbar()
plt.title("Sub Pixel Disparity Map")
plt.axis("off")
plt.show()

In [None]:
xyz_sub_pixel = triangulate_disparity(
    disparity=disparity_sub_pixel,
    lens_model_0=stereo_matching_data.lens_model_0,
    lens_model_1=stereo_matching_data.lens_model_1,
    transformation_matrix=stereo_matching_data.transformation_matrix,
)

if xyz_sub_pixel is not None:
    plt.figure(figsize=(10, 6))
    plt.imshow(xyz_sub_pixel[..., 2])
    plt.colorbar()
    plt.title("Sub Pixel Depth Map")
    plt.axis("off")
    plt.show()

This has removed the discrete disparity levels and given us a subdisparity resolution. We still see some artifacts, but this is a good start. For those with Open3D set up you can see the subpixel disparity in the point cloud below:

In [None]:
if xyz_sub_pixel is not None:
    open3d_visualize_point_cloud(xyz=xyz_sub_pixel, rgb=stereo_matching_data.image_0)

## Testing Stereo Matching

Until next session I want you to test out how you can improve the stereo matching algorithm. Below is a full implementation that allows you to alter some flags to test different things:

- `block_size` - The size of the block to compare, what happens if you adjust this?
- `disparity_range` - The range of disparities to search in, what happens if you adjust this?
- `subpixel_fit` - If you want to do a subpixel fit or not as we have seen in this workshop.
- `cost_function` - What cost function to use, I currenntly used the absolute difference, and utilizing all color channels. Can you test a different cost function?
  - You can also try third party libraries like OpenCV that have some built in stereo matching functions (F.Ex. [`cv::stereo::StereoBinaryBM`](https://docs.opencv.org/3.4/dd/d86/group__stereo.html))

In [13]:
from enum import Enum


class CostFunction(Enum):
    SUM_OF_ABSOLUTE_DIFFERENCE = 0
    SUM_OF_SQUARED_DIFFERENCE = 1


def _get_cost(
    image_0: NDArray[Shape["H, W, ..."], Float32],
    image_1: NDArray[Shape["H, W, ..."], Float32],
    cost_function: CostFunction,
) -> NDArray[Shape["H, W"], Float32]:
    match cost_function:
        case CostFunction.SUM_OF_ABSOLUTE_DIFFERENCE:
            return np.abs(image_0 - image_1).sum(axis=-1)
        case CostFunction.SUM_OF_SQUARED_DIFFERENCE:
            return ((image_0 - image_1) ** 2).sum(axis=-1)
        case _:
            raise ValueError("Invalid cost function")


def block_matching(
    image_0: NDArray[Shape["H, W"], Float32],
    image_1: NDArray[Shape["H, W"], Float32],
    disparity_range: NDArray[Shape["2"], Float32],
    block_size: int,
    subpixel_fit: bool,
    cost_function: CostFunction,
) -> NDArray[Shape["H, W"], Float32]:
    disparities = np.arange(disparity_range[0], disparity_range[1], dtype=np.int32)
    error = []
    for _disparity in disparities:
        shifted_image_1 = np.roll(image_1, _disparity, axis=1)
        single_pixel_error = _get_cost(image_0, shifted_image_1, cost_function)

        convoluted_error = convolve2d(
            convolve2d(
                single_pixel_error, np.ones((1, block_size)) / block_size, mode="same"
            ),
            np.ones((block_size, 1)) / block_size,
            mode="same",
        )
        error.append(convoluted_error)

    disparity_error = np.array(error, dtype=np.float32)

    if subpixel_fit:
        disparity = find_subpixel_disparities_poly_2(
            disparities=disparities, function_value=disparity_error
        )
    else:
        disparity = disparities[np.argmin(disparity_error, axis=0)].astype(np.float32)

    disparity[:, : int(np.abs(disparities).max())] = np.nan
    disparity[:, -int(np.abs(disparities).max()) :] = np.nan
    disparity[disparity >= disparities.max()] = np.nan
    disparity[disparity <= disparities.min()] = np.nan

    return disparity

This can be used to test any wanted dataset:

In [14]:
# Choose testset
data_dir = TestDataPaths.stereo_data_0_dir
data = StereoData.from_path(data_dir)

# Adjustable parameters
block_size = 29
disparity_range = data.expected_disparity
subpixel_fit = True
cost_function = CostFunction.SUM_OF_ABSOLUTE_DIFFERENCE

# Run block matching
disparity_new = block_matching(
    image_0=data.image_0,
    image_1=data.image_1,
    disparity_range=disparity_range,
    block_size=block_size,
    subpixel_fit=subpixel_fit,
    cost_function=cost_function,
)

In [None]:
plt.figure(figsize=(10, 6))
plt.imshow(disparity)
plt.colorbar()
plt.title("Disparity Map")
plt.axis("off")
plt.show()

In [None]:
xyz_new = triangulate_disparity(
    disparity=disparity_new,
    lens_model_0=data.lens_model_0,
    lens_model_1=data.lens_model_1,
    transformation_matrix=data.transformation_matrix,
)

if xyz_new is not None:
    plt.figure(figsize=(10, 6))
    plt.imshow(xyz_new[..., 2])
    plt.colorbar()
    plt.title("Depth Map")
    plt.axis("off")
    plt.show()

In [None]:
if xyz_new is not None:
    open3d_visualize_point_cloud(xyz=xyz_new, rgb=data.image_0)

### Different kernels and cost functions

Here we test both the absolute difference and the squared difference cost functions. All are run with a block size of $11 \times 11$, $17 \times 17$, $23 \times 23$ and $29 \times 29$.

In [None]:
# Choose testset
data_dir = TestDataPaths.stereo_data_0_dir
data = StereoData.from_path(data_dir)

# Adjustable parameters
block_size = 29
disparity_range = data.expected_disparity
subpixel_fit = True
cost_function = CostFunction.SUM_OF_ABSOLUTE_DIFFERENCE
kernels = [11, 17, 23, 29]

fig, axs = plt.subplots(2, 2, figsize=(16, 16))
for kernel, ax in zip(kernels, axs.ravel()):
    disparity_new = block_matching(
        image_0=data.image_0,
        image_1=data.image_1,
        disparity_range=disparity_range,
        block_size=kernel,
        subpixel_fit=subpixel_fit,
        cost_function=cost_function,
    )

    ax.imshow(disparity_new)
    ax.axis("off")
    ax.set_title(f"Disparity Map (Kernel: {kernel})")
fig.tight_layout()
fig.suptitle("Block Matching with Sum of Absolute Difference")


cost_function = CostFunction.SUM_OF_SQUARED_DIFFERENCE
fig, axs = plt.subplots(2, 2, figsize=(16, 16))
for kernel, ax in zip(kernels, axs.ravel()):
    disparity_new = block_matching(
        image_0=data.image_0,
        image_1=data.image_1,
        disparity_range=disparity_range,
        block_size=kernel,
        subpixel_fit=subpixel_fit,
        cost_function=cost_function,
    )

    ax.imshow(disparity_new)
    ax.axis("off")
    ax.set_title(f"Disparity Map (Kernel: {kernel})")
fig.tight_layout()
fig.suptitle("Block Matching with Sum of Squared Difference")
plt.show()