In [None]:

# find the dataset definition by name, for example dtu_yao (dtu_yao.py)
def find_dataset_def(dataset_name):
    module_name = 'datasets.{}'.format(dataset_name)
    module = importlib.import_module(module_name)
    return getattr(module, "MVSDataset")


In [None]:
"""
Implementation of Pytorch layer primitives, such as Conv+BN+ReLU, differentiable warping layers,
and depth regression based upon expectation of an input probability distribution.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBnReLU(nn.Module):
    """Implements 2d Convolution + batch normalization + ReLU"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        pad: int = 1,
        dilation: int = 1,
    ) -> None:
        """initialization method for convolution2D + batch normalization + relu module
        Args:
            in_channels: input channel number of convolution layer
            out_channels: output channel number of convolution layer
            kernel_size: kernel size of convolution layer
            stride: stride of convolution layer
            pad: pad of convolution layer
            dilation: dilation of convolution layer
        """
        super(ConvBnReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward method"""
        return F.relu(self.bn(self.conv(x)), inplace=True)


class ConvBnReLU3D(nn.Module):
    """Implements of 3d convolution + batch normalization + ReLU."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        pad: int = 1,
        dilation: int = 1,
    ) -> None:
        """initialization method for convolution3D + batch normalization + relu module
        Args:
            in_channels: input channel number of convolution layer
            out_channels: output channel number of convolution layer
            kernel_size: kernel size of convolution layer
            stride: stride of convolution layer
            pad: pad of convolution layer
            dilation: dilation of convolution layer
        """
        super(ConvBnReLU3D, self).__init__()
        self.conv = nn.Conv3d(
            in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False
        )
        self.bn = nn.BatchNorm3d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward method"""
        return F.relu(self.bn(self.conv(x)), inplace=True)


class ConvBnReLU1D(nn.Module):
    """Implements 1d Convolution + batch normalization + ReLU."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        pad: int = 1,
        dilation: int = 1,
    ) -> None:
        """initialization method for convolution1D + batch normalization + relu module
        Args:
            in_channels: input channel number of convolution layer
            out_channels: output channel number of convolution layer
            kernel_size: kernel size of convolution layer
            stride: stride of convolution layer
            pad: pad of convolution layer
            dilation: dilation of convolution layer
        """
        super(ConvBnReLU1D, self).__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False
        )
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward method"""
        return F.relu(self.bn(self.conv(x)), inplace=True)


class ConvBn(nn.Module):
    """Implements of 2d convolution + batch normalization."""

    def __init__(
        self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, pad: int = 1
    ) -> None:
        """initialization method for convolution2D + batch normalization + ReLU module
        Args:
            in_channels: input channel number of convolution layer
            out_channels: output channel number of convolution layer
            kernel_size: kernel size of convolution layer
            stride: stride of convolution layer
            pad: pad of convolution layer
        """
        super(ConvBn, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward method"""
        return self.bn(self.conv(x))


def differentiable_warping(
    src_fea: torch.Tensor, src_proj: torch.Tensor, ref_proj: torch.Tensor, depth_samples: torch.Tensor
):
    """Differentiable homography-based warping, implemented in Pytorch.

    Args:
        src_fea: [B, C, H, W] source features, for each source view in batch
        src_proj: [B, 4, 4] source camera projection matrix, for each source view in batch
        ref_proj: [B, 4, 4] reference camera projection matrix, for each ref view in batch
        depth_samples: [B, Ndepth, H, W] virtual depth layers
    Returns:
        warped_src_fea: [B, C, Ndepth, H, W] features on depths after perspective transformation
    """

    batch, channels, height, width = src_fea.shape
    num_depth = depth_samples.shape[1]

    with torch.no_grad():
        proj = torch.matmul(src_proj, torch.inverse(ref_proj))
        rot = proj[:, :3, :3]  # [B,3,3]
        trans = proj[:, :3, 3:4]  # [B,3,1]

        y, x = torch.meshgrid(
            [
                torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
                torch.arange(0, width, dtype=torch.float32, device=src_fea.device),
            ]
        )
        y, x = y.contiguous(), x.contiguous()
        y, x = y.view(height * width), x.view(height * width)
        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]
        xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)  # [B, 3, H*W]
        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]

        rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(
            batch, 1, num_depth, height * width
        )  # [B, 3, Ndepth, H*W]
        proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)  # [B, 3, Ndepth, H*W]
        # avoid negative depth
        negative_depth_mask = proj_xyz[:, 2:] <= 1e-3
        proj_xyz[:, 0:1][negative_depth_mask] = float(width)
        proj_xyz[:, 1:2][negative_depth_mask] = float(height)
        proj_xyz[:, 2:3][negative_depth_mask] = 1.0
        proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]
        proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1  # [B, Ndepth, H*W]
        proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]
        grid = proj_xy

    warped_src_fea = F.grid_sample(
        src_fea,
        grid.view(batch, num_depth * height, width, 2),
        mode="bilinear",
        padding_mode="zeros",
        align_corners=True,
    )

    return warped_src_fea.view(batch, channels, num_depth, height, width)


def depth_regression(p: torch.Tensor, depth_values: torch.Tensor) -> torch.Tensor:
    """Implements per-pixel depth regression based upon a probability distribution per-pixel.

    The regressed depth value D(p) at pixel p is found as the expectation w.r.t. P of the hypotheses.

    Args:
        p: probability volume [B, D, H, W]
        depth_values: discrete depth values [B, D]
    Returns:
        result depth: expected value, soft argmin [B, 1, H, W]
    """

    return torch.sum(p * depth_values.view(depth_values.shape[0], 1, 1), dim=1).unsqueeze(1)


def is_empty(x: torch.Tensor) -> bool:
    return x.numel() == 0



In [None]:
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .module import ConvBnReLU, depth_regression
from .patchmatch import PatchMatch


class FeatureNet(nn.Module):
    """Feature Extraction Network: to extract features of original images from each view"""

    def __init__(self):
        """Initialize different layers in the network"""

        super(FeatureNet, self).__init__()

        self.conv0 = ConvBnReLU(3, 8, 3, 1, 1)
        # [B,8,H,W]
        self.conv1 = ConvBnReLU(8, 8, 3, 1, 1)
        # [B,16,H/2,W/2]
        self.conv2 = ConvBnReLU(8, 16, 5, 2, 2)
        self.conv3 = ConvBnReLU(16, 16, 3, 1, 1)
        self.conv4 = ConvBnReLU(16, 16, 3, 1, 1)
        # [B,32,H/4,W/4]
        self.conv5 = ConvBnReLU(16, 32, 5, 2, 2)
        self.conv6 = ConvBnReLU(32, 32, 3, 1, 1)
        self.conv7 = ConvBnReLU(32, 32, 3, 1, 1)
        # [B,64,H/8,W/8]
        self.conv8 = ConvBnReLU(32, 64, 5, 2, 2)
        self.conv9 = ConvBnReLU(64, 64, 3, 1, 1)
        self.conv10 = ConvBnReLU(64, 64, 3, 1, 1)

        self.output1 = nn.Conv2d(64, 64, 1, bias=False)
        self.inner1 = nn.Conv2d(32, 64, 1, bias=True)
        self.inner2 = nn.Conv2d(16, 64, 1, bias=True)
        self.output2 = nn.Conv2d(64, 32, 1, bias=False)
        self.output3 = nn.Conv2d(64, 16, 1, bias=False)

    def forward(self, x: torch.Tensor) -> Dict[int, torch.Tensor]:
        """Forward method

        Args:
            x: images from a single view, in the shape of [B, C, H, W]. Generally, C=3

        Returns:
            output_feature: a python dictionary contains extracted features from stage 1 to stage 3
                keys are 1, 2, and 3
        """
        output_feature: Dict[int, torch.Tensor] = {}

        conv1 = self.conv1(self.conv0(x))
        conv4 = self.conv4(self.conv3(self.conv2(conv1)))

        conv7 = self.conv7(self.conv6(self.conv5(conv4)))
        conv10 = self.conv10(self.conv9(self.conv8(conv7)))

        output_feature[3] = self.output1(conv10)
        intra_feat = F.interpolate(conv10, scale_factor=2.0, mode="bilinear", align_corners=False) + self.inner1(conv7)
        del conv7
        del conv10

        output_feature[2] = self.output2(intra_feat)
        intra_feat = F.interpolate(
            intra_feat, scale_factor=2.0, mode="bilinear", align_corners=False) + self.inner2(conv4)
        del conv4

        output_feature[1] = self.output3(intra_feat)
        del intra_feat

        return output_feature


class Refinement(nn.Module):
    """Depth map refinement network"""

    def __init__(self):
        """Initialize"""

        super(Refinement, self).__init__()

        # img: [B,3,H,W]
        self.conv0 = ConvBnReLU(in_channels=3, out_channels=8)
        # depth map:[B,1,H/2,W/2]
        self.conv1 = ConvBnReLU(in_channels=1, out_channels=8)
        self.conv2 = ConvBnReLU(in_channels=8, out_channels=8)
        self.deconv = nn.ConvTranspose2d(
            in_channels=8, out_channels=8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
        )

        self.bn = nn.BatchNorm2d(8)
        self.conv3 = ConvBnReLU(in_channels=16, out_channels=8)
        self.res = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1, bias=False)

    def forward(
        self, img: torch.Tensor, depth_0: torch.Tensor, depth_min: torch.Tensor, depth_max: torch.Tensor
    ) -> torch.Tensor:
        """Forward method

        Args:
            img: input reference images (B, 3, H, W)
            depth_0: current depth map (B, 1, H//2, W//2)
            depth_min: pre-defined minimum depth (B, )
            depth_max: pre-defined maximum depth (B, )

        Returns:
            depth: refined depth map (B, 1, H, W)
        """

        batch_size = depth_min.size()[0]
        # pre-scale the depth map into [0,1]
        depth = (depth_0 - depth_min.view(batch_size, 1, 1, 1)) / (depth_max - depth_min).view(batch_size, 1, 1, 1)

        conv0 = self.conv0(img)
        deconv = F.relu(self.bn(self.deconv(self.conv2(self.conv1(depth)))), inplace=True)
        # depth residual
        res = self.res(self.conv3(torch.cat((deconv, conv0), dim=1)))
        del conv0
        del deconv

        depth = F.interpolate(depth, scale_factor=2.0, mode="nearest") + res
        # convert the normalized depth back
        return depth * (depth_max - depth_min).view(batch_size, 1, 1, 1) + depth_min.view(batch_size, 1, 1, 1)


class PatchmatchNet(nn.Module):
    """ Implementation of complete structure of PatchmatchNet"""

    def __init__(
        self,
        patchmatch_interval_scale: List[float] = [0.005, 0.0125, 0.025],
        propagation_range: List[int] = [6, 4, 2],
        patchmatch_iteration: List[int] = [1, 2, 2],
        patchmatch_num_sample: List[int] = [8, 8, 16],
        propagate_neighbors: List[int] = [0, 8, 16],
        evaluate_neighbors: List[int] = [9, 9, 9],
    ) -> None:
        """Initialize modules in PatchmatchNet

        Args:
            patchmatch_interval_scale: depth interval scale in patchmatch module
            propagation_range: propagation range
            patchmatch_iteration: patchmatch iteration number
            patchmatch_num_sample: patchmatch number of samples
            propagate_neighbors: number of propagation neighbors
            evaluate_neighbors: number of propagation neighbors for evaluation
        """
        super(PatchmatchNet, self).__init__()

        self.stages = 4
        self.feature = FeatureNet()
        self.patchmatch_num_sample = patchmatch_num_sample

        num_features = [16, 32, 64]

        self.propagate_neighbors = propagate_neighbors
        self.evaluate_neighbors = evaluate_neighbors
        # number of groups for group-wise correlation
        self.G = [4, 8, 8]

        for i in range(self.stages - 1):
            patchmatch = PatchMatch(
                propagation_out_range=propagation_range[i],
                patchmatch_iteration=patchmatch_iteration[i],
                patchmatch_num_sample=patchmatch_num_sample[i],
                patchmatch_interval_scale=patchmatch_interval_scale[i],
                num_feature=num_features[i],
                G=self.G[i],
                propagate_neighbors=self.propagate_neighbors[i],
                evaluate_neighbors=evaluate_neighbors[i],
                stage=i + 1,
            )
            setattr(self, f"patchmatch_{i+1}", patchmatch)

        self.upsample_net = Refinement()

    def forward(
        self,
        images: Dict[str, torch.Tensor],
        proj_matrices: Dict[str, torch.Tensor],
        depth_min: torch.Tensor,
        depth_max: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, List[torch.Tensor]]]:
        """Forward method for PatchMatchNet

        Args:
            images: different stages of images (B, 3, H, W) stored in the dictionary
            proj_matrices: different stages of camera projection matrices (B, 4, 4) stored in the dictionary
            depth_min: minimum virtual depth (B, )
            depth_max: maximum virtual depth (B, )

        Returns:
            output tuple of PatchMatchNet, containing refined depthmap, depth patchmatch, and photometric confidence.
        """
        imgs_0 = torch.unbind(images["stage_0"], 1)
        del images

        ref_image = imgs_0[0]

        proj_mtx = {
            0: torch.unbind(proj_matrices["stage_0"].float(), 1),
            1: torch.unbind(proj_matrices["stage_1"].float(), 1),
            2: torch.unbind(proj_matrices["stage_2"].float(), 1),
            3: torch.unbind(proj_matrices["stage_3"].float(), 1)
        }
        del proj_matrices

        assert len(imgs_0) == len(proj_mtx[0]), "Different number of images and projection matrices"

        # step 1. Multi-scale feature extraction
        features: List[Dict[int, torch.Tensor]] = []
        for img in imgs_0:
            output_feature = self.feature(img)
            features.append(output_feature)
        del imgs_0
        ref_feature, src_features = features[0], features[1:]

        depth_min = depth_min.float()
        depth_max = depth_max.float()

        # step 2. Learning-based patchmatch
        depth = torch.empty(0)
        depths: List[torch.Tensor] = []
        score = torch.empty(0)
        view_weights = torch.empty(0)
        depth_patchmatch: Dict[int, List[torch.Tensor]] = {}

        for stage in range(self.stages - 1, 0, -1):
            src_features_l = [src_fea[stage] for src_fea in src_features]
            ref_proj, src_projs = proj_mtx[stage][0], proj_mtx[stage][1:]
            # Need conditional since TorchScript only allows "getattr" access with string literals
            if stage == 3:
                depths, _, view_weights = self.patchmatch_3(
                    ref_feature=ref_feature[stage],
                    src_features=src_features_l,
                    ref_proj=ref_proj,
                    src_projs=src_projs,
                    depth_min=depth_min,
                    depth_max=depth_max,
                    depth=depth,
                    view_weights=view_weights,
                )
            elif stage == 2:
                depths, _, view_weights = self.patchmatch_2(
                    ref_feature=ref_feature[stage],
                    src_features=src_features_l,
                    ref_proj=ref_proj,
                    src_projs=src_projs,
                    depth_min=depth_min,
                    depth_max=depth_max,
                    depth=depth,
                    view_weights=view_weights,
                )
            elif stage == 1:
                depths, score, _ = self.patchmatch_1(
                    ref_feature=ref_feature[stage],
                    src_features=src_features_l,
                    ref_proj=ref_proj,
                    src_projs=src_projs,
                    depth_min=depth_min,
                    depth_max=depth_max,
                    depth=depth,
                    view_weights=view_weights,
                )

            depth_patchmatch[stage] = depths
            depth = depths[-1].detach()

            if stage > 1:
                # upsampling the depth map and pixel-wise view weight for next stage
                depth = F.interpolate(depth, scale_factor=2.0, mode="nearest")
                view_weights = F.interpolate(view_weights, scale_factor=2.0, mode="nearest")

        del ref_feature
        del src_features

        # step 3. Refinement
        depth = self.upsample_net(ref_image, depth, depth_min, depth_max)

        if self.training:
            return depth, torch.empty(0), depth_patchmatch
        else:
            num_depth = self.patchmatch_num_sample[0]
            score_sum4 = 4 * F.avg_pool3d(
                F.pad(score.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0
            ).squeeze(1)
            # [B, 1, H, W]
            depth_index = depth_regression(
                score, depth_values=torch.arange(num_depth, device=score.device, dtype=torch.float)
            ).long().clamp(0, num_depth - 1)
            photometric_confidence = torch.gather(score_sum4, 1, depth_index)
            photometric_confidence = F.interpolate(photometric_confidence, scale_factor=2.0, mode="nearest").squeeze(1)

            return depth, photometric_confidence, depth_patchmatch


def patchmatchnet_loss(
    depth_patchmatch: Dict[int, List[torch.Tensor]],
    depth_gt: Dict[str, torch.Tensor],
    mask: Dict[str, torch.Tensor],
) -> torch.Tensor:
    """Patchmatch Net loss function

    Args:
        depth_patchmatch: depth map predicted by patchmatch net
        depth_gt: ground truth depth map
        mask: mask for filter valid points

    Returns:
        loss: result loss value
    """
    loss = 0
    for i in range(0, 4):
        mask_i = mask[f"stage_{i}"] > 0.5
        gt_depth = depth_gt[f"stage_{i}"][mask_i]
        for depth in depth_patchmatch[i]:
            loss = loss + F.smooth_l1_loss(depth[mask_i], gt_depth, reduction="mean")

    return loss


In [None]:
"""
PatchmatchNet uses the following main steps:

1. Initialization: generate random hypotheses;
2. Propagation: propagate hypotheses to neighbors;
3. Evaluation: compute the matching costs for all the hypotheses and choose best solutions.
"""
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .module import ConvBnReLU3D, differentiable_warping, is_empty


class DepthInitialization(nn.Module):
    """Initialization Stage Class"""

    def __init__(self, patchmatch_num_sample: int = 1) -> None:
        """Initialize method

        Args:
            patchmatch_num_sample: number of samples used in patchmatch process
        """
        super(DepthInitialization, self).__init__()
        self.patchmatch_num_sample = patchmatch_num_sample

    def forward(
        self,
        min_depth: torch.Tensor,
        max_depth: torch.Tensor,
        height: int,
        width: int,
        depth_interval_scale: float,
        device: torch.device,
        depth: torch.Tensor = torch.empty(0),
    ) -> torch.Tensor:
        """Forward function for depth initialization

        Args:
            min_depth: minimum virtual depth, (B, )
            max_depth: maximum virtual depth, (B, )
            height: height of depth map
            width: width of depth map
            depth_interval_scale: depth interval scale
            device: device on which to place tensor
            depth: current depth (B, 1, H, W)

        Returns:
            depth_sample: initialized sample depth map by randomization or local perturbation (B, Ndepth, H, W)
        """
        batch_size = min_depth.size()[0]
        inverse_min_depth = 1.0 / min_depth
        inverse_max_depth = 1.0 / max_depth
        if is_empty(depth):
            # first iteration of Patchmatch on stage 3, sample in the inverse depth range
            # divide the range into several intervals and sample in each of them
            patchmatch_num_sample = 48
            # [B,Ndepth,H,W]
            depth_sample = torch.rand(
                size=(batch_size, patchmatch_num_sample, height, width), device=device
            ) + torch.arange(start=0, end=patchmatch_num_sample, step=1, device=device).view(
                1, patchmatch_num_sample, 1, 1
            )

            depth_sample = inverse_max_depth.view(batch_size, 1, 1, 1) + depth_sample / patchmatch_num_sample * (
                inverse_min_depth.view(batch_size, 1, 1, 1) - inverse_max_depth.view(batch_size, 1, 1, 1)
            )

            return 1.0 / depth_sample

        elif self.patchmatch_num_sample == 1:
            return depth.detach()
        else:
            # other Patchmatch, local perturbation is performed based on previous result
            # uniform samples in an inversed depth range
            depth_sample = (
                torch.arange(-self.patchmatch_num_sample // 2, self.patchmatch_num_sample // 2, 1, device=device)
                .view(1, self.patchmatch_num_sample, 1, 1).repeat(batch_size, 1, height, width).float()
            )
            inverse_depth_interval = (inverse_min_depth - inverse_max_depth) * depth_interval_scale
            inverse_depth_interval = inverse_depth_interval.view(batch_size, 1, 1, 1)

            depth_sample = 1.0 / depth.detach() + inverse_depth_interval * depth_sample

            depth_clamped = []
            del depth
            for k in range(batch_size):
                depth_clamped.append(
                    torch.clamp(depth_sample[k], min=inverse_max_depth[k], max=inverse_min_depth[k]).unsqueeze(0)
                )

            return 1.0 / torch.cat(depth_clamped, dim=0)


class Propagation(nn.Module):
    """ Propagation module implementation"""

    def __init__(self) -> None:
        """Initialize method"""
        super(Propagation, self).__init__()

    def forward(self, depth_sample: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
        # [B,D,H,W]
        """Forward method of adaptive propagation

        Args:
            depth_sample: sample depth map, in shape of [batch, num_depth, height, width],
            grid: 2D grid for bilinear gridding, in shape of [batch, neighbors*H, W, 2]

        Returns:
            propagate depth: sorted propagate depth map [batch, num_depth+num_neighbors, height, width]
        """
        batch, num_depth, height, width = depth_sample.size()
        num_neighbors = grid.size()[1] // height
        propagate_depth_sample = F.grid_sample(
            depth_sample[:, num_depth // 2, :, :].unsqueeze(1),
            grid,
            mode="bilinear",
            padding_mode="border",
            align_corners=False
        ).view(batch, num_neighbors, height, width)
        return torch.sort(torch.cat((depth_sample, propagate_depth_sample), dim=1), dim=1)[0]


class Evaluation(nn.Module):
    """Evaluation module for adaptive evaluation step in Learning-based Patchmatch
    Used to compute the matching costs for all the hypotheses and choose best solutions.
    """

    def __init__(self, G: int = 8) -> None:
        """Initialize method`

        Args:
            G: the feature channels of input will be divided evenly into G groups
        """
        super(Evaluation, self).__init__()

        self.G = G
        self.pixel_wise_net = PixelwiseNet(self.G)
        self.softmax = nn.LogSoftmax(dim=1)
        self.similarity_net = SimilarityNet(self.G)

    def forward(
        self,
        ref_feature: torch.Tensor,
        src_features: List[torch.Tensor],
        ref_proj: torch.Tensor,
        src_projs: List[torch.Tensor],
        depth_sample: torch.Tensor,
        grid: torch.Tensor,
        weight: torch.Tensor,
        view_weights: torch.Tensor = torch.empty(0),
        is_inverse: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward method for adaptive evaluation

        Args:
            ref_feature: feature from reference view, (B, C, H, W)
            src_features: features from (Nview-1) source views, (Nview-1) * (B, C, H, W), where Nview is the number of
                input images (or views) of PatchmatchNet
            ref_proj: projection matrix of reference view, (B, 4, 4)
            src_projs: source matrices of source views, (Nview-1) * (B, 4, 4), where Nview is the number of input
                images (or views) of PatchmatchNet
            depth_sample: sample depth map, (B,Ndepth,H,W)
            grid: grid, (B, evaluate_neighbors*H, W, 2)
            weight: weight, (B,Ndepth,1,H,W)
            view_weights: Tensor to store weights of source views, in shape of (B,Nview-1,H,W),
                Nview-1 represents the number of source views
            is_inverse: Flag for inverse depth regression

        Returns:
            depth_sample: expectation of depth sample, (B,H,W)
            score: probability map, (B,Ndepth,H,W)
            view_weights: optional, Tensor to store weights of source views, in shape of (B,Nview-1,H,W),
                Nview-1 represents the number of source views
        """
        batch, feature_channel, height, width = ref_feature.size()
        device = ref_feature.device

        num_depth = depth_sample.size()[1]
        assert (
            len(src_features) == len(src_projs)
        ), "Patchmatch Evaluation: Different number of images and projection matrices"
        if not is_empty(view_weights):
            assert (
                len(src_features) == view_weights.size()[1]
            ), "Patchmatch Evaluation: Different number of images and view weights"

        # Change to a tensor with value 1e-5
        pixel_wise_weight_sum = 1e-5 * torch.ones((batch, 1, 1, height, width), dtype=torch.float32, device=device)
        ref_feature = ref_feature.view(batch, self.G, feature_channel // self.G, 1, height, width)
        similarity_sum = torch.zeros((batch, self.G, num_depth, height, width), dtype=torch.float32, device=device)

        i = 0
        view_weights_list = []
        for src_feature, src_proj in zip(src_features, src_projs):
            warped_feature = differentiable_warping(
                src_feature, src_proj, ref_proj, depth_sample
            ).view(batch, self.G, feature_channel // self.G, num_depth, height, width)
            # group-wise correlation
            similarity = (warped_feature * ref_feature).mean(2)
            # pixel-wise view weight
            if is_empty(view_weights):
                view_weight = self.pixel_wise_net(similarity)
                view_weights_list.append(view_weight)
            else:
                # reuse the pixel-wise view weight from first iteration of Patchmatch on stage 3
                view_weight = view_weights[:, i].unsqueeze(1)  # [B,1,H,W]
                i = i + 1

            similarity_sum += similarity * view_weight.unsqueeze(1)
            pixel_wise_weight_sum += view_weight.unsqueeze(1)

        # aggregated matching cost across all the source views
        similarity = similarity_sum.div_(pixel_wise_weight_sum)  # [B, G, Ndepth, H, W]
        # adaptive spatial cost aggregation
        score = self.similarity_net(similarity, grid, weight)  # [B, G, Ndepth, H, W]
        # apply softmax to get probability
        score = torch.exp(self.softmax(score))

        if is_empty(view_weights):
            view_weights = torch.cat(view_weights_list, dim=1)  # [B,4,H,W], 4 is the number of source views

        if is_inverse:
            # depth regression: inverse depth regression
            depth_index = torch.arange(0, num_depth, 1, device=device).view(1, num_depth, 1, 1)
            depth_index = torch.sum(depth_index * score, dim=1)

            inverse_min_depth = 1.0 / depth_sample[:, -1, :, :]
            inverse_max_depth = 1.0 / depth_sample[:, 0, :, :]
            depth_sample = inverse_max_depth + depth_index / (num_depth - 1) * (inverse_min_depth - inverse_max_depth)
            depth_sample = 1.0 / depth_sample
        else:
            # depth regression: expectation
            depth_sample = torch.sum(depth_sample * score, dim=1)

        return depth_sample, score, view_weights.detach()


class PatchMatch(nn.Module):
    """Patchmatch module"""

    def __init__(
        self,
        propagation_out_range: int = 2,
        patchmatch_iteration: int = 2,
        patchmatch_num_sample: int = 16,
        patchmatch_interval_scale: float = 0.025,
        num_feature: int = 64,
        G: int = 8,
        propagate_neighbors: int = 16,
        evaluate_neighbors: int = 9,
        stage: int = 3,
    ) -> None:
        """Initialize method

        Args:
            propagation_out_range: range of propagation out,
            patchmatch_iteration: number of iterations in patchmatch,
            patchmatch_num_sample: number of samples in patchmatch,
            patchmatch_interval_scale: interval scale,
            num_feature: number of features,
            G: the feature channels of input will be divided evenly into G groups,
            propagate_neighbors: number of neighbors to be sampled in propagation,
            stage: number of stage,
            evaluate_neighbors: number of neighbors to be sampled in evaluation,
        """
        super(PatchMatch, self).__init__()
        self.patchmatch_iteration = patchmatch_iteration
        self.patchmatch_interval_scale = patchmatch_interval_scale
        self.propa_num_feature = num_feature
        # group wise correlation
        self.G = G
        self.stage = stage
        self.dilation = propagation_out_range
        self.propagate_neighbors = propagate_neighbors
        self.evaluate_neighbors = evaluate_neighbors
        # Using dictionary instead of Enum since TorchScript cannot recognize and export it correctly
        self.grid_type = {"propagation": 1, "evaluation": 2}

        self.depth_initialization = DepthInitialization(patchmatch_num_sample)
        self.propagation = Propagation()
        self.evaluation = Evaluation(self.G)
        # adaptive propagation: last iteration on stage 1 does not have propagation,
        # but we still define this for TorchScript export compatibility
        self.propa_conv = nn.Conv2d(
            in_channels=self.propa_num_feature,
            out_channels=max(2 * self.propagate_neighbors, 1),
            kernel_size=3,
            stride=1,
            padding=self.dilation,
            dilation=self.dilation,
            bias=True,
        )
        nn.init.constant_(self.propa_conv.weight, 0.0)
        nn.init.constant_(self.propa_conv.bias, 0.0)

        # adaptive spatial cost aggregation (adaptive evaluation)
        self.eval_conv = nn.Conv2d(
            in_channels=self.propa_num_feature,
            out_channels=2 * self.evaluate_neighbors,
            kernel_size=3,
            stride=1,
            padding=self.dilation,
            dilation=self.dilation,
            bias=True,
        )
        nn.init.constant_(self.eval_conv.weight, 0.0)
        nn.init.constant_(self.eval_conv.bias, 0.0)
        self.feature_weight_net = FeatureWeightNet(self.evaluate_neighbors, self.G)

    def get_grid(
        self, grid_type: int, batch: int, height: int, width: int, offset: torch.Tensor, device: torch.device
    ) -> torch.Tensor:
        """Compute the offset for adaptive propagation or spatial cost aggregation in adaptive evaluation

        Args:
            grid_type: type of grid - propagation (1) or evaluation (2)
            batch: batch size
            height: grid height
            width: grid width
            offset: grid offset
            device: device on which to place tensor

        Returns:
            generated grid: in the shape of [batch, propagate_neighbors*H, W, 2]
        """

        if grid_type == self.grid_type["propagation"]:
            if self.propagate_neighbors == 4:  # if 4 neighbors to be sampled in propagation
                original_offset = [[-self.dilation, 0], [0, -self.dilation], [0, self.dilation], [self.dilation, 0]]
            elif self.propagate_neighbors == 8:  # if 8 neighbors to be sampled in propagation
                original_offset = [
                    [-self.dilation, -self.dilation],
                    [-self.dilation, 0],
                    [-self.dilation, self.dilation],
                    [0, -self.dilation],
                    [0, self.dilation],
                    [self.dilation, -self.dilation],
                    [self.dilation, 0],
                    [self.dilation, self.dilation],
                ]
            elif self.propagate_neighbors == 16:  # if 16 neighbors to be sampled in propagation
                original_offset = [
                    [-self.dilation, -self.dilation],
                    [-self.dilation, 0],
                    [-self.dilation, self.dilation],
                    [0, -self.dilation],
                    [0, self.dilation],
                    [self.dilation, -self.dilation],
                    [self.dilation, 0],
                    [self.dilation, self.dilation],
                ]
                for i in range(len(original_offset)):
                    offset_x, offset_y = original_offset[i]
                    original_offset.append([2 * offset_x, 2 * offset_y])
            else:
                raise NotImplementedError
        elif grid_type == self.grid_type["evaluation"]:
            dilation = self.dilation - 1  # dilation of evaluation is a little smaller than propagation
            if self.evaluate_neighbors == 9:  # if 9 neighbors to be sampled in evaluation
                original_offset = [
                    [-dilation, -dilation],
                    [-dilation, 0],
                    [-dilation, dilation],
                    [0, -dilation],
                    [0, 0],
                    [0, dilation],
                    [dilation, -dilation],
                    [dilation, 0],
                    [dilation, dilation],
                ]
            elif self.evaluate_neighbors == 17:  # if 17 neighbors to be sampled in evaluation
                original_offset = [
                    [-dilation, -dilation],
                    [-dilation, 0],
                    [-dilation, dilation],
                    [0, -dilation],
                    [0, 0],
                    [0, dilation],
                    [dilation, -dilation],
                    [dilation, 0],
                    [dilation, dilation],
                ]
                for i in range(len(original_offset)):
                    offset_x, offset_y = original_offset[i]
                    if offset_x != 0 or offset_y != 0:
                        original_offset.append([2 * offset_x, 2 * offset_y])
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

        with torch.no_grad():
            y_grid, x_grid = torch.meshgrid(
                [
                    torch.arange(0, height, dtype=torch.float32, device=device),
                    torch.arange(0, width, dtype=torch.float32, device=device),
                ]
            )
            y_grid, x_grid = y_grid.contiguous().view(height * width), x_grid.contiguous().view(height * width)
            xy = torch.stack((x_grid, y_grid))  # [2, H*W]
            xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]

        xy_list = []
        for i in range(len(original_offset)):
            original_offset_y, original_offset_x = original_offset[i]
            offset_x = original_offset_x + offset[:, 2 * i, :].unsqueeze(1)
            offset_y = original_offset_y + offset[:, 2 * i + 1, :].unsqueeze(1)
            xy_list.append((xy + torch.cat((offset_x, offset_y), dim=1)).unsqueeze(2))

        xy = torch.cat(xy_list, dim=2)  # [B, 2, 9, H*W]

        del xy_list
        del x_grid
        del y_grid

        x_normalized = xy[:, 0, :, :] / ((width - 1) / 2) - 1
        y_normalized = xy[:, 1, :, :] / ((height - 1) / 2) - 1
        del xy
        grid = torch.stack((x_normalized, y_normalized), dim=3)  # [B, 9, H*W, 2]
        del x_normalized
        del y_normalized
        return grid.view(batch, len(original_offset) * height, width, 2)

    def forward(
        self,
        ref_feature: torch.Tensor,
        src_features: List[torch.Tensor],
        ref_proj: torch.Tensor,
        src_projs: List[torch.Tensor],
        depth_min: torch.Tensor,
        depth_max: torch.Tensor,
        depth: torch.Tensor,
        view_weights: torch.Tensor = torch.empty(0),
    ) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
        """Forward method for PatchMatch

        Args:
            ref_feature: feature from reference view, (B, C, H, W)
            src_features: features from (Nview-1) source views, (Nview-1) * (B, C, H, W), where Nview is the number of
                input images (or views) of PatchmatchNet
            ref_proj: projection matrix of reference view, (B, 4, 4)
            src_projs: source matrices of source views, (Nview-1) * (B, 4, 4), where Nview is the number of input
                images (or views) of PatchmatchNet
            depth_min: minimum virtual depth, (B,)
            depth_max: maximum virtual depth, (B,)
            depth: current depth map, (B,1,H,W) or None
            view_weights: Tensor to store weights of source views, in shape of (B,Nview-1,H,W),
                Nview-1 represents the number of source views

        Returns:
            depth_samples: list of depth maps from each patchmatch iteration, Niter * (B,1,H,W)
            score: evaluted probabilities, (B,Ndepth,H,W)
            view_weights: Tensor to store weights of source views, in shape of (B,Nview-1,H,W),
                Nview-1 represents the number of source views
        """
        score = torch.empty(0)
        depth_samples = []

        device = ref_feature.device
        batch, _, height, width = ref_feature.size()

        # the learned additional 2D offsets for adaptive propagation
        propa_grid = torch.empty(0)
        if self.propagate_neighbors > 0 and not (self.stage == 1 and self.patchmatch_iteration == 1):
            # last iteration on stage 1 does not have propagation (photometric consistency filtering)
            propa_offset = self.propa_conv(ref_feature).view(batch, 2 * self.propagate_neighbors, height * width)
            propa_grid = self.get_grid(self.grid_type["propagation"], batch, height, width, propa_offset, device)

        # the learned additional 2D offsets for adaptive spatial cost aggregation (adaptive evaluation)
        eval_offset = self.eval_conv(ref_feature).view(batch, 2 * self.evaluate_neighbors, height * width)
        eval_grid = self.get_grid(self.grid_type["evaluation"], batch, height, width, eval_offset, device)

        # [B, evaluate_neighbors, H, W]
        feature_weight = self.feature_weight_net(ref_feature.detach(), eval_grid)
        depth_sample = depth
        del depth

        for iter in range(1, self.patchmatch_iteration + 1):
            is_inverse = self.stage == 1 and iter == self.patchmatch_iteration

            # first iteration on stage 3, random initialization (depth is empty), no adaptive propagation
            # subsequent iterations, local perturbation based on previous result, [B,Ndepth,H,W]
            depth_sample = self.depth_initialization(
                min_depth=depth_min,
                max_depth=depth_max,
                height=height,
                width=width,
                depth_interval_scale=self.patchmatch_interval_scale,
                device=device,
                depth=depth_sample
            )

            # adaptive propagation
            if self.propagate_neighbors > 0 and not (self.stage == 1 and iter == self.patchmatch_iteration):
                # last iteration on stage 1 does not have propagation (photometric consistency filtering)
                depth_sample = self.propagation(depth_sample=depth_sample, grid=propa_grid)

            # weights for adaptive spatial cost aggregation in adaptive evaluation, [B,Ndepth,N_neighbors_eval,H,W]
            weight = depth_weight(
                depth_sample=depth_sample.detach(),
                depth_min=depth_min,
                depth_max=depth_max,
                grid=eval_grid.detach(),
                patchmatch_interval_scale=self.patchmatch_interval_scale,
                neighbors=self.evaluate_neighbors,
            ) * feature_weight.unsqueeze(1)
            weight = weight / torch.sum(weight, dim=2).unsqueeze(2)  # [B,Ndepth,1,H,W]

            # evaluation, outputs regressed depth map and pixel-wise view weights which will
            # be used for subsequent iterations
            depth_sample, score, view_weights = self.evaluation(
                ref_feature=ref_feature,
                src_features=src_features,
                ref_proj=ref_proj,
                src_projs=src_projs,
                depth_sample=depth_sample,
                grid=eval_grid,
                weight=weight,
                view_weights=view_weights,
                is_inverse=is_inverse,
            )

            depth_sample = depth_sample.unsqueeze(1)
            depth_samples.append(depth_sample)

        return depth_samples, score, view_weights


class SimilarityNet(nn.Module):
    """Similarity Net, used in Evaluation module (adaptive evaluation step)
    1. Do 1x1x1 convolution on aggregated cost [B, G, Ndepth, H, W] among all the source views,
        where G is the number of groups
    2. Perform adaptive spatial cost aggregation to get final cost (scores)
    """

    def __init__(self, G: int) -> None:
        """Initialize method

        Args:
            G: the feature channels of input will be divided evenly into G groups
        """
        super(SimilarityNet, self).__init__()

        self.conv0 = ConvBnReLU3D(in_channels=G, out_channels=16, kernel_size=1, stride=1, pad=0)
        self.conv1 = ConvBnReLU3D(in_channels=16, out_channels=8, kernel_size=1, stride=1, pad=0)
        self.similarity = nn.Conv3d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding=0)

    def forward(self, x1: torch.Tensor, grid: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        """Forward method for SimilarityNet

        Args:
            x1: [B, G, Ndepth, H, W], where G is the number of groups, aggregated cost among all the source views with
                pixel-wise view weight
            grid: position of sampling points in adaptive spatial cost aggregation, (B, evaluate_neighbors*H, W, 2)
            weight: weight of sampling points in adaptive spatial cost aggregation, combination of
                feature weight and depth weight, [B,Ndepth,1,H,W]

        Returns:
            final cost: in the shape of [B,Ndepth,H,W]
        """

        batch, G, num_depth, height, width = x1.size()
        num_neighbors = grid.size()[1] // height

        # [B,Ndepth,num_neighbors,H,W]
        x1 = F.grid_sample(
            input=self.similarity(self.conv1(self.conv0(x1))).squeeze(1),
            grid=grid,
            mode="bilinear",
            padding_mode="border",
            align_corners=False
        ).view(batch, num_depth, num_neighbors, height, width)

        return torch.sum(x1 * weight, dim=2)


class FeatureWeightNet(nn.Module):
    """FeatureWeight Net: Called at the beginning of patchmatch, to calculate feature weights based on similarity of
    features of sampling points and center pixel. The feature weights is used to implement adaptive spatial
    cost aggregation.
    """

    def __init__(self, neighbors: int = 9, G: int = 8) -> None:
        """Initialize method

        Args:
            neighbors: number of neighbors to be sampled
            G: the feature channels of input will be divided evenly into G groups
        """
        super(FeatureWeightNet, self).__init__()
        self.neighbors = neighbors
        self.G = G

        self.conv0 = ConvBnReLU3D(in_channels=G, out_channels=16, kernel_size=1, stride=1, pad=0)
        self.conv1 = ConvBnReLU3D(in_channels=16, out_channels=8, kernel_size=1, stride=1, pad=0)
        self.similarity = nn.Conv3d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding=0)

        self.output = nn.Sigmoid()

    def forward(self, ref_feature: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
        """Forward method for FeatureWeightNet

        Args:
            ref_feature: reference feature map, [B,C,H,W]
            grid: position of sampling points in adaptive spatial cost aggregation, (B, evaluate_neighbors*H, W, 2)

        Returns:
            weight based on similarity of features of sampling points and center pixel, [B,Neighbor,H,W]
        """
        batch, feature_channel, height, width = ref_feature.size()

        weight = F.grid_sample(
            ref_feature, grid, mode="bilinear", padding_mode="border", align_corners=False
        ).view(batch, self.G, feature_channel // self.G, self.neighbors, height, width)

        # [B,G,C//G,H,W]
        ref_feature = ref_feature.view(batch, self.G, feature_channel // self.G, height, width).unsqueeze(3)
        # [B,G,Neighbor,H,W]
        weight = (weight * ref_feature).mean(2)
        # [B,Neighbor,H,W]
        return self.output(self.similarity(self.conv1(self.conv0(weight))).squeeze(1))


def depth_weight(
    depth_sample: torch.Tensor,
    depth_min: torch.Tensor,
    depth_max: torch.Tensor,
    grid: torch.Tensor,
    patchmatch_interval_scale: float,
    neighbors: int,
) -> torch.Tensor:
    """Calculate depth weight
    1. Adaptive spatial cost aggregation
    2. Weight based on depth difference of sampling points and center pixel

    Args:
        depth_sample: sample depth map, (B,Ndepth,H,W)
        depth_min: minimum virtual depth, (B,)
        depth_max: maximum virtual depth, (B,)
        grid: position of sampling points in adaptive spatial cost aggregation, (B, neighbors*H, W, 2)
        patchmatch_interval_scale: patchmatch interval scale,
        neighbors: number of neighbors to be sampled in evaluation

    Returns:
        depth weight
    """
    batch, num_depth, height, width = depth_sample.size()
    inverse_depth_min = 1.0 / depth_min
    inverse_depth_max = 1.0 / depth_max

    # normalization
    x = 1.0 / depth_sample
    del depth_sample
    x = (x - inverse_depth_max.view(batch, 1, 1, 1)) / (inverse_depth_min - inverse_depth_max).view(batch, 1, 1, 1)

    x1 = F.grid_sample(
        x, grid, mode="bilinear", padding_mode="border", align_corners=False
    ).view(batch, num_depth, neighbors, height, width)
    del grid

    # [B,Ndepth,N_neighbors,H,W]
    x1 = torch.abs(x1 - x.unsqueeze(2)) / patchmatch_interval_scale
    del x

    # sigmoid output approximate to 1 when x=4
    return torch.sigmoid(4.0 - 2.0 * x1.clamp(min=0, max=4)).detach()


class PixelwiseNet(nn.Module):
    """Pixelwise Net: A simple pixel-wise view weight network, composed of 1x1x1 convolution layers
    and sigmoid nonlinearities, takes the initial set of similarities to output a number between 0 and 1 per
    pixel as estimated pixel-wise view weight.

    1. The Pixelwise Net is used in adaptive evaluation step
    2. The similarity is calculated by ref_feature and other source_features warped by differentiable_warping
    3. The learned pixel-wise view weight is estimated in the first iteration of Patchmatch and kept fixed in the
    matching cost computation.
    """

    def __init__(self, G: int) -> None:
        """Initialize method

        Args:
            G: the feature channels of input will be divided evenly into G groups
        """
        super(PixelwiseNet, self).__init__()
        self.conv0 = ConvBnReLU3D(in_channels=G, out_channels=16, kernel_size=1, stride=1, pad=0)
        self.conv1 = ConvBnReLU3D(in_channels=16, out_channels=8, kernel_size=1, stride=1, pad=0)
        self.conv2 = nn.Conv3d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.output = nn.Sigmoid()

    def forward(self, x1: torch.Tensor) -> torch.Tensor:
        """Forward method for PixelwiseNet

        Args:
            x1: pixel-wise view weight, [B, G, Ndepth, H, W], where G is the number of groups
        """
        # [B,1,H,W]
        return torch.max(self.output(self.conv2(self.conv1(self.conv0(x1))).squeeze(1)), dim=1)[0].unsqueeze(1)


In [None]:

from typing import Any, Callable, Union, Dict

import numpy as np

import torchvision.utils as vutils
import torch
import torch.utils.tensorboard as tb


def print_args(args: Any) -> None:
    """Utilities to print arguments

    Arsg:
        args: arguments to pring out
    """
    print("################################  args  ################################")
    for k, v in args.__dict__.items():
        print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v))))
    print("########################################################################")


def make_nograd_func(func: Callable) -> Callable:
    """Utilities to make function no gradient

    Args:
        func: input function

    Returns:
        no gradient function wrapper for input function
    """

    def wrapper(*f_args, **f_kwargs):
        with torch.no_grad():
            ret = func(*f_args, **f_kwargs)
        return ret

    return wrapper


def make_recursive_func(func: Callable) -> Callable:
    """Convert a function into recursive style to handle nested dict/list/tuple variables

    Args:
        func: input function

    Returns:
        recursive style function
    """

    def wrapper(vars):
        if isinstance(vars, list):
            return [wrapper(x) for x in vars]
        elif isinstance(vars, tuple):
            return tuple([wrapper(x) for x in vars])
        elif isinstance(vars, dict):
            return {k: wrapper(v) for k, v in vars.items()}
        else:
            return func(vars)

    return wrapper


@make_recursive_func
def tensor2float(vars: Any) -> float:
    """Convert tensor to float"""
    if isinstance(vars, float):
        return vars
    elif isinstance(vars, torch.Tensor):
        return vars.data.item()
    else:
        raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars)))


@make_recursive_func
def tensor2numpy(vars: Any) -> np.ndarray:
    """Convert tensor to numpy array"""
    if isinstance(vars, np.ndarray):
        return vars
    elif isinstance(vars, torch.Tensor):
        return vars.detach().cpu().numpy().copy()
    else:
        raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))


@make_recursive_func
def tocuda(vars: Any) -> Union[str, torch.Tensor]:
    """Convert tensor to tensor on GPU"""
    if isinstance(vars, torch.Tensor):
        return vars.cpu()
    elif isinstance(vars, str):
        return vars
    else:
        raise NotImplementedError("invalid input type {} for tocuda".format(type(vars)))


def save_scalars(logger: tb.SummaryWriter, mode: str, scalar_dict: Dict[str, Any], global_step: int) -> None:
    """Log values stored in the scalar dictionary

    Args:
        logger: tensorboard summary writer
        mode: mode name used in writing summaries
        scalar_dict: python dictionary stores the key and value pairs to be recorded
        global_step: step index where the logger should write
    """
    scalar_dict = tensor2float(scalar_dict)
    for key, value in scalar_dict.items():
        if not isinstance(value, (list, tuple)):
            name = "{}/{}".format(mode, key)
            logger.add_scalar(name, value, global_step)
        else:
            for idx in range(len(value)):
                name = "{}/{}_{}".format(mode, key, idx)
                logger.add_scalar(name, value[idx], global_step)


def save_images(logger: tb.SummaryWriter, mode: str, images_dict: Dict[str, Any], global_step: int) -> None:
    """Log images stored in the image dictionary

    Args:
        logger: tensorboard summary writer
        mode: mode name used in writing summaries
        images_dict: python dictionary stores the key and image pairs to be recorded
        global_step: step index where the logger should write
    """
    images_dict = tensor2numpy(images_dict)

    def preprocess(name, img):
        if not (len(img.shape) == 3 or len(img.shape) == 4):
            raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape))
        if len(img.shape) == 3:
            img = img[:, np.newaxis, :, :]
        img = torch.from_numpy(img[:1])
        return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True)

    for key, value in images_dict.items():
        if not isinstance(value, (list, tuple)):
            name = "{}/{}".format(mode, key)
            logger.add_image(name, preprocess(name, value), global_step)
        else:
            for idx in range(len(value)):
                name = "{}/{}_{}".format(mode, key, idx)
                logger.add_image(name, preprocess(name, value[idx]), global_step)


class DictAverageMeter:
    """Wrapper class for dictionary variables that require the average value"""

    def __init__(self) -> None:
        """Initialization method"""
        self.data: Dict[Any, float] = {}
        self.count = 0

    def update(self, new_input: Dict[Any, float]) -> None:
        """Update the stored dictionary with new input data

        Args:
            new_input: new data to update self.data
        """
        self.count += 1
        if len(self.data) == 0:
            for k, v in new_input.items():
                if not isinstance(v, float):
                    raise NotImplementedError("invalid data {}: {}".format(k, type(v)))
                self.data[k] = v
        else:
            for k, v in new_input.items():
                if not isinstance(v, float):
                    raise NotImplementedError("invalid data {}: {}".format(k, type(v)))
                self.data[k] += v

    def mean(self) -> Any:
        """Return the average value of values stored in self.data"""
        return {k: v / self.count for k, v in self.data.items()}


def compute_metrics_for_each_image(metric_func: Callable) -> Callable:
    """A wrapper to compute metrics for each image individually"""

    def wrapper(depth_est, depth_gt, mask, *args):
        batch_size = depth_gt.shape[0]
        print(batch_size)
        # if batch_size < BATCH_SIZE:
        #     break
        results = []
        # compute result one by one
        for idx in range(batch_size):
            ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args)
            results.append(ret)
        return torch.stack(results).mean()

    return wrapper


@make_nograd_func
@compute_metrics_for_each_image
def Thres_metrics(
    depth_est: torch.Tensor, depth_gt: torch.Tensor, mask: torch.Tensor, thres: Union[int, float]
) -> torch.Tensor:
    """Return error rate for where absolute error is larger than threshold.

    Args:
        depth_est: estimated depth map
        depth_gt: ground truth depth map
        mask: mask
        thres: threshold

    Returns:
        error rate: error rate of the depth map
    """
    # if thres is int or float, then True
    assert isinstance(thres, (int, float))
    depth_est, depth_gt = depth_est[mask], depth_gt[mask]
    errors = torch.abs(depth_est - depth_gt)
    err_mask = errors > thres
    return torch.mean(err_mask.float())


# NOTE: please do not use this to build up training loss
@make_nograd_func
@compute_metrics_for_each_image
def AbsDepthError_metrics(depth_est: torch.Tensor, depth_gt: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """Calculate average absolute depth error

    Args:
        depth_est: estimated depth map
        depth_gt: ground truth depth map
        mask: mask
    """
    depth_est, depth_gt = depth_est[mask], depth_gt[mask]
    return torch.mean((depth_est - depth_gt).abs())


In [None]:
"""Utilities for reading and writing images, depth maps, and auxiliary data (cams, pairs) from/to disk."""

import re
import struct
import sys
from typing import Dict, List, Tuple

import cv2
import numpy as np
from PIL import Image


def scale_to_max_dim(image: np.ndarray, max_dim: int) -> Tuple[np.ndarray, int, int]:
    """Scale image to specified max dimension

    Args:
        image: the input image in original size
        max_dim: the max dimension to scale the image down to if smaller than the actual max dimension

    Returns:
        Tuple of scaled image along with original image height and width
    """
    original_height = image.shape[0]
    original_width = image.shape[1]
    scale = max_dim / max(original_height, original_width)
    if 0 < scale < 1:
        width = int(scale * original_width)
        height = int(scale * original_height)
        image = cv2.resize(image, (width, height), interpolation=cv2.INTER_LINEAR)

    return image, original_height, original_width


def read_image(filename: str, max_dim: int = -1) -> Tuple[np.ndarray, int, int]:
    """Read image and rescale to specified max dimension (if exists)

    Args:
        filename: image input file path string
        max_dim: max dimension to scale down the image; keep original size if -1

    Returns:
        Tuple of scaled image along with original image height and width
    """
    image = Image.open(filename)
    # scale 0~255 to 0~1
    np_image = np.array(image, dtype=np.float32) / 255.0
    return scale_to_max_dim(np_image, max_dim)


def save_image(filename: str, image: np.ndarray) -> None:
    """Save images including binary mask (bool), float (0<= val <= 1), or int (as-is)

    Args:
        filename: image output file path string
        image: output image array
    """
    if image.dtype == bool:
        image = image.astype(np.uint8) * 255
    elif image.dtype == np.float32 or image.dtype == np.float64:
        image = image * 255
        image = image.astype(np.uint8)
    else:
        image = image.astype(np.uint8)
    Image.fromarray(image).save(filename)


def read_image_dictionary(filename: str) -> Dict[int, str]:
    """Create image dictionary from file; useful for ETH3D dataset reading and conversion.

    Args:
        filename: input dictionary text file path

    Returns:
        Dictionary of image id (int) and corresponding image file name (string)
    """
    image_dict: Dict[int, str] = {}
    with open(filename) as f:
        num_entries = int(f.readline().strip())
        for _ in range(num_entries):
            parts = f.readline().strip().split(' ')
            image_dict[int(parts[0].strip())] = parts[1].strip()
    return image_dict


def read_cam_file(filename: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Read camera intrinsics, extrinsics, and depth values (min, max) from text file

    Args:
        filename: cam text file path string

    Returns:
        Tuple with intrinsics matrix (3x3), extrinsics matrix (4x4), and depth params vector (min and max) if exists
    """
    with open(filename) as f:
        lines = [line.rstrip() for line in f.readlines()]
    # extrinsics: line [1,5), 4x4 matrix
    extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
    # intrinsics: line [7-10), 3x3 matrix
    intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
    # depth min and max: line 11
    if len(lines) >= 12:
        depth_params = np.fromstring(lines[11], dtype=np.float32, sep=' ')
    else:
        depth_params = np.empty(0)

    return intrinsics, extrinsics, depth_params


def read_pair_file(filename: str) -> List[Tuple[int, List[int]]]:
    """Read image pairs from text file and output a list of tuples each containing the reference image ID and a list of
    source image IDs

    Args:
        filename: pair text file path string

    Returns:
        List of tuples with reference ID and list of source IDs
    """
    data = []
    with open(filename) as f:
        num_viewpoint = int(f.readline())
        for _ in range(num_viewpoint):
            # ref_view = int(f.readline().rstrip())
            ref_view = int(f.readline().rstrip())
            # print(ref_view)
            # src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
            src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
            # print(src_views)
            view_ids = [ref_view] + src_views[:2]
            # print(view_ids)
            if len(src_views) != 0:
                data.append((ref_view, src_views))
    return data


def read_map(path: str, max_dim: int = -1) -> np.ndarray:
    """ Read a binary depth map from either PFM or Colmap (bin) format determined by the file extension and also scale
    the map to the max dim if given

    Args:
        path: input depth map file path string
        max_dim: max dimension to scale down the map; keep original size if -1

    Returns:
        Array of depth map values
    """
    if path.endswith('.bin'):
        in_map = read_bin(path)
    elif path.endswith('.pfm'):
        in_map, _ = read_pfm(path)
    else:
        raise Exception('Invalid input format; only pfm and bin are supported')
    return scale_to_max_dim(in_map, max_dim)[0]


def save_map(path: str, data: np.ndarray) -> None:
    """Save binary depth or confidence maps in PFM or Colmap (bin) format determined by the file extension

    Args:
        path: output map file path string
        data: map data array
    """
    if path.endswith('.bin'):
        save_bin(path, data)
    elif path.endswith('.pfm'):
        save_pfm(path, data)
    else:
        raise Exception('Invalid input format; only pfm and bin are supported')


def read_bin(path: str) -> np.ndarray:
    """Read a depth map from a Colmap .bin file

    Args:
        path: .pfm file path string

    Returns:
        data: array of shape (H, W, C) representing loaded depth map
    """
    with open(path, 'rb') as fid:
        width, height, channels = np.genfromtxt(fid, delimiter='&', max_rows=1,
                                                usecols=(0, 1, 2), dtype=int)
        fid.seek(0)
        num_delimiter = 0
        byte = fid.read(1)
        while True:
            if byte == b'&':
                num_delimiter += 1
                if num_delimiter >= 3:
                    break
            byte = fid.read(1)
        data = np.fromfile(fid, np.float32)
    data = data.reshape((width, height, channels), order='F')
    data = np.transpose(data, (1, 0, 2))
    return data


def save_bin(filename: str, data: np.ndarray):
    """Save a depth map to a Colmap .bin file

    Args:
        filename: output .pfm file path string,
        data: depth map to save, of shape (H,W) or (H,W,C)
    """
    if data.dtype != np.float32:
        raise Exception('Image data type must be float32.')

    if len(data.shape) == 2:
        height, width = data.shape
        channels = 1
    elif len(data.shape) == 3 and (data.shape[2] == 3 or data.shape[2] == 1):
        height, width, channels = data.shape
    else:
        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')

    with open(filename, 'w') as fid:
        fid.write(str(width) + '&' + str(height) + '&' + str(channels) + '&')

    with open(filename, 'ab') as fid:
        if len(data.shape) == 2:
            image_trans = np.transpose(data, (1, 0))
        else:
            image_trans = np.transpose(data, (1, 0, 2))
        data_1d = image_trans.reshape(-1, order='F')
        data_list = data_1d.tolist()
        endian_character = '<'
        format_char_sequence = ''.join(['f'] * len(data_list))
        byte_data = struct.pack(endian_character + format_char_sequence, *data_list)
        fid.write(byte_data)


def read_pfm(filename: str) -> Tuple[np.ndarray, float]:
    """Read a depth map from a .pfm file

    Args:
        filename: .pfm file path string

    Returns:
        data: array of shape (H, W, C) representing loaded depth map
        scale: float to recover actual depth map pixel values
    """
    file = open(filename, "rb")  # treat as binary and read-only

    header = file.readline().decode("utf-8").rstrip()
    if header == "PF":
        color = True
    elif header == "Pf": # depth is Pf
        color = False
    else:
        raise Exception("Not a PFM file.")

    dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8"))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception("Malformed PFM header.")

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = "<"
        scale = -scale
    else:
        endian = ">"  # big-endian

    data = np.fromfile(file, endian + "f")
    shape = (height, width, 3) if color else (height, width, 1)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    file.close()
    return data, scale


def save_pfm(filename: str, image: np.ndarray, scale: float = 1) -> None:
    """Save a depth map to a .pfm file

    Args:
        filename: output .pfm file path string,
        image: depth map to save, of shape (H,W) or (H,W,C)
        scale: scale parameter to save
    """
    file = open(filename, "wb")
    color = None

    image = np.flipud(image)

    if image.dtype.name != "float32":
        raise Exception("Image dtype must be float32.")

    if len(image.shape) == 3 and image.shape[2] == 3:  # color image
        color = True
    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:  # greyscale
        color = False
    else:
        raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")

    file.write("PF\n".encode("utf-8") if color else "Pf\n".encode("utf-8"))
    file.write("{} {}\n".format(image.shape[1], image.shape[0]).encode("utf-8"))

    endian = image.dtype.byteorder

    if endian == "<" or endian == "=" and sys.byteorder == "little":
        scale = -scale

    file.write(("%f\n" % scale).encode("utf-8"))

    image.tofile(file)
    file.close()


In [None]:
import argparse
import os
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import time
# from datasets import find_dataset_def
# from models import *
# from utils import *
import sys
# from datasets.data_io import read_cam_file, read_pair_file, read_image, read_map, save_image, save_map
import cv2
from plyfile import PlyData, PlyElement

cudnn.benchmark = True

parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse')
parser.add_argument('--model', default='PatchmatchNet', help='select model')

parser.add_argument('--dataset', default='eth3d', help='select dataset')
parser.add_argument('--testpath', help='testing data path')
parser.add_argument('--testlist', help='testing scan list')
parser.add_argument('--split', default='test', help='select data')

parser.add_argument('--batch_size', type=int, default=1, help='testing batch size')
parser.add_argument('--n_views', type=int, default=5, help='num of view')


parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')
parser.add_argument('--outdir', default='./outputs', help='output dir')
parser.add_argument('--display', action='store_true', help='display depth images and masks')

parser.add_argument('--patchmatch_iteration', nargs='+', type=int, default=[1, 2, 2],
                    help='num of iteration of patchmatch on stages 1,2,3')
parser.add_argument('--patchmatch_num_sample', nargs='+', type=int, default=[8, 8, 16],
                    help='num of generated samples in local perturbation on stages 1,2,3')
parser.add_argument('--patchmatch_interval_scale', nargs='+', type=float, default=[0.005, 0.0125, 0.025], 
                    help='normalized interval in inverse depth range to generate samples in local perturbation')
parser.add_argument('--patchmatch_range', nargs='+', type=int, default=[6, 4, 2],
                    help='fixed offset of sampling points for propogation of patchmatch on stages 1,2,3')
parser.add_argument('--propagate_neighbors', nargs='+', type=int, default=[0, 8, 16],
                    help='num of neighbors for adaptive propagation on stages 1,2,3')
parser.add_argument('--evaluate_neighbors', nargs='+', type=int, default=[9, 9, 9],
                    help='num of neighbors for adaptive matching cost aggregation of adaptive evaluation on stages 1,2,3')

parser.add_argument('--geo_pixel_thres', type=float, default=1,
                    help='pixel threshold for geometric consistency filtering')
parser.add_argument('--geo_depth_thres', type=float, default=0.01,
                    help='depth threshold for geometric consistency filtering')
parser.add_argument('--photo_thres', type=float, default=0.8, help='threshold for photometric consistency filtering')

# parse arguments and check
args = parser.parse_args()
print("argv:", sys.argv[1:])
print_args(args)


# run MVS model to save depth maps
def save_depth():
    # dataset, dataloader
    mvs_dataset = find_dataset_def(args.dataset)
    test_dataset = mvs_dataset(args.testpath, args.n_views)
    image_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
    # image_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, drop_last=False)

    # model
    model = PatchmatchNet(
        patchmatch_interval_scale=args.patchmatch_interval_scale,
        propagation_range=args.patchmatch_range,
        patchmatch_iteration=args.patchmatch_iteration,
        patchmatch_num_sample=args.patchmatch_num_sample,
        propagate_neighbors=args.propagate_neighbors,
        evaluate_neighbors=args.evaluate_neighbors
    )
    model = nn.DataParallel(model)
    model.cpu()

    # load checkpoint file specified by args.loadckpt
    print("loading model {}".format(args.loadckpt))
    state_dict = torch.load(args.loadckpt,map_location=torch.device('cpu'))
    model.load_state_dict(state_dict['model'], strict=False)
    model.eval()
    
    with torch.no_grad():
        for batch_idx, sample in enumerate(image_loader):
            # print(batch_idx)
            start_time = time.time()
            sample_cuda = tocuda(sample)
            refined_depth, confidence, _ = model(sample_cuda["imgs"], sample_cuda["proj_matrices"],
                                                 sample_cuda["depth_min"], sample_cuda["depth_max"])
            refined_depth = tensor2numpy(refined_depth)
            confidence = tensor2numpy(confidence)

            del sample_cuda
            print('Iter {}/{}, time = {:.3f}'.format(batch_idx, len(image_loader), time.time() - start_time))
            filenames = sample["filename"]

            # save depth maps and confidence maps
            for filename, depth_est, photometric_confidence in zip(filenames, refined_depth, confidence):
                depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm'))
                confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm'))
                os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True)
                os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True)
                # save depth maps
                depth_est = np.squeeze(depth_est, 0)
                save_map(depth_filename, depth_est)
                # save confidence maps
                save_map(confidence_filename, photometric_confidence)
                

# project the reference point cloud into the source view, then project back
def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):
    width, height = depth_ref.shape[1], depth_ref.shape[0]
    # step1. project reference pixels to the source view
    # reference view x, y
    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
    # reference 3D space
    xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),
                        np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))
    # source 3D space
    xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),
                        np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
    # source view x, y
    k_xyz_src = np.matmul(intrinsics_src, xyz_src)
    xy_src = k_xyz_src[:2] / k_xyz_src[2:3]

    # step2. reproject the source view points with source view depth estimation
    # find the depth estimation of the source view
    x_src = xy_src[0].reshape([height, width]).astype(np.float32)
    y_src = xy_src[1].reshape([height, width]).astype(np.float32)
    sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)
    # mask = sampled_depth_src > 0

    # source 3D space
    # NOTE that we should use sampled source-view depth_here to project back
    xyz_src = np.matmul(np.linalg.inv(intrinsics_src),
                        np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))
    # reference 3D space
    xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
                                np.vstack((xyz_src, np.ones_like(x_ref))))[:3]
    # source view x, y, depth
    depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)
    k_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)
    xy_reprojected = k_xyz_reprojected[:2] / k_xyz_reprojected[2:3]
    x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)
    y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)

    return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src


def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src,
                                geo_pixel_thres, geo_depth_thres):
    width, height = depth_ref.shape[1], depth_ref.shape[0]
    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
    depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(
        depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src)
    # print(depth_ref.shape)
    # print(depth_reprojected.shape)
    # check |p_reproj-p_1| < 1
    dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)

    # check |d_reproj-d_1| / d_1 < 0.01
    # depth_ref = np.squeeze(depth_ref, 2)
    depth_diff = np.abs(depth_reprojected - depth_ref)
    relative_depth_diff = depth_diff / depth_ref

    mask = np.logical_and(dist < geo_pixel_thres, relative_depth_diff < geo_depth_thres)
    depth_reprojected[~mask] = 0

    return mask, depth_reprojected, x2d_src, y2d_src


def filter_depth(
        scan_folder, out_folder, plyfilename, geo_pixel_thres, geo_depth_thres, photo_thres, img_wh, geo_mask_thres):
    # the pair file
    pair_file = os.path.join(scan_folder, "pair.txt")
    # for the final point cloud
    vertexs = []
    vertex_colors = []

    pair_data = read_pair_file(pair_file)

    # for each reference view and the corresponding source views
    for ref_view, src_views in pair_data:
        
        # load the reference image
        ref_img, original_h, original_w = read_image(
            os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)), max(img_wh))
        ref_intrinsics, ref_extrinsics, _ = read_cam_file(
            os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view)))[0:2]
        # print([ref_intrinsics,ref_extrinsics])
        ref_intrinsics[0] *= img_wh[0]/original_w
        ref_intrinsics[1] *= img_wh[1]/original_h
        
        # load the estimated depth of the reference view
        ref_depth_est = read_map(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))
        ref_depth_est = np.squeeze(ref_depth_est, 2)
        # load the photometric mask of the reference view
        confidence = read_map(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))
        
        photo_mask = confidence > photo_thres
        photo_mask = np.squeeze(photo_mask, 2)

        all_srcview_depth_ests = []
        # compute the geometric mask
        geo_mask_sum = 0
        for src_view in src_views:
            # camera parameters of the source view
            _, original_h, original_w = read_image(
                os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(src_view)), max(img_wh))
            src_intrinsics, src_extrinsics, _ = read_cam_file(
                os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view)))[0:2]
            src_intrinsics[0] *= img_wh[0]/original_w
            src_intrinsics[1] *= img_wh[1]/original_h
            
            # the estimated depth of the source view
            src_depth_est = read_map(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))

            geo_mask, depth_reprojected, _, _ = check_geometric_consistency(
                ref_depth_est, ref_intrinsics, ref_extrinsics, src_depth_est, src_intrinsics, src_extrinsics,
                geo_pixel_thres, geo_depth_thres)
            geo_mask_sum += geo_mask.astype(np.int32)
            all_srcview_depth_ests.append(depth_reprojected)

        depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1)
        geo_mask = geo_mask_sum >= geo_mask_thres
        final_mask = np.logical_and(photo_mask, geo_mask)

        os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True)
        save_image(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask)
        save_image(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask)
        save_image(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask)
        
        print("processing {}, ref-view{:0>2}, geo_mask:{:3f} photo_mask:{:3f} final_mask: {:3f}".format(
            scan_folder, ref_view, geo_mask.mean(), photo_mask.mean(), final_mask.mean()))

        if args.display:
            cv2.imshow('ref_img', ref_img[:, :, ::-1])
            cv2.imshow('ref_depth', ref_depth_est)
            cv2.imshow('ref_depth * photo_mask', ref_depth_est * photo_mask.astype(np.float32))
            cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32))
            cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32))
            cv2.waitKey(1)

        height, width = depth_est_averaged.shape[:2]
        x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
        
        valid_points = final_mask
        
        x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points]
        
        color = ref_img[valid_points]
        xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), np.vstack((x, y, np.ones_like(x))) * depth)
        xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), np.vstack((xyz_ref, np.ones_like(x))))[:3]
        vertexs.append(xyz_world.transpose((1, 0)))
        vertex_colors.append((color * 255).astype(np.uint8))

    vertexs = np.concatenate(vertexs, axis=0)
    vertex_colors = np.concatenate(vertex_colors, axis=0)
    vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])

    vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
    for prop in vertexs.dtype.names:
        vertex_all[prop] = vertexs[prop]
    for prop in vertex_colors.dtype.names:
        vertex_all[prop] = vertex_colors[prop]

    el = PlyElement.describe(vertex_all, 'vertex')
    PlyData([el]).write(plyfilename)
    print("saving the final model to", plyfilename)


if __name__ == '__main__':
    # step1. save all the depth maps and the masks in outputs directory
    save_depth()
    # the size of image input for PatchmatchNet, maybe downsampled
    img_wh = (640, 480)
    # number of source images need to be consistent with in geometric consistency filtering
    geo_mask_thres = 2

    # step2. filter saved depth maps and reconstruct point cloud
    filter_depth(args.testpath, args.outdir, os.path.join(args.outdir, 'custom.ply'),  args.geo_pixel_thres,
                 args.geo_depth_thres, args.photo_thres, img_wh, geo_mask_thres)
