diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 88eb4e17823..ea85a853824 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -16,6 +16,7 @@ from collections import defaultdict from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +import numpy as np import PIL import PIL.Image import pytest @@ -23,7 +24,7 @@ import torchvision.datasets import torchvision.io from common_utils import get_tmp_dir, disable_console_output - +from torchvision.transforms.functional import get_dimensions __all__ = [ "UsageError", @@ -561,7 +562,9 @@ def test_feature_types(self, config): @test_all_configs def test_num_examples(self, config): with self.create_dataset(config) as (dataset, info): - assert len(dataset) == info["num_examples"] + assert ( + len(dataset) == info["num_examples"] + ), f"The number of examples {len(dataset)} does not match the expected {info['num_examples']}" @test_all_configs def test_transforms(self, config): @@ -931,6 +934,48 @@ def create_random_string(length: int, *digits: str) -> str: return "".join(random.choice(digits) for _ in range(length)) +def shape_test_for_stereo_gt_w_mask( + left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray, valid_mask: np.ndarray +): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, h, w = left_dims + # check that left and right are the same size + assert left_dims == right_dims + # check general shapes + assert c == 3 + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + # check that valid mask is the same size as the disparity + _, dh, dw = disparity.shape + mh, mw = valid_mask.shape + assert dh == mh + assert dw == mw + + +def shape_test_for_stereo_gt_no_mask(left: PIL.Image.Image, right: PIL.Image.Image, disparity: np.ndarray): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, h, w = left_dims + # check that left and right are the same size + assert left_dims == right_dims + # check general shapes + assert c == 3 + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + + +def shape_test_for_stereo_no_gt(left: PIL.Image.Image, right: PIL.Image.Image, disparity: None): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, _, _ = left_dims + # check that left and right are the same size + assert left_dims == right_dims + # check general shapes + assert c == 3 + assert disparity is None + + def make_fake_pfm_file(h, w, file_name): values = list(range(3 * h * w)) # Note: we pack everything in little endian: -1.0, and "<" diff --git a/test/test_datasets.py b/test/test_datasets.py index a108479aee3..ff1a418fbac 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -13,6 +13,7 @@ import unittest import xml.etree.ElementTree as ET import zipfile +from typing import List, Callable, Tuple import datasets_utils import numpy as np @@ -2671,5 +2672,579 @@ def inject_fake_data(self, tmpdir: str, config): return len(sampled_classes) * num_images_per_class[config["split"]] +class ETH3DTStereoestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.ETH3DStereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + @staticmethod + def _create_scene_folder(num_examples: int, root_dir: str) -> List[pathlib.Path]: + # create the scene folder + image_paths = [] + # make the root_dir if it does not exits + os.makedirs(root_dir, exist_ok=True) + + for i in range(num_examples): + scene_dir = os.path.join(root_dir, f"scene_{i}") + os.makedirs(scene_dir, exist_ok=True) + # populate with left right images + image_paths.append(datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100))) + image_paths.append(datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100))) + return image_paths + + @staticmethod + def _create_annotation_folder(num_examples: int, root_dir: str) -> None: + paths = [] + # make the root_dir if it does not exits + os.makedirs(root_dir, exist_ok=True) + + # create scene directories + for i in range(num_examples): + scene_dir = os.path.join(root_dir, f"scene_{i}") + os.makedirs(scene_dir, exist_ok=True) + # populate with a random png file for occlusion mask, and a pfm file for disparity + paths.append(datasets_utils.create_image_file(root=scene_dir, name="mask0nocc.png", size=(1, 100, 100))) + pfm_path = os.path.join(scene_dir, "disp0GT.pfm") + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=pfm_path) + + def inject_fake_data(self, tmpdir, config): + eth3d_dir = pathlib.Path(tmpdir) / "ETH3D" + + num_examples = 2 if config["split"] == "train" else 3 + + split_name = "two_view_training" if config["split"] == "train" else "two_view_test" + split_dir = eth3d_dir / split_name + self._create_scene_folder(num_examples, split_dir) + + if config["split"] == "train": + annot_dir = os.path.join(eth3d_dir, "two_view_training_gt") + self._create_annotation_folder(num_examples, annot_dir) + + return num_examples + + def test_training_splits(self): + with self.create_dataset(split="train") as (dataset, _): + assert dataset._images and len(dataset._images) == len( + dataset._disparities + ), "Training images do not match with training disparities" + for left, right, disparity, valid_mask in dataset: + datasets_utils.shape_test_for_stereo_gt_w_mask(left, right, disparity, valid_mask) + + def test_testing_splits(self): + with self.create_dataset(split="test") as (dataset, _): + assert all(d == ("", "") for d in dataset._disparities) + for left, right, disparity, _ in dataset: + datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class CREStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CREStereo + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None)) + + def inject_fake_data(self, tmpdir, config): + crestereo_dir = pathlib.Path(tmpdir) / "CREStereo" + os.makedirs(crestereo_dir, exist_ok=True) + + examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5} + + for category_name in ["shapenet", "reflective", "tree", "hole"]: + split_dir = crestereo_dir / category_name + os.makedirs(split_dir, exist_ok=True) + num_examples = examples[category_name] + + for idx in range(num_examples): + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100)) + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100)) + # these are going to end up being gray scale images + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100)) + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100)) + + return sum(examples.values()) + + def test_splits(self): + with self.create_dataset() as (dataset, _): + for left, right, disparity, mask in dataset: + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + +class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Middlebury2014Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "additional"), + calibration=("perfect", "imperfect", "both"), + use_ambient_views=(True, False), + ) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + @staticmethod + def _make_scene_folder(root_dir: str, scene_name: str, split: str) -> None: + calibrations = [None] if split == "test" else ["-perfect", "-imperfect"] + root_dir = pathlib.Path(root_dir) + + for c in calibrations: + scene_dir = root_dir / f"{scene_name}{c}" + os.makedirs(scene_dir, exist_ok=True) + # make normal images first + datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(3, 100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(3, 100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1E.png", size=(3, 100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1L.png", size=(3, 100, 100)) + # these are going to end up being gray scale images + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=os.path.join(scene_dir, "disp0.pfm")) + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=os.path.join(scene_dir, "disp1.pfm")) + + def inject_fake_data(self, tmpdir, config): + split_scene_map = { + "train": ["Adirondack", "Jadeplant", "Motorcycle", "Piano"], + "additional": ["Backpack", "Bicycle1", "Cable", "Classroom1"], + "test": ["Plants", "Classroom2E", "Classroom2", "Australia"], + } + + middlebury_dir = pathlib.Path(tmpdir, "Middlebury2014") + os.makedirs(middlebury_dir, exist_ok=True) + + split_dir = middlebury_dir / config["split"] + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 2, "additional": 3, "test": 4}.get(config["split"], 0) + for idx in range(num_examples): + scene_name = split_scene_map[config["split"]][idx] + self._make_scene_folder(root_dir=split_dir, scene_name=scene_name, split=config["split"]) + + if config["calibration"] == "both": + num_examples *= 2 + return num_examples + + def test_train_splits(self): + for split, calibration in itertools.product(["train", "additional"], ["perfect", "imperfect", "both"]): + with self.create_dataset(split=split, calibration=calibration) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split, calibration=None) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + + def test_augmented_view_usage(self): + with self.create_dataset(split="train", use_ambient_views=True) as (dataset, _): + for left, right, _ in dataset: + left_array = np.array(left) + right_array = np.array(right) + # check that left and right are the same size + assert left_array.shape == right_array.shape + + def test_warnings_train(self): + # train set invalid + split = "train" + calibration = None + with pytest.raises( + ValueError, + match=f"Split '{split}' has calibration settings, however None was provided as an argument." + f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.", + ): + with self.create_dataset(split=split, calibration=calibration): + pass + + def test_warnings_test(self): + # test set invalid + split = "test" + calibration = "perfect" + with pytest.raises( + ValueError, match="Split 'test' has only no calibration settings, please set `calibration=None`." + ): + with self.create_dataset(split=split, calibration=calibration): + pass + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti2012Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = pathlib.Path(tmpdir) / "Kitti2012" + os.makedirs(kitti_dir, exist_ok=True) + + split_dir = kitti_dir / (config["split"] + "ing") + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 3}.get(config["split"], 0) + + datasets_utils.create_image_folder( + root=split_dir, + name="colored_0", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + datasets_utils.create_image_folder( + root=split_dir, + name="colored_1", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + + if config["split"] == "train": + datasets_utils.create_image_folder( + root=split_dir, + name="disp_noc", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2012 uses a single channel image for disparities + size=(1, 100, 200), + ) + + return num_examples + + def test_train_splits(self): + for split in ["train"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti2015Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = pathlib.Path(tmpdir) / "Kitti2015" + os.makedirs(kitti_dir, exist_ok=True) + + split_dir = kitti_dir / (config["split"] + "ing") + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 6}.get(config["split"], 0) + + datasets_utils.create_image_folder( + root=split_dir, + name="image_2", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + datasets_utils.create_image_folder( + root=split_dir, + name="image_3", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + + if config["split"] == "train": + datasets_utils.create_image_folder( + root=split_dir, + name="disp_occ_0", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2015 uses a single channel image for disparities + size=(1, 100, 200), + ) + + datasets_utils.create_image_folder( + root=split_dir, + name="disp_occ_1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2015 uses a single channel image for disparities + size=(1, 100, 200), + ) + + return num_examples + + def test_train_splits(self): + for split in ["train"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo_no_gt(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SceneFlowStereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final") + ) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + @staticmethod + def _create_pfm_folder( + root: str, name: str, file_name_fn: Callable[..., str], num_examples: int, size: Tuple[int, int] + ) -> None: + root = pathlib.Path(root) / name + os.makedirs(root, exist_ok=True) + + for i in range(num_examples): + datasets_utils.make_fake_pfm_file(size[0], size[1], root / file_name_fn(i)) + + def inject_fake_data(self, tmpdir, config): + scene_flow_dir = pathlib.Path(tmpdir) / "SceneFlow" + os.makedirs(scene_flow_dir, exist_ok=True) + + split_dir = scene_flow_dir / config["split"] + os.makedirs(split_dir, exist_ok=True) + + pass_dir_map = { + "clean": "frames_cleanpass", + "final": "frames_finalpass", + } + + num_examples = 1 + pass_dir_name = pass_dir_map.get(config["pass_name"], None) + + # create pass directories + pass_dir = split_dir / pass_dir_name + disp_dir = split_dir / "disparity" + os.makedirs(pass_dir, exist_ok=True) + os.makedirs(disp_dir, exist_ok=True) + + num_examples = {"FlyingThings3D": 4, "Driving": 6, "Monkaa": 5}.get(config["split"], 0) + + for direction in ["left", "right"]: + for scene_idx in range(num_examples): + os.makedirs(pass_dir / f"scene_{scene_idx:06d}", exist_ok=True) + datasets_utils.create_image_folder( + root=pass_dir / f"scene_{scene_idx:06d}", + name=direction, + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=1, + size=(3, 200, 100), + ) + + os.makedirs(disp_dir / f"scene_{scene_idx:06d}", exist_ok=True) + self._create_pfm_folder( + root=disp_dir / f"scene_{scene_idx:06d}", + name=direction, + file_name_fn=lambda i: f"{i:06d}.pfm", + num_examples=1, + size=(100, 200), + ) + + return num_examples + + def test_splits(self): + for split_name, pass_name in itertools.product(["FlyingThings3D", "Driving", "Monkaa"], ["clean", "final"]): + with self.create_dataset(split=split_name, pass_name=pass_name) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FallingThingsStereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("single", "mixed")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + @staticmethod + def _make_dummy_depth_map(root: str, name: str, size: Tuple[int, int]): + file = pathlib.Path(root) / name + image = np.ones((size[0], size[1]), dtype=np.uint8) + PIL.Image.fromarray(image).save(file) + + @staticmethod + def _make_scene_folder(root: str, scene_name: str, size: Tuple[int, int]) -> None: + root = pathlib.Path(root) / scene_name + os.makedirs(root, exist_ok=True) + # jpg images + datasets_utils.create_image_file(root, "image1.left.jpg", size=(3, size[1], size[0])) + datasets_utils.create_image_file(root, "image1.right.jpg", size=(3, size[1], size[0])) + # single channel depth maps + FallingThingsStereoTestCase._make_dummy_depth_map(root, "image1.left.depth.png", size=(size[0], size[1])) + FallingThingsStereoTestCase._make_dummy_depth_map(root, "image1.right.depth.png", size=(size[0], size[1])) + # camera settings json. Minimal example for _read_disparity function testing + settings_json = {"camera_settings": [{"intrinsic_settings": {"fx": 1}}]} + with open(root / "_camera_settings.json", "w") as f: + json.dump(settings_json, f) + + def inject_fake_data(self, tmpdir, config): + fallingthings_dir = pathlib.Path(tmpdir) / "FallingThings" + os.makedirs(fallingthings_dir, exist_ok=True) + + split_dir = pathlib.Path(fallingthings_dir) / config["split"] + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"single": 2, "mixed": 3}.get(config["split"], 0) + + for i in range(num_examples): + self._make_scene_folder( + root=split_dir, + scene_name=f"scene_{i:06d}", + size=(100, 200), + ) + + return num_examples + + def test_splits(self): + for split_name in ["single", "mixed"]: + with self.create_dataset(split=split_name) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class SintelStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SintelStereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(pass_name=("final", "clean", "both")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + sintel_dir = pathlib.Path(tmpdir) / "Sintel" + os.makedirs(sintel_dir, exist_ok=True) + + split_dir = pathlib.Path(sintel_dir) / "training" + os.makedirs(split_dir, exist_ok=True) + + # a single setting, since there are no splits + num_examples = {"final": 2, "clean": 2} + pass_names = { + "final": ["final"], + "clean": ["clean"], + "both": ["final", "clean"], + }.get(config["pass_name"], []) + + for p in pass_names: + for view in [f"{p}_left", f"{p}_right"]: + root = split_dir / view + os.makedirs(root, exist_ok=True) + + datasets_utils.create_image_folder( + root=root, + name="scene1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples[p], + size=(3, 100, 200), + ) + + datasets_utils.create_image_folder( + root=split_dir / "occlusions", + name="scene1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=2, + size=(1, 100, 200), + ) + + datasets_utils.create_image_folder( + root=split_dir / "outofframe", + name="scene1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=2, + size=(1, 100, 200), + ) + + datasets_utils.create_image_folder( + root=split_dir / "disparities", + name="scene1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=2, + size=(3, 100, 200), + ) + + if config["pass_name"] == "both": + num_examples = sum(num_examples.values()) + else: + num_examples = num_examples.get(config["pass_name"], 0) + + return num_examples + + def test_splits(self): + for pass_name in ["final", "clean", "both"]: + with self.create_dataset(pass_name=pass_name) as (dataset, _): + for left, right, disparity, valid_mask in dataset: + datasets_utils.shape_test_for_stereo_gt_w_mask(left, right, disparity, valid_mask) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"): + with self.create_dataset(pass_name="bad"): + pass + + +class InStereo2k(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.InStereo2k + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + + @staticmethod + def _make_scene_folder(root: str, name: str, size: Tuple[int, int]): + root = pathlib.Path(root) / name + os.makedirs(root, exist_ok=True) + + datasets_utils.create_image_file(root=root, name="left.png", size=(3, size[0], size[1])) + datasets_utils.create_image_file(root=root, name="right.png", size=(3, size[0], size[1])) + datasets_utils.create_image_file(root=root, name="left_disp.png", size=(1, size[0], size[1])) + datasets_utils.create_image_file(root=root, name="right_disp.png", size=(1, size[0], size[1])) + + def inject_fake_data(self, tmpdir, config): + in_stereo_dir = pathlib.Path(tmpdir) / "InStereo2k" + os.makedirs(in_stereo_dir, exist_ok=True) + + split_dir = pathlib.Path(in_stereo_dir) / config["split"] + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 5}.get(config["split"], 0) + + for i in range(num_examples): + self._make_scene_folder(split_dir, f"scene_{i:06d}", (100, 200)) + + return num_examples + + def test_splits(self): + for split_name in ["train", "test"]: + with self.create_dataset(split=split_name) as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo_gt_no_mask(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index c76a84f8634..6ff1382010d 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,6 +1,7 @@ import pytest import test_models as TM import torch +import torchvision.prototype.models.depth.stereo.crestereo as crestereo import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo from common_utils import set_rng_seed, cpu_and_gpu @@ -36,3 +37,43 @@ def test_raft_stereo(model_builder, model_mode, dev): # Test against expected file output TM._assert_expected(depth_pred, name=model_builder.__name__, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("model_builder", (crestereo.crestereo_base,)) +@pytest.mark.parametrize("model_mode", ("standard", "scripted")) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +def test_crestereo(model_builder, model_mode, dev): + set_rng_seed(0) + + model = model_builder().eval().to(dev) + + if model_mode == "scripted": + model = torch.jit.script(model) + + img1 = torch.rand(1, 3, 256, 256).to(dev) + img2 = torch.rand(1, 3, 256, 256).to(dev) + iterations = 3 + + preds = model(img1, img2, flow_init=None, iterations=iterations) + disparity_pred = preds[-1] + + # all the pyramid levels except the highest res make only half the number of iterations + expected_iterations = (iterations // 2) * (len(model.resolutions) - 1) + expected_iterations += iterations + assert ( + len(preds) == expected_iterations + ), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels" + + assert disparity_pred.shape == torch.Size( + [1, 2, 256, 256] + ), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}" + + assert all( + d.shape == torch.Size([1, 2, 256, 256]) for d in preds + ), "All predicted disparities are expected to have the same shape" + + # test a backward pass with a dummy loss as well + preds = torch.stack(preds, dim=0) + targets = torch.ones_like(preds, requires_grad=False) + loss = torch.nn.functional.mse_loss(preds, targets) + loss.backward() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 295fe922478..8e0e6f274d1 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,15 @@ from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K +from ._stereo_matching import ( + ETH3DStereo, + FallingThingsStereo, + Kitti2012Stereo, + Kitti2015Stereo, + Middlebury2014Stereo, + SceneFlowStereo, + SintelStereo, + CREStereo, + InStereo2k, +) from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -105,4 +116,13 @@ "FGVCAircraft", "EuroSAT", "RenderedSST2", + "StereoETH3D", + "StereoFallingThings", + "StereoKitti2012", + "StereoKitti2015", + "StereoMiddlebury2014", + "StereoSceneFlow", + "StereoSintel", + "CREStereo", + "InStereo2k", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py new file mode 100644 index 00000000000..ff7d0183773 --- /dev/null +++ b/torchvision/datasets/_stereo_matching.py @@ -0,0 +1,1164 @@ +import functools +import json +import os +import random +import shutil +from abc import ABC, abstractmethod +from glob import glob +from pathlib import Path +from typing import Callable, List, Optional, Tuple + +import numpy as np +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg, _read_pfm +from .vision import VisionDataset + +__all__ = ( + "CREStereo" + "Middlebury2014Stereo" + "ETH3DStereo" + "Kitti2012Stereo" + "Kitti2015Stereo" + "SintelStereo" + "SceneFlowStereo" + "FallingThingsStereo" + "InStereo2k" +) + +_read_pfm_file = functools.partial(_read_pfm, slice_channels=1) + + +class StereoMatchingDataset(ABC, VisionDataset): + """Base interface for Stereo matching datasets""" + + _has_built_in_disparity_mask = False + + def __init__(self, root: str, transforms: Optional[Callable] = None): + """ + + Args: + root(str): Root directory of the dataset. + transforms(callable, optional): A function/transform that takes in Tuples of + (images, disparities, valid_masks) and returns a transformed version of each of them. + images is a Tuple of (``PIL.Image``, ``PIL.Image``) + disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W) + valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W) + + In some cases, when a dataset does not provide disparities, the ``disparities`` and + ``valid_masks`` can be Tuples containing None values. + + For training splits generally the datasets provide a minimal guarantee of + images: (``PIL.Image``, ``PIL.Image``) + disparities: (``np.ndarray``, ``None``) with shape (1, H, W) + + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W) + + For some test splits, the datasets provides outputs that look like: + imgaes: (``PIL.Image``, ``PIL.Image``) + disparities: (``None``, ``None``) + + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``None``, ``None``) + """ + super().__init__(root=root) + self.transforms = transforms + + self._images: List[Tuple[str, str]] = [] + self._disparities: List[Tuple[str, str]] = [] + + def _read_img(self, file_path: str) -> Image.Image: + img = Image.open(file_path) + if img.mode != "RGB": + img = img.convert("RGB") + return img + + def _scan_pairs( + self, paths_left_pattern: str, paths_right_pattern: str, fill_empty: bool = False + ) -> List[Tuple[str, str]]: + left_paths: List[str] = sorted(glob(paths_left_pattern)) + right_paths: List[str] = sorted(glob(paths_right_pattern)) + + # used when dealing with inexistent disparity for the right image + if fill_empty: + right_paths = list("" for _ in left_paths) + + if not left_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}") + + if not right_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}") + + if len(left_paths) != len(right_paths): + raise ValueError( + f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n " + f"left pattern: {paths_left_pattern}\n" + f"right pattern: {paths_right_pattern}\n" + ) + + images = list((left, right) for left, right in zip(left_paths, right_paths)) + return images + + @abstractmethod + def _read_disparity(self, file_path: str) -> Tuple: + # function that returns a disparity map and an occlusion map + pass + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask`` + can be a numpy boolean mask of shape (H, W) if the dataset provides a file + indicating which disparity pixels are valid. The disparity is a numpy array of + shape (1, H, W) and the images are PIL images. ``disparity`` is None for + datasets on which for ``split="test"`` the authors did not provide annotations. + """ + img_left = self._read_img(self._images[index][0]) + img_right = self._read_img(self._images[index][1]) + + dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0]) + dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1]) + + imgs = (img_left, img_right) + dsp_maps = (dsp_map_left, dsp_map_right) + valid_masks = (valid_mask_left, valid_mask_right) + + if self.transforms is not None: + ( + imgs, + dsp_maps, + valid_masks, + ) = self.transforms(imgs, dsp_maps, valid_masks) + + if self._has_built_in_disparity_mask or valid_masks[0] is not None: + return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] + else: + return imgs[0], imgs[1], dsp_maps[0] + + def __len__(self) -> int: + return len(self._images) + + +class CREStereo(StereoMatchingDataset): + """Synthetic dataset used in training the `CREStereo `_ architecture. + + Dataset details on the official paper `repo `_. + + The dataset is expected to have the following structure: :: + + root + CREStereo + tree + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + img2_left.jpg + img2_right.jpg + img2_left.disp.jpg + img2_right.disp.jpg + ... + shapenet + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + reflective + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + hole + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + + Args: + root (str): Root directory of the dataset. + split (str): The split of the dataset to use. One of ``"tree"``, ``"shapenet"``, ``"reflective"``, ``"hole"`` + or ``"all"``. The ``"all"`` split contains all of the above splits. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + download (bool, optional): If true, downloads the dataset from the internet and puts it in the root directory. + max_disparity (int, optional): Maximum disparity value. Used to compute the valid mask. + """ + + DOWNLOAD_SPACE = 400 * 1024 * 1024 * 1024 + + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + download: bool = False, + max_disparity: float = 256.0, + ): + super().__init__(root, transforms) + self._has_built_in_disparity_mask = True + + root = Path(root) / "CREStereo" + self.max_disparity = max_disparity + + # if the API user requests a dataset download check that the user can download it + if download: + statvfs = os.statvfs(root) + # measured in bytes + available_space = statvfs.f_frsize * statvfs.f_bavail + if available_space - self.DOWNLOAD_SPACE < 0: + raise ValueError( + f"The storage device for {str(root)} is too small to download the dataset), " + f"an additional {self.DOWNLOAD_SPACE - available_space:.2f} GB are required." + ) + self._download_dataset(str(root)) + + dirs = ["shapenet", "reflective", "tree", "hole"] + + for s in dirs: + left_image_pattern = str(root / s / "*_left.jpg") + right_image_pattern = str(root / s / "*_right.jpg") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images += imgs + + left_disparity_pattern = str(root / s / "*_left.disp.png") + right_disparity_pattern = str(root / s / "*_right.disp.png") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities += disparities + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def _download_dataset(self, root: str) -> None: + dirs = ["tree", "shapenet", "reflective", "hole"] + # create directory subtree for the download + for d in dirs: + d_path = os.path.join(root, d) + if not os.path.exists(d_path): + os.makedirs(d_path) + + for i in range(10): + url = f"https://data.megengine.org.cn/research/crestereo/dataset/{d}/{i}.tar" + download_and_extract_archive(url=url, download_root=d_path, remove_finished=True) + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + """ + return super().__getitem__(index) + + +class Middlebury2014Stereo(StereoMatchingDataset): + """Publicly available scenes from the Middlebury dataset `2014 version `. + + The dataset mostly follows the original format, without containing the ambient subdirectories. : :: + + root + Middlebury2014 + train + scene1-{ ,perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + scene2-{ ,perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + ... + additional + scene1-{ ,perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + ... + test + scene1 + calib.txt + im{0,1}.png + scene2 + calib.txt + im{0,1}.png + ... + + + Args: + root (string): Root directory of the Middleburry 2014 Dataset. + split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional" + use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible. + The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``. + calibration (string, optional): Wether or not to use the calibrated (default) or uncalibrated scenes. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + download (boolean, optional): Wether or not to download the dataset in the ``root`` directory. + """ + + splits = { + "train": [ + "Adirondack", + "Jadeplant", + "Motorcycle", + "Piano", + "Pipes", + "Playroom", + "Playtable", + "Recycle", + "Shelves", + "Vintage", + ], + "additional": [ + "Backpack", + "Bicycle1", + "Cable", + "Classroom1", + "Couch", + "Flowers", + "Mask", + "Shopvac", + "Sticks", + "Storage", + "Sword1", + "Sword2", + "Umbrella", + ], + "test": [ + "Plants", + "Classroom2E", + "Classroom2", + "Australia", + "DjembeL", + "CrusadeP", + "Crusade", + "Hoops", + "Bicycle2", + "Staircase", + "Newkuba", + "AustraliaP", + "Djembe", + "Livingroom", + "Computer", + ], + } + + def __init__( + self, + root: str, + split: str = "train", + calibration: Optional[str] = "perfect", + use_ambient_views: bool = False, + transforms: Optional[Callable] = None, + download: bool = False, + ): + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test", "additional")) + self.split = split + + if calibration: + verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore + if split == "test": + raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.") + else: + if split != "test": + raise ValueError( + f"Split '{split}' has calibration settings, however None was provided as an argument." + f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.", + ) + + if download: + self._download_dataset(root) + + root = Path(root) / "Middlebury2014" + + if not os.path.exists(root / split): + raise FileNotFoundError(f"The {split} directory was not found in the provided root directory") + + split_scenes = self.splits[split] + # check that the provided root folder contains the scene splits + if not any( + # using startswith to account for perfect / imperfect calibrartion + scene.startswith(s) + for scene in os.listdir(root / split) + for s in split_scenes + ): + raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.") + + calibrartion_suffixes = { + None: [""], + "perfect": ["-perfect"], + "imperfect": ["-imperfect"], + "both": ["-perfect", "-imperfect"], + }[calibration] + + for calibration_suffix in calibrartion_suffixes: + scene_pattern = "*" + calibration_suffix + left_img_pattern = str(root / split / scene_pattern / "im0.png") + right_img_pattern = str(root / split / scene_pattern / "im1.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "test": + self._disparities += list(("", "") for _ in self._images) + else: + left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm") + right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm") + self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern) + + self.use_ambient_views = use_ambient_views + + def _read_img(self, file_path: str) -> Image.Image: + """ + Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True. + When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]`` + as the right image. + """ + if os.path.basename(file_path) == "im1.png" and self.use_ambient_views: + # initialize sampleable container + base_path = os.path.dirname(file_path) + ambient_file_paths = list(os.path.join(base_path, view_name) for view_name in ["im1E.png", "im1L.png"]) + # double check that we're not going to try to read from an invalid file path + ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths)) + # keep the original image as an option as well for uniform sampling between base views + ambient_file_paths.append(file_path) + file_path = random.choice(ambient_file_paths) + return super()._read_img(file_path) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has not disparity maps + if not os.path.exists(file_path): + return None, None + + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def _download_dataset(self, root: str): + base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" + # train and additional splits have 2 different calibration settings + root = Path(root) / "Middlebury2014" + split_name = self.split + + if split_name != "test": + for split_scene in self.splits[split_name]: + split_root = root / split_name + for calibration in ["perfect", "imperfect"]: + scene_name = f"{split_scene}-{calibration}" + scene_url = f"{base_url}/{scene_name}.zip" + print(f"Downloading {scene_url}") + # download the scene only if it doesn't exist + if not os.path.exists(split_root / scene_name): + download_and_extract_archive( + url=scene_url, + filename=f"{scene_name}.zip", + download_root=str(split_root), + remove_finished=True, + ) + else: + os.makedirs(root / "test") + if any(s not in os.listdir(root / "test") for s in self.splits["test"]): + # test split is downloaded from a different location + test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip" + # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF + # we want to move the contents from testF into the directory + download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True) + for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")): + for scene in scene_names: + scene_dst_dir = root / "test" + scene_src_dir = Path(scene_dir) / scene + os.makedirs(scene_dst_dir, exist_ok=True) + shutil.move(str(scene_src_dir), str(scene_dst_dir)) + + # cleanup MiddEval3 directory + shutil.rmtree(str(root / "MiddEval3")) + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) + + +class ETH3DStereo(StereoMatchingDataset): + """ "ETH3D `Low-Res Two-View `_ dataset. + + The dataset is expected to have the following structure: :: + + root + ETH3D + two_view_training + scene1 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + scene2 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + ... + two_view_training_gt + scene1 + disp0GT.pfm + mask0nocc.png + scene2 + disp0GT.pfm + mask0nocc.png + ... + two_view_testing + scene1 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + scene2 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + ... + + Args: + root (string): Root directory of the ETH3D Dataset. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + calibration (string, optional): Wether or not to use the calibrated (default) or uncalibrated scenes. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + # needed for output consistency, otherwise tests get fussy about + # variable sized FEATURE_TYPES based on dataset split + self._has_built_in_disparity_mask = True + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "ETH3D" + + img_dir = "two_view_training" if split == "train" else "two_view_test" + anot_dir = "two_view_training_gt" + + left_img_pattern = str(root / img_dir / "*" / "im0.png") + right_img_pattern = str(root / img_dir / "*" / "im1.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "test": + self._disparities = list(("", "") for _ in self._images) + else: + disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm") + self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if not os.path.exists(file_path): + return None, None + + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + mask_path = os.path.join(os.path.split(file_path)[0], "mask0nocc.png") + valid_mask = Image.open(mask_path) + valid_mask = np.asarray(valid_mask).astype(np.bool_) + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index) + + +class Kitti2012Stereo(StereoMatchingDataset): + """ "Kitti dataset from the `2012 `_ stereo evaluation benchmark. + Uses the RGB images for consistency with Kitti 2015. + + The dataset is expected to have the following structure: :: + + root + Kitti2012 + testing + colored_0 + colored_1 + training + colored_0 + colored_1 + disp_noc + calib + + Args: + root (string): Root directory where Kitti2012 is located. + split (string, optional): The dataset split of scenes, either "train" (default), test, or "additional" + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + download (boolean, optional): Wether or not to download the dataset in the ``root`` directory. + """ + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + self._has_built_in_disparity_mask = True + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2012" / (split + "ing") + + left_img_pattern = str(root / "colored_0" / "*_10.png") + right_img_pattern = str(root / "colored_1" / "*_10.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + disparity_pattern = str(root / "disp_noc" / "*.png") + self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True) + else: + self._disparities = list(("", "") for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if not os.path.exists(file_path): + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index) + + +class Kitti2015Stereo(StereoMatchingDataset): + """ "Kitti dataset from the `2015 `_ stereo evaluation benchmark. + + The dataset is expected to have the following structure: :: + + root + Kitti2015 + testing + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + training + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + disp_occ_0 + img1.png + img2.png + ... + disp_occ_1 + img1.png + img2.png + ... + calib + + Args: + root (string): Root directory where Kitti2015 is located. + split (string, optional): The dataset split of scenes, either "train" (default) or test. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + self._has_built_in_disparity_mask = True + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2015" / (split + "ing") + left_img_pattern = str(root / "image_2" / "*.png") + right_img_pattern = str(root / "image_3" / "*.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + left_disparity_pattern = str(root / "disp_occ_0" / "*.png") + right_disparity_pattern = str(root / "disp_occ_1" / "*.png") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + else: + self._disparities = list(("", "") for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if not os.path.exists(file_path): + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index) + + +class SintelStereo(StereoMatchingDataset): + """ "Sintel `Stereo Dataset `_. + + The dataset is expected to have the following structure: :: + + root + Sintel + training + final_left + scene1 + img1.png + img2.png + ... + ... + final_right + scene2 + img1.png + img2.png + ... + ... + disparities + scene1 + img1.png + img2.png + ... + ... + occlusions + scene1 + img1.png + img2.png + ... + ... + outofframe + scene1 + img1.png + img2.png + ... + ... + + Args: + root (string): Root directory where Sintel Stereo is located. + pass_name (string): The name of the pass to use, either "final" or "clean". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) + + root = Path(root) / "Sintel" + pass_names = { + "final": ["final"], + "clean": ["clean"], + "both": ["final", "clean"], + }[pass_name] + + for p in pass_names: + left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png") + right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png") + self._disparities += self._scan_pairs(disparity_pattern, "", fill_empty=True) + + def _get_oclussion_mask_paths(self, file_path: str) -> Tuple[str, str]: + path_tokens = file_path.split(os.sep) + rets = None + + for idx in range(len(path_tokens) - 1): + if path_tokens[idx] == "training" and path_tokens[idx + 1] == "disparities": + pre_tokens = path_tokens[: idx + 1] + post_tokens = path_tokens[idx + 2 :] + rets = ( + "/".join(pre_tokens + ["occlusions"] + post_tokens), + "/".join(pre_tokens + ["outofframe"] + post_tokens), + ) + break + + if rets is None: + raise ValueError("Malformed file path: {}".format(file_path)) + + for path in rets: + if not os.path.exists(path): + raise ValueError(f"Could not find file {path}") + + return rets + + def _read_disparity(self, file_path: str) -> Tuple: + if not os.path.exists(file_path): + return None, None + + # disparity decoding as per Sintel instructions in the README provided with the dataset + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + r, g, b = np.split(disparity_map, 3, axis=-1) + disparity_map = r * 4 + g / (2 ** 6) + b / (2 ** 14) + # reshape into (C, H, W) format + disparity_map = np.transpose(disparity_map, (2, 0, 1)) + # find the appropiate file paths + occlued_mask_path, out_of_frame_mask_path = self._get_oclussion_mask_paths(file_path) + # occlusion masks + valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0 + # out of frame masks + off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0 + # combine the masks together + valid_mask = np.logical_and(off_mask, valid_mask) + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) + + +class SceneFlowStereo(StereoMatchingDataset): + """Dataset interface for `Scene Flow `_ datasets. + + The dataset is expected to have the following structre: :: + + root + SceneFlow + Monkaa + frames_cleanpass + scene1 + left + img1.png + img2.png + right + img1.png + img2.png + scene2 + left + img1.png + img2.png + right + img1.png + img2.png + frames_finalpass + scene1 + left + img1.png + img2.png + right + img1.png + img2.png + ... + ... + disparity + scene1 + left + img1.pfm + img2.pfm + right + img1.pfm + img2.pfm + FlyingThings3D + ... + ... + + Args: + root (string): Root directory where SceneFlow is located. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__( + self, root: str, split: str = "FlyingThings3D", pass_name: str = "clean", transforms: Optional[Callable] = None + ): + super().__init__(root, transforms) + + root = Path(root) / "SceneFlow" + + verify_str_arg(split, "split", valid_values=("FlyingThings3D", "Driving", "Monkaa")) + verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) + + passes = { + "clean": ["frames_cleanpass"], + "final": ["frames_finalpass"], + "both": ["frames_cleanpass, frames_finalpass"], + }[pass_name] + + root = root / split + + for p in passes: + left_img_pattern = str(root / p / "*" / "left" / "*.png") + right_img_pattern = str(root / p / "*" / "right" / "*.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + left_disparity_pattern = str(root / "disparity" / "*" / "left" / "*.pfm") + right_disparity_pattern = str(root / "disparity" / "*" / "right" / "*.pfm") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) + + +class FallingThingsStereo(StereoMatchingDataset): + """FallingThings ``_ dataset + + The dataset is expected to have the following structre: :: + + root + FallingThings + single + scene1 + _object_settings.json + _camera_settings.json + image1.left.depth.png + image1.right.depth.png + image1.left.jpg + image1.right.jpg + image2.left.depth.png + image2.right.depth.png + image2.left.jpg + image2.right + ... + scene2 + ... + mixed + scene1 + _object_settings.json + _camera_settings.json + image1.left.depth.png + image1.right.depth.png + image1.left.jpg + image1.right.jpg + image2.left.depth.png + image2.right.depth.png + image2.left.jpg + image2.right + ... + scene2 + ... + + Args: + root (string): Root directory where FallingThings is located. + split (string): Either "single", "mixed", or "both". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + + """ + + def __init__(self, root: str, split: str = "single", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + root = Path(root) / "FallingThings" + + verify_str_arg(split, "split", valid_values=("single", "mixed", "both")) + + splits = { + "single": ["single"], + "mixed": ["mixed"], + "both": ["single", "mixed"], + }[split] + + for s in splits: + left_img_pattern = str(root / s / "*" / "*.left.jpg") + right_img_pattern = str(root / s / "*" / "*.right.jpg") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + left_disparity_pattern = str(root / s / "*" / "*.left.depth.png") + right_disparity_pattern = str(root / s / "*" / "*.right.depth.png") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple: + # (H, W) image + depth = np.asarray(Image.open(file_path)) + # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt + # in order to extract disparity from depth maps + with open(os.path.split(file_path)[0] + "/_camera_settings.json", "r") as f: + # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constatnt) + intrinsics = json.load(f) + focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"] + baseline, pixel_constant = 6.0, 100.0 # pixel constant is inverted + disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32) + # unsqueeze disparity to (C, H, W) + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) + + +class InStereo2k(StereoMatchingDataset): + """InStereo2k ``_ dataset + + The dataset is expected to have the following structre: :: + + root + InStereo2k + train + scene1 + left.png + right.png + left_disp.png + right_disp.png + ... + scene2 + ... + test + scene1 + left.png + right.png + left_disp.png + right_disp.png + ... + scene2 + ... + + Args: + root (string): Root directory where InStereo2k is located. + split (string): Either "train" or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + root = Path(root) / "InStereo2k" / split + + verify_str_arg(split, "split", valid_values=("train", "test")) + + left_img_pattern = str(root / "*" / "left.png") + right_img_pattern = str(root / "*" / "right.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + left_disparity_pattern = str(root / "*" / "left_disp.png") + right_disparity_pattern = str(root / "*" / "right_disp.png") + self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + # unsqueeze disparity to (C, H, W) + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) diff --git a/torchvision/prototype/models/depth/stereo/crestereo.py b/torchvision/prototype/models/depth/stereo/crestereo.py new file mode 100644 index 00000000000..92a75d20ce3 --- /dev/null +++ b/torchvision/prototype/models/depth/stereo/crestereo.py @@ -0,0 +1,1007 @@ +import math +from functools import partial +from typing import Iterable, List, Optional, Callable, Tuple, Dict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models.optical_flow.raft as raft +from torch import Tensor +from torchvision.models._api import WeightsEnum +from torchvision.models.optical_flow._utils import make_coords_grid, grid_sample, upsample_flow +from torchvision.ops import Conv2dNormActivation + + +class ResidualBlock(raft.ResidualBlock): + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): + super().__init__(in_channels, out_channels, norm_layer=norm_layer, stride=stride) + + # the CREStereo base architecture changes the number of channels + # even on grids with the same spatial resolution + if in_channels != out_channels: + self.downsample = Conv2dNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + +class FeatureEncoder(raft.FeatureEncoder): + """Base encoder for Feature Encoder and Context Encoder""" + + def __init__( + self, + *, + block: Callable[..., nn.Module] = ResidualBlock, + layers: Tuple[int, int, int, int, int] = (64, 64, 96, 128, 256), + strides: Tuple[int, int, int, int] = (2, 1, 2, 1), + norm_layer: Callable[..., nn.Module] = nn.InstanceNorm2d, + ): + super().__init__(block=block, layers=layers, strides=strides, norm_layer=norm_layer) + for s in strides: + if s not in [1, 2]: + raise ValueError(f"FeatureEncoder unsupported stride size {s}. Supported values are one of ``[1, 2]``.") + + self.output_dim = layers[-1] + num_downsamples = len(list(filter(lambda s: s == 2, strides))) + self.downsample_factor = 2 ** num_downsamples + + +class ConvexMaskPredictor(nn.Module): + def __init__( + self, + *, + in_channels: int, + hidden_size: int, + upsample_factor: int, + multiplier: float = 0.25, + ) -> None: + + super().__init__() + self.mask_head = nn.Sequential( + Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3), + nn.Conv2d(hidden_size, upsample_factor ** 2 * 9, 1, padding=0), + ) + + self.multiplier = multiplier + + def forward(self, x: Tensor) -> Tensor: + x = self.mask_head(x) * self.multiplier + return x + + +class AdaptiveGroupCorrelationLayer(nn.Module): + """ + Container for computing various correlation types between a left and right feature map. + This module does not contain any optimisable parameters, it's solely a collection of ops. + We wrap in a nn.Module for torch.jit.script compatibility + + Adaptive Group Correlation operations from: https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf + + Canonical reference implementation: https://github.com/megvii-research/CREStereo/blob/master/nets/corr.py + """ + + def __init__( + self, + attention_module: Optional[nn.Module] = None, + groups: int = 4, + search_window_1d: Tuple[int, int] = (1, 9), + search_dilate_1d: Tuple[int, int] = (1, 1), + search_window_2d: Tuple[int, int] = (3, 3), + search_dilate_2d: Tuple[int, int] = (1, 1), + ) -> None: + super().__init__() + self.attention_module = attention_module + + assert np.prod(search_window_1d) == np.prod(search_window_2d), ( + f"The 1D and 2D windows should contain the same number of elements. " + f"1D shape: {search_window_1d} 2D shape: {search_window_2d}" + ) + + assert np.prod(search_window_1d) % 2 == 1, ( + f"Search windows should contain an odd number of elements in them." + f"Window of shape {search_window_1d} has {np.prod(search_window_1d)} elements." + ) + + assert any( + size == 1 for size in search_window_1d + ), f"The 1D search window should have at least one size equal to 1. 1D shape: {search_window_1d}" + + assert all( + size != 1 for size in search_window_2d + ), f"The 2D search window should have all dimensions greater than 1. 2D shape: {search_window_2d}" + + self.search_window_1d = search_window_1d + self.search_window_2d = search_window_2d + + self.search_dilate_1d = search_dilate_1d + self.search_dilate_2d = search_dilate_2d + + self.groups = groups + + # two selection tables for dealing withh the small_patch argument in the forward function + self.patch_sizes = { + True: [self.search_window_2d for _ in range(self.groups)], + False: [self.search_window_1d for _ in range(self.groups)], + } + + self.dilate_sizes = { + True: [self.search_dilate_2d for _ in range(self.groups)], + False: [self.search_dilate_1d for _ in range(self.groups)], + } + + def forward( + self, + left_features: Tensor, + right_features: Tensor, + flow: torch.Tensor, + extra_offset: Union[torch.Tensor, None], + use_small_patch: bool = False, + iter_mode: bool = False, + ): + if iter_mode or extra_offset is None: + corr = self.iterative_correlation(left_features, right_features, flow, use_small_patch) + else: + corr = self.attention_offset_correlation(left_features, right_features, flow, extra_offset, use_small_patch) # type: ignore + return corr + + def _make_coords(self, feature_map: Tensor) -> Tensor: + return make_coords_grid(feature_map.shape[0], feature_map.shape[2], feature_map.shape[3]).to(feature_map.device) + + def get_correlation( + self, + left_feature: Tensor, + right_feature: Tensor, + window_size: Tuple[int, int] = (3, 3), + dilate: Tuple[int, int] = (1, 1), + ) -> Tensor: + """Function that computes a correlation product between the left and right features. + + The correlation is computed in a sliding window fashion, namely the the left features are fixed + and for each ``(i, j)`` location we compute the correlation with a sliding window anchored in + ``(i, j)`` from the right feature map. The sliding window selects pixels obtained in the range of the sliding + window; i.e ``(i - window_size // 2, i + window_size // 2)`` respectively ``(j - window_size // 2, j + window_size // 2)``. + """ + + B, C, H, W = left_feature.shape + + di_y, di_x = dilate[0], dilate[1] + pad_y, pad_x = window_size[0] // 2 * di_y, window_size[1] // 2 * di_x + + right_padded = F.pad(right_feature, (pad_x, pad_x, pad_y, pad_y), mode="replicate") + right_padded = right_padded.detach() + # in order to vectorize the correlation computation over all pixel candidates + # we create multiple shifted right images which we stack on an extra dimension + right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate) + # torch unfold returns a tensor of shape [B, flattened_values, n_selections] + right_padded = right_padded.permute(0, 2, 1) + # we consider rehsape back into [B, n_views, C, H, W] + right_padded = right_padded.reshape(B, (window_size[0] * window_size[1]), C, H, W) + # we expand the left features for broadcasting + left_feature = left_feature.unsqueeze(1) + # this will compute an element product of between [B, 1, C, H, W] * [B, n_views, C, H, W] + # to obtain correlations over the pixel canditates we perform a mean on the C dimension + correlation = torch.mean(left_feature * right_padded, dim=2, keepdim=False) + # the final correlation tensor shape will be [B, n_views, H, W] + # where on the i-th position of the n_views dimension we will have + # the correlation value between the left pixel + # and the i-th candidate on the right feature map + return correlation + + def iterative_correlation( + self, left_feature: Tensor, right_feature: Tensor, flow: Tensor, use_small_patch: bool = False + ) -> Tensor: + """Function that computes 1 pass of non-offsetted Group-Wise correlation""" + coords = self._make_coords(left_feature) + + # we offset the coordinate grid in the flow direction + coords = coords + flow + coords = coords.permute(0, 2, 3, 1) + # resample right features according to off-setted grid + right_feature = grid_sample(right_feature, coords, mode="bilinear", align_corners=True) + + # use_small_patch is a flag by which we decide on how many axes + # we perform candidate search. See section 3.1 ``Deformable search window`` & Figure 4 in the paper. + patch_size_list = self.patch_sizes[use_small_patch] + dilate_size_list = self.dilate_sizes[use_small_patch] + + # chunking the left and right feature to perform group-wise correlation + # mechanism simillar to GroupNorm. See section 3.1 ``Group-wise correlation``. + left_groups = torch.chunk(left_feature, self.groups, dim=1) + right_groups = torch.chunk(right_feature, self.groups, dim=1) + + correlations = [] + # this boils down to rather than performing the correlation product + # over the entire C dimensions, we use subsets of C to get multiple correlation sets + for i in range(len(patch_size_list)): + correlation = self.get_correlation(left_groups[i], right_groups[i], patch_size_list[i], dilate_size_list[i]) + correlations.append(correlation) + final_correlations = torch.cat(correlations, dim=1) + return final_correlations + + def attention_offset_correlation( + self, + left_feature: Tensor, + right_feature: Tensor, + flow: Tensor, + extra_offset: Tensor, + use_small_patch: bool = False, + ): + """Function that computes 1 pass of offsetted Group-Wise correlation + + If the class was provided with an attention layer, the left and right feature maps + will be passed through a transformer first + """ + B, C, H, W = left_feature.shape + + if self.attention_module is not None: + # prepare for transformer required input shapes + left_feature = left_feature.permute(0, 2, 3, 1).reshape(B, H * W, C) + right_feature = right_feature.permute(0, 2, 3, 1).reshape(B, H * W, C) + # this can be either self attention or cross attention, hence the tupple return + left_feature, right_feature = self.attention_module(left_feature, right_feature) + left_feature = left_feature.reshape(B, H, W, C).permute(0, 3, 1, 2) + right_feature = right_feature.reshape(B, H, W, C).permute(0, 3, 1, 2) + + left_groups = torch.chunk(left_feature, self.groups, dim=1) + right_groups = torch.chunk(right_feature, self.groups, dim=1) + + num_search_candidates = 9 + # for each pixel (i, j) we have a number of search candidates + # thus, for each candidate we should have an X-axis and Y-axis offset value + extra_offset = extra_offset.reshape(B, num_search_candidates, 2, H, W).permute(0, 1, 3, 4, 2) + + # see line 133 for details + patch_size_list = self.patch_sizes[use_small_patch] + dilate_size_list = self.dilate_sizes[use_small_patch] + + group_channels = C // self.groups + correlations = [] + + for i in range(len(patch_size_list)): + left_group, right_group = left_groups[i], right_groups[i] + patch_size, dilate = patch_size_list[i], dilate_size_list[i] + + di_y, di_x = dilate + ps_y, ps_x = patch_size + # define the search based on the window patch shape + ry, rx = ps_y // 2 * di_y, ps_x // 2 * di_x + + # base offsets for search (i.e. where to look on the search index) + x_grid, y_grid = torch.meshgrid( + torch.arange(-rx, rx + 1, di_x), torch.arange(-ry, ry + 1, di_y), indexing="xy" + ) + x_grid, y_grid = x_grid.to(flow.device), y_grid.to(flow.device) + offsets = torch.stack((x_grid, y_grid)) + offsets = offsets.reshape(2, -1).permute(1, 0) + + for d in (0, 2, 3): + offsets = offsets.unsqueeze(d) + # extra offsets for search (i.e. deformed search indexes. Simillar concept to deformable convolutions) + offsets = offsets + extra_offset + + coords = self._make_coords(left_feature) + flow + coords = coords.permute(0, 2, 3, 1).unsqueeze(1) + coords = coords + offsets + coords = coords.reshape(B, -1, W, 2) + + right_group = grid_sample(right_group, coords, mode="bilinear", align_corners=True) + # we do not need to perform any window shifting because the grid sample op + # will return a multi-view right based on the num_search_candidates dimension in the offsets + right_group = right_group.reshape(B, -1, group_channels, H, W) + left_group = left_group.reshape(B, -1, group_channels, H, W) + correlation = torch.mean(left_group * right_group, dim=2) + correlations.append(correlation) + + final_correlation = torch.cat(correlations, dim=1) + return final_correlation + + +def elu_feature_map(x: Tensor) -> Tensor: + """Elu feature map operation from: https://arxiv.org/pdf/2006.16236.pdf""" + return F.elu(x) + 1 + + +class LinearAttention(nn.Module): + """ + Linear attention operation from: https://arxiv.org/pdf/2006.16236.pdf + Cannonical implementation reference: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py + LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py + """ + + def __init__(self, eps: float = 1e-6, feature_map_fn: Callable[[Tensor], Tensor] = elu_feature_map) -> None: + super().__init__() + self.eps = eps + self.feature_map_fn = elu_feature_map + + def forward( + self, + queries: Tensor, + keys: Tensor, + values: Tensor, + q_mask: Optional[Tensor] = None, + kv_mask: Optional[Tensor] = None, + ): + """ + Args: + queries (torch.Tensor): [N, S1, H, D] + keys (torch.Tensor): [N, S2, H, D] + values (torch.Tensor): [N, S2, H, D] + q_mask (torch.Tensor): [N, S1] (optional) + kv_mask (torch.Tensor): [N, S2] (optional) + Returns: + queried_values (torch.Tensor): [N, S1, H, D] + """ + queries = self.feature_map_fn(queries) + values = self.feature_map_fn(values) + + if q_mask is not None: + queries = queries * q_mask[:, :, None, None] + if kv_mask is not None: + keys = keys * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + # mitigates fp16 overflows + values_length = values.shape[1] + values = values / values_length + kv = torch.einsum("NSHD, NSHV -> NHDV", keys, values) + z = 1 / (torch.einsum("NLHD, NHD -> NLH", queries, keys.sum(dim=1)) + self.eps) + # rescale at the end to account for fp16 mitigation + queried_values = torch.einsum("NLHD, NHDV, NLH -> NLHV", queries, kv, z) * values_length + return queried_values + + +class SoftmaxAttention(nn.Module): + """ + A simple softmax attention operation + LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py + """ + + def __init__(self, dropout: float = 0.0) -> None: + super().__init__() + self.dropout = nn.Dropout(dropout) if dropout else nn.Identity() + + def forward( + self, + queries: Tensor, + keys: Tensor, + values: Tensor, + q_mask: Optional[Tensor] = None, + kv_mask: Optional[Tensor] = None, + ): + """ + Computes classical softmax full-attention between all queries and keys. + + Args: + queries (torch.Tensor): [N, S1, H, D] + keys (torch.Tensor): [N, S2, H, D] + values (torch.Tensor): [N, S2, H, D] + q_mask (torch.Tensor): [N, S1] (optional) + kv_mask (torch.Tensor): [N, S2] (optional) + Returns: + queried_values: [N, S1, H, D] + """ + + scale_factor = 1.0 / queries.shape[3] ** 0.5 # irsqrt(D) scaling + queries = queries * scale_factor + + qk = torch.einsum("NLHD, NSHD -> NLSH", queries, keys) + if kv_mask is not None and q_mask is not None: + qk.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf")) + + attention = torch.softmax(qk, dim=2) + attention = self.dropout(attention) + + queried_values = torch.einsum("NLSH, NSHD -> NLHD", attention, values) + return queried_values + + +class PositionalEncodingSine(nn.Module): + """ + Sinusoidal positonal encodings + + Using the scaling term from https://github.com/megvii-research/CREStereo/blob/master/nets/attention/position_encoding.py + Reference implementation from https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/position_encoding.py#L28-L48 + """ + + def __init__(self, dim_model: int) -> None: + super().__init__() + self.dim_model = dim_model + self.scale_factor = -math.log(10_000) / (dim_model // 2) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: [B, C, H, W] + """ + torch._assert( + len(x.shape) == 4, + f"PositionalEncodingSine requires a 4-D dimensional input. Provided tensor is of shape {x.shape}", + ) + + coords = torch.ones(size=x.shape[2:], dtype=x.dtype, device=x.device) + positions_y = coords.cumsum(0).unsqueeze(0).unsqueeze(-1) + positions_x = coords.cumsum(1).unsqueeze(0).unsqueeze(-1) + + div_term = torch.exp(torch.arange(0, self.dim_model // 2, dtype=x.dtype, device=x.device) * self.scale_factor) + positions_x = positions_x * div_term + positions_y = positions_y * div_term + + positions_x = torch.stack((positions_x[..., 0::2].sin(), positions_x[..., 1::2].cos()), dim=4).flatten(3) + positions_y = torch.stack((positions_y[..., 0::2].sin(), positions_y[..., 1::2].cos()), dim=4).flatten(3) + + positional_embeddings = torch.cat((positions_x, positions_y), dim=3).permute(0, 3, 1, 2) + return x + positional_embeddings + + +class LocalFeatureEncoderLayer(nn.Module): + """ + LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf + Cannonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py + """ + + def __init__( + self, + *, + dim_model: int, + num_heads: int, + attention_type: str = "linear", + ) -> None: + super().__init__() + + if attention_type not in ["linear", "softmax"]: + raise ValueError( + f"Unsuported attention type {attention_type}. LocalFeatureEncoderLayer supports one of ``[linear, softmax]``" + ) + + self.dim_head = dim_model // num_heads + self.num_heads = num_heads + + # multi-head attention + self.query_proj = nn.Linear(dim_model, dim_model, bias=False) + self.key_proj = nn.Linear(dim_model, dim_model, bias=False) + self.value_proj = nn.Linear(dim_model, dim_model, bias=False) + self.attention_op = LinearAttention() if attention_type == "linear" else SoftmaxAttention() + self.merge = nn.Linear(dim_model, dim_model, bias=False) + + # feed forward network + self.ffn = nn.Sequential( + nn.Linear(dim_model * 2, dim_model * 2, bias=False), + nn.ReLU(), + nn.Linear(dim_model * 2, dim_model, bias=False), + ) + + # norm layers + self.attention_norm = nn.LayerNorm(dim_model) + self.ffn_norm = nn.LayerNorm(dim_model) + + def forward(self, x: Tensor, source: Tensor, x_mask: Optional[Tensor] = None, source_mask: Optional[Tensor] = None): + """ + Args: + x (torch.Tensor): [B, S1, D] + source (torch.Tensor): [B, S2, D] + x_mask (torch.Tensor): [B, S1] (optional) + source_mask (torch.Tensor): [B, S2] (optional) + """ + B, S, D = x.shape + queries, keys, values = x, source, source + + queries = self.query_proj(queries).reshape(B, S, self.num_heads, self.dim_head) + keys = self.key_proj(keys).reshape(B, S, self.num_heads, self.dim_head) + values = self.value_proj(values).reshape(B, S, self.num_heads, self.dim_head) + + # attention operation + message = self.attention_op(queries, keys, values, x_mask, source_mask) + # concatenating attention heads together before passing throught projection layer + message = self.merge(message.reshape(B, S, D)) + message = self.attention_norm(message) + + # ffn operation + message = self.ffn(torch.cat([x, message], dim=2)) + message = self.attention_norm(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """ + LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf + Cannonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py + """ + + def __init__( + self, + *, + dim_model: int, + num_heads: int, + attention_directions: List[str], + attention_type: str = "linear", + ) -> None: + super(LocalFeatureTransformer, self).__init__() + + self.attention_directions = attention_directions + for direction in attention_directions: + if direction not in ["self", "cross"]: + raise ValueError( + f"Attention direction {direction} unsupported. LocalFeatureTransformer accepts only ``attention_type`` in ``[self, cross]``." + ) + + self.layers = nn.ModuleList( + [ + LocalFeatureEncoderLayer(dim_model=dim_model, num_heads=num_heads, attention_type=attention_type) + for _ in attention_directions + ] + ) + + def forward( + self, + left_features: Tensor, + right_features: Tensor, + left_mask: Optional[Tensor] = None, + right_mask: Optional[Tensor] = None, + ): + """ + Args: + left_features (torch.Tensor): [N, S1, D] + right_features (torch.Tensor): [N, S2, D] + left_mask (torch.Tensor): [N, S1] (optional) + right_mask (torch.Tensor): [N, S2] (optional) + Returns: + left_features (torch.Tensor): [N, S1, D] + right_features (torch.Tensor): [N, S2, D] + """ + + torch._assert( + left_features.shape[2] == right_features.shape[2], + f"left_features and right_features should have the same embedding dimensions. left_features: {left_features.shape[2]} right_features: {right_features.shape[2]}", + ) + + for idx, layer in enumerate(self.layers): + attention_direction = self.attention_directions[idx] + # for layer, attention_direction in zip(self.layers, self.attention_directions): + + if attention_direction == "self": + left_features = layer(left_features, left_features, left_mask, left_mask) + right_features = layer(right_features, right_features, right_mask, right_mask) + + elif attention_direction == "cross": + left_features = layer(left_features, right_features, left_mask, right_mask) + right_features = layer(right_features, left_features, right_mask, left_mask) + + return left_features, right_features + + +class PyramidDownsample(nn.Module): + """ + A simple wrapper that return and Avg Pool feature pyramid based on the provided scales. + Implicitly returns the input as well. + """ + + def __init__(self, factors: Iterable[int]) -> None: + super().__init__() + self.factors = factors + + def forward(self, x: torch.Tensor) -> List[Tensor]: + results = [x] + for factor in self.factors: + results.append(F.avg_pool2d(x, kernel_size=factor, stride=factor)) + return results + + +class CREStereo(nn.Module): + """ + CREStereo network from: https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf + + Canonical implementation: https://github.com/megvii-research/CREStereo/blob/master/nets/crestereo.py + """ + + def __init__( + self, + *, + feature_encoder: FeatureEncoder, + update_block: raft.UpdateBlock, + flow_head: raft.FlowHead, + self_attn_block: LocalFeatureTransformer, + cross_attn_block: LocalFeatureTransformer, + feature_downsample_rates: Tuple[int, ...] = (2, 4), + correlation_groups: int = 4, + search_window_1d: Tuple[int, int] = (1, 9), + search_dilate_1d: Tuple[int, int] = (1, 1), + search_window_2d: Tuple[int, int] = (3, 3), + search_dilate_2d: Tuple[int, int] = (1, 1), + ) -> None: + super().__init__() + + self.feature_encoder = feature_encoder + self.update_block = update_block + self.flow_head = flow_head + self.self_attn_block = self_attn_block + + # average pooling for the feature encoder outputs + self.downsampling_pyramid = PyramidDownsample(feature_downsample_rates) + self.downsampling_factors: List[int] = [feature_encoder.downsample_factor] + base_downsample_factor: int = self.downsampling_factors[0] + for rate in feature_downsample_rates: + self.downsampling_factors.append(base_downsample_factor * rate) + + # output resolution tracking + self.resolutions: List[str] = [f"1 / {factor}" for factor in self.downsampling_factors] + self.search_pixels = int(np.prod(search_window_1d)) + + # flow convex upsampling mask predictor + self.mask_predictor = ConvexMaskPredictor( + in_channels=feature_encoder.output_dim // 2, + hidden_size=feature_encoder.output_dim, + upsample_factor=4, + multiplier=0.25, + ) + + # offsets modules for offseted feature selection + self.offset_convs = nn.ModuleDict() + self.correlation_layers = nn.ModuleDict() + + offset_conv_layer = partial( + Conv2dNormActivation, + in_channels=feature_encoder.output_dim, + out_channels=self.search_pixels * 2, + norm_layer=None, + activation_layer=None, + ) + + correlation_layer = partial( + AdaptiveGroupCorrelationLayer, + groups=correlation_groups, + search_window_1d=search_window_1d, + search_dilate_1d=search_dilate_1d, + search_window_2d=search_window_2d, + search_dilate_2d=search_dilate_2d, + ) + + # populate the dicts in top to bottom order + # useful for iterating through torch.jit.script module given the network forward pass + # + # Ignore the largest resolution. We handle that separately due to torch.jit.script + # not being to able access to runtime generated keys in ModuleDicts. + # This way, we can keep a generic way of processing all pyramid levels but except + # the final one + + for idx, resolution in enumerate(reversed(self.resolutions[1:])): + # the largest resolution does use offset convolutions for sampling grid coords + offset_conv = None if idx == len(self.resolutions) - 1 else offset_conv_layer() + if offset_conv: + self.offset_convs[resolution] = offset_conv + # only the lowest resolution uses the cross attention module when computing correlation scores + self.correlation_layers[resolution] = ( + correlation_layer(attention_module=cross_attn_block) if idx == 0 else correlation_layer() + ) + + # correlation layer for the largest resolution + self.max_res_correlation_layer = correlation_layer() + + # simple 2D Postional Encodings + self.positional_encodings = PositionalEncodingSine(feature_encoder.output_dim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def unfreeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.train() + + def forward(self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], iterations: int = 10): + features = torch.cat([left_image, right_image], dim=0) + features = self.feature_encoder(features) + left_features, right_features = features.chunk(2, dim=0) + + # update block network state and input context are derived from the left feature map + net, ctx = left_features.chunk(2, dim=1) + net = torch.tanh(net) + ctx = torch.relu(ctx) + + # will output lists of tensor. + l_pyramid = self.downsampling_pyramid(left_features) + r_pyramid = self.downsampling_pyramid(right_features) + net_pyramid = self.downsampling_pyramid(net) + ctx_pyramid = self.downsampling_pyramid(ctx) + + # we store in reversed order because we process the pyramid from top to bottom + l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)} + r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)} + net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)} + ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)} + + # offsets for sampling pixel candidates in the correlation ops + offsets: Dict[str, Tensor] = {} + for resolution, offset_conv in self.offset_convs.items(): + feature_map = l_pyramid[resolution] + offset = offset_conv(feature_map) + offsets[resolution] = (torch.sigmoid(offset) - 0.5) * 2.0 + + # the smallest resolution is prepared for passing through self attention + min_res = self.resolutions[-1] + max_res = self.resolutions[0] + + B, C, MIN_H, MIN_W = l_pyramid[min_res].shape + # add positional encodings + l_pyramid[min_res] = self.positional_encodings(l_pyramid[min_res]) + r_pyramid[min_res] = self.positional_encodings(r_pyramid[min_res]) + # reshaping for transformer + l_pyramid[min_res] = l_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C) + r_pyramid[min_res] = r_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C) + # perform self attention + l_pyramid[min_res], r_pyramid[min_res] = self.self_attn_block(l_pyramid[min_res], r_pyramid[min_res]) + # now we need to reshape back into [B, C, H, W] format + l_pyramid[min_res] = l_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2) + r_pyramid[min_res] = r_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2) + + predictions: List[Tensor] = [] + flow_estimates: Dict[str, Tensor] = {} + # we added this because of torch.script.jit + # also, the predicition prior is always going to have the + # spatial size of the features outputed by the feature encoder + flow_pred_prior: Tensor = torch.empty( + size=(B, 2, left_features.shape[2], left_features.shape[3]), + dtype=l_pyramid[max_res].dtype, + device=l_pyramid[max_res].device, + ) + + if flow_init is not None: + scale = l_pyramid[max_res].shape[2] // flow_init.shape[2] + # in CREStereo implementation they multiply with -scale instead of scale + # upsample_flow multiples with scale, therefor we add the - in front + flow_estimates[max_res] = -upsample_flow(flow_init, up_mask=None, factor=scale) + # when not provided with a flow prior, we construct one using the lower resolution maps + else: + # initialize a zero flow with the smallest resolution + flow = torch.zeros(size=(B, 2, MIN_H, MIN_W), device=left_features.device, dtype=left_features.dtype) + + # flows from coarse resolutions are refined similarly + # we always need to fetch the next pyramid feature map as well + # when updating coarse resolutions, therefore we create a reversed + # view which has its order synced with the ModuleDict keys iterator + coarse_resolutions: List[str] = self.resolutions[::-1] # using slicing because of torch.jit.script + fine_grained_resolution = max_res + + # set the coarsest flow to the zero flow + flow_estimates[coarse_resolutions[0]] = flow + + # the correlation layer ModuleDict will contain layers ordered from coarse to fine resolution + # i.e ["1 / 16", "1 / 8", "1 / 4"] + # the correlation layer ModuleDict has layers for all the resolutions except the fine one + # i.e {"1 / 16": Module, "1 / 8": Module} + # for these resolution we perform only half of the number of refinement iterations + for idx, (resolution, correlation_layer) in enumerate(self.correlation_layers.items()): + # compute the scale difference between the first pyramid scale and the current pyramid scale + scale_to_base = l_pyramid[fine_grained_resolution].shape[2] // l_pyramid[resolution].shape[2] + for it in range(iterations // 2): + # set wether or not we want to search on (X, Y) axes for correlation or just on X axis + use_small_search_patch = (it % 2) == 1 + # we consider this a prior, therefor we do not want to back-propagate through it + flow_estimates[resolution] = flow_estimates[resolution].detach() + + # corr_fn = self.get_module_from_module_dict(self.correlation_functions, resolution) + correlations = correlation_layer( + l_pyramid[resolution], # left + r_pyramid[resolution], # right + flow_estimates[resolution], + offsets[resolution], + use_small_search_patch, + ) + + # update the recurrent network state and the flow deltas + net_pyramid[resolution], delta_flow = self.update_block( + net_pyramid[resolution], ctx_pyramid[resolution], correlations, flow_estimates[resolution] + ) + + # the convex upsampling weights are computed w.r.t. + # the recurrent update state + up_mask = self.mask_predictor(net_pyramid[resolution]) + flow_estimates[resolution] = flow_estimates[resolution] + delta_flow + # convex upsampling with the initial feature encoder downsampling rate + flow_pred_prior = upsample_flow( + flow_estimates[resolution], up_mask, factor=self.downsampling_factors[0] + ) + # we then bilinear upsample to the final resolution + # we use a factor that's equivalent to the difference between + # the current downsample resolution and the base downsample resolution + # + # i.e. if a 1 / 16 flow is upsampled by 4 (base downsampling) we get a 1 / 4 flow. + # therefore we have to further upscale it by the difference between + # the current level 1 / 16 and the base level 1 / 4. + flow_pred = -upsample_flow(flow_pred_prior, None, factor=scale_to_base) + predictions.append(flow_pred) + + # when constructing the next resolution prior, we resample w.r.t + # to the scale of the next level in the pyramid + next_resolution = coarse_resolutions[idx + 1] + scale_to_next = l_pyramid[next_resolution].shape[2] / flow_pred_prior.shape[2] + # we use the flow_up_prior because this is a more accurate estimation of the true flow + # due to the convex upsample, which resembles a learned super-resolution module. + # this is not necessarily an upsample, it can be a downsample, based on the provided configuration + flow_estimates[next_resolution] = -scale_to_next * F.interpolate( + input=flow_pred_prior, + size=l_pyramid[next_resolution].shape[2:], + mode="bilinear", + align_corners=True, + ) + + # finally we will be doing a full pass through the fine-grained resolution + # this coincides with the maximum resolution + + # we keep a separate loop here in order to avoid python control flow + # to decide how much iterations should we do based on the current resolution + # further more, if provided with an inital flow, there is no need to generate + # a prior estimate when moving into the final refinement stage + + for it in range(iterations): + use_small_search_patch = (it % 2) == 1 + + flow_estimates[max_res] = flow_estimates[max_res].detach() + # we run the fine-grained resolution correlations in iterative mode + # this means that we are using the fixed window pixel selections + # instead of the deformed ones as with the previous steps + correlations = self.max_res_correlation_layer( + l_pyramid[max_res], + r_pyramid[max_res], + flow_estimates[max_res], + extra_offset=None, + use_small_patch=use_small_search_patch, + iter_mode=True, + ) + + net_pyramid[max_res], delta_flow = self.update_block( + net_pyramid[max_res], ctx_pyramid[max_res], correlations, flow_estimates[max_res] + ) + + up_mask = self.mask_predictor(net_pyramid[max_res]) + flow_estimates[max_res] = flow_estimates[max_res] + delta_flow + # at the final resolution we simply do a convex upsample using the base downsample rate + flow_pred = -upsample_flow(flow_estimates[max_res], up_mask, factor=self.downsampling_factors[0]) + predictions.append(flow_pred) + + return predictions + + +def _crestereo( + *, + weights: Optional[WeightsEnum], + progress: bool, + # Feature Encoder + feature_encoder_layers: Tuple[int, int, int, int, int], + feature_encoder_strides: Tuple[int, int, int, int], + feature_encoder_block: Callable[..., nn.Module], + # Average Pooling Pyramid + feature_downsample_rates: Tuple[int, ...], + # Adaptive Correlation Layer + corr_groups: int, + corr_search_window_2d: Tuple[int, int], + corr_search_dilate_2d: Tuple[int, int], + corr_search_window_1d: Tuple[int, int], + corr_search_dilate_1d: Tuple[int, int], + # Flow head + flow_head_hidden_size: int, + # Recurrent block + recurrent_block_hidden_state_size: int, + recurrent_block_kernel_size: Tuple[Tuple[int, int], Tuple[int, int]], + recurrent_block_padding: Tuple[Tuple[int, int], Tuple[int, int]], + # Motion Encoder + motion_encoder_corr_layers: Tuple[int, int], + motion_encoder_flow_layers: Tuple[int, int], + motion_encoder_out_channels: int, + # Transformer Blocks + num_attention_heads: int, + num_self_attention_layers: int, + num_cross_attention_layers: int, + self_attention_type: str, + cross_attention_type: str, + **kwargs, +) -> CREStereo: + + feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder( + block=feature_encoder_block, + layers=feature_encoder_layers, + strides=feature_encoder_strides, + norm_layer=nn.InstanceNorm2d, + ) + + assert feature_encoder.output_dim % corr_groups == 0, ( + f"Final ``feature_encoder_layers`` size should be divisible by ``corr_groups`` argument." + f"Feature encoder output size : {feature_encoder.output_dim}, Correlation groups: {corr_groups}." + ) + + motion_encoder = kwargs.pop("motion_encoder", None) or raft.MotionEncoder( + in_channels_corr=corr_groups * int(np.prod(corr_search_window_1d)), + corr_layers=motion_encoder_corr_layers, + flow_layers=motion_encoder_flow_layers, + out_channels=motion_encoder_out_channels, + ) + + out_channels_context = feature_encoder_layers[-1] - recurrent_block_hidden_state_size + recurrent_block = kwargs.pop("recurrent_block", None) or raft.RecurrentBlock( + input_size=motion_encoder.out_channels + out_channels_context, + hidden_size=recurrent_block_hidden_state_size, + kernel_size=recurrent_block_kernel_size, + padding=recurrent_block_padding, + ) + + flow_head = kwargs.pop("flow_head", None) or raft.FlowHead( + in_channels=out_channels_context, hidden_size=flow_head_hidden_size + ) + + update_block = raft.UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head) + + self_attn_block = LocalFeatureTransformer( + dim_model=feature_encoder.output_dim, + num_heads=num_attention_heads, + attention_directions=["self"] * num_self_attention_layers, + attention_type=self_attention_type, + ) + + cross_attn_block = LocalFeatureTransformer( + dim_model=feature_encoder.output_dim, + num_heads=num_attention_heads, + attention_directions=["cross"] * num_cross_attention_layers, + attention_type=cross_attention_type, + ) + + model = CREStereo( + feature_encoder=feature_encoder, + update_block=update_block, + flow_head=flow_head, + self_attn_block=self_attn_block, + cross_attn_block=cross_attn_block, + feature_downsample_rates=feature_downsample_rates, + correlation_groups=corr_groups, + search_window_1d=corr_search_window_1d, + search_window_2d=corr_search_window_2d, + search_dilate_1d=corr_search_dilate_1d, + search_dilate_2d=corr_search_dilate_2d, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +def crestereo_base(*, weights: Optional[WeightsEnum] = None, progress=True, **kwargs) -> CREStereo: + return _crestereo( + weights=weights, + progress=progress, + # Feature encoder + feature_encoder_layers=(64, 64, 96, 128, 256), + feature_encoder_strides=(2, 1, 2, 1), + feature_encoder_block=ResidualBlock, + # Average pooling pyramid + feature_downsample_rates=(2, 4), + # Motion encoder + motion_encoder_corr_layers=(256, 192), + motion_encoder_flow_layers=(128, 64), + motion_encoder_out_channels=256, + # Recurrent block + recurrent_block_hidden_state_size=128, + recurrent_block_kernel_size=((1, 5), (5, 1)), + recurrent_block_padding=((0, 2), (2, 0)), + # Flow head + flow_head_hidden_size=256, + # Transformer blocks + num_attention_heads=8, + num_self_attention_layers=1, + num_cross_attention_layers=1, + self_attention_type="linear", + cross_attention_type="linear", + # Adaptive Correlation layer + corr_groups=4, + corr_search_window_2d=(3, 3), + corr_search_dilate_2d=(1, 1), + corr_search_window_1d=(1, 9), + corr_search_dilate_1d=(1, 1), + )