Skip to content

Commit

Permalink
Differentiable PatchMatch birth
Browse files Browse the repository at this point in the history
  • Loading branch information
ShivamDuggal4 committed Oct 23, 2019
0 parents commit 65ea152
Show file tree
Hide file tree
Showing 42 changed files with 740 additions and 0 deletions.
78 changes: 78 additions & 0 deletions DifferentiablePatchMatch/demo_script.py
@@ -0,0 +1,78 @@
# ---------------------------------------------------------------------------
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
#
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Uber Non-Commercial License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at the root directory of this project.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Written by Shivam Duggal
# ---------------------------------------------------------------------------

from __future__ import print_function
from PIL import Image
import torch
import random
import skimage
import numpy as np
from models.image_reconstruction import ImageReconstruction
import os
import torchvision.transforms as transforms
import argparse
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='Differentiable PatchMatch')
parser.add_argument('--base_dir', default='./',
help='path of base directory where images are stored.')
parser.add_argument('--save_dir', default='./',
help='save directory')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')

args = parser.parse_args()
torch.backends.cudnn.benchmark=True
args.cuda = not args.no_cuda and torch.cuda.is_available()

if args.cuda:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True

model = ImageReconstruction()

if args.cuda:
model.cuda()


def main():

base_dir = args.base_dir
for file1, file2 in zip(sorted(os.listdir(base_dir+'/image_1')), sorted(os.listdir(base_dir+'/image_2'))):

image_1_image_path = base_dir + '/image_1/' + file1
image_2_image_path = base_dir + '/image_2/' + file2

image_1 = np.asarray(Image.open(image_1_image_path).convert('RGB'))
image_2 = np.asarray(Image.open(image_2_image_path).convert('RGB'))

image_1 = transforms.ToTensor()(image_1).unsqueeze(0).cuda().float()
image_2 = transforms.ToTensor()(image_2).unsqueeze(0).cuda().float()

reconstruction = model(image_1, image_2)

plt.imsave(os.path.join(args.save_dir, image_1_image_path.split('/')[-1]),
np.asarray(reconstruction[0].permute(1,2,0).data.cpu()*256).astype('uint16'))



if __name__ == '__main__':
with torch.no_grad():
main()
1 change: 1 addition & 0 deletions DifferentiablePatchMatch/models/.gitignore
@@ -0,0 +1 @@
*.pyc
Empty file.
48 changes: 48 additions & 0 deletions DifferentiablePatchMatch/models/config.py
@@ -0,0 +1,48 @@

# ---------------------------------------------------------------------------
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
#
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Uber Non-Commercial License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at the root directory of this project.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Written by Shivam Duggal
# ---------------------------------------------------------------------------

from __future__ import print_function

class obj(object):
def __init__(self, d):
for key, value in d.items():
if isinstance(value, (list, tuple)):
setattr(self, key, [obj(x) if isinstance(x, dict) else x for x in value])
else:
setattr(self, key, obj(value) if isinstance(value, dict) else value)

config = {
"patch_match_args": {
# sample count refers to random sampling stage of generalized PM.
# Number of random samples generated: (sample_count+1) * (sample_count+1)
# we generate (sample_count+1) samples in x direction, and (sample_count+1) samples in y direction,
# and then perform meshgrid like opertaion to generate (sample_count+1) * (sample_count+1) samples.
"sample_count": 1,

"iteration_count": 21,
"propagation_filter_size": 3,
"propagation_type": "faster_filter_3_propagation", # for better code for PM propagation, set it to None
"softmax_temperature": 10000000000, # softmax temperature for evaluation. Larger temp. lead to sharper output.
"random_search_window_size": [100,100], # search range around evaluated offsets after every iteration.
"evaluation_type": "softmax"
},

"feature_extractor_filter_size": 7

}


config = obj(config)
69 changes: 69 additions & 0 deletions DifferentiablePatchMatch/models/feature_extractor.py
@@ -0,0 +1,69 @@
# ---------------------------------------------------------------------------
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
#
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Uber Non-Commercial License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at the root directory of this project.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Written by Shivam Duggal
# ---------------------------------------------------------------------------

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F


class feature_extractor(nn.Module):
def __init__(self, filter_size):
super(feature_extractor, self).__init__()

self.filter_size = filter_size

def forward(self, left_input, right_input):
"""
Feature Extractor
Description: Aggregates the RGB values from the neighbouring pixels in the window (filter_size * filter_size).
No weights are learnt for this feature extractor.
Args:
:param left_input: Left Image
:param right_input: Right Image
Returns:
:left_features: Left Image features
:right_features: Right Image features
:one_hot_filter: Convolution filter used to aggregate neighbour RGB features to the center pixel.
one_hot_filter.shape = (filter_size * filter_size)
"""

device = left_input.get_device()

label = torch.arange(0, self.filter_size * self.filter_size, device=device).repeat(
self.filter_size * self.filter_size).view(
self.filter_size * self.filter_size, 1, 1, self.filter_size, self.filter_size)

one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float()

left_features = F.conv3d(left_input.unsqueeze(1), one_hot_filter,
padding=(0, self.filter_size // 2, self.filter_size // 2))
right_features = F.conv3d(right_input.unsqueeze(1), one_hot_filter,
padding=(0, self.filter_size // 2, self.filter_size // 2))

left_features = left_features.view(left_features.size()[0],
left_features.size()[1] * left_features.size()[2],
left_features.size()[3],
left_features.size()[4])

right_features = right_features.view(right_features.size()[0],
right_features.size()[1] * right_features.size()[2],
right_features.size()[3],
right_features.size()[4])

return left_features, right_features, one_hot_filter
124 changes: 124 additions & 0 deletions DifferentiablePatchMatch/models/image_reconstruction.py
@@ -0,0 +1,124 @@
# ---------------------------------------------------------------------------
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
#
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Uber Non-Commercial License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at the root directory of this project.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Written by Shivam Duggal
# ---------------------------------------------------------------------------

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.patch_match import PatchMatch
from models.feature_extractor import feature_extractor
from models.config import config as args


class Reconstruct(nn.Module):
def __init__(self, filter_size):
super(Reconstruct, self).__init__()
self.filter_size = filter_size

def forward(self, right_input, offset_x, offset_y, x_coordinate, y_coordinate, neighbour_extraction_filter):
"""
Reconstruct the left image using the NNF(NNF represented by the offsets and the xy_coordinates)
We did Patch Voting on the offset field, before reconstruction, in order to
generate smooth reconstruction.
Args:
:right_input: Right Image
:offset_x: horizontal offset to generate the NNF.
:offset_y: vertical offset to generate the NNF.
:x_coordinate: X coordinate
:y_coordinate: Y coordinate
Returns:
:reconstruction: Right image reconstruction
"""

pad_size = self.filter_size // 2
smooth_offset_x = nn.ReflectionPad2d(
(pad_size, pad_size, pad_size, pad_size))(offset_x)
smooth_offset_y = nn.ReflectionPad2d(
(pad_size, pad_size, pad_size, pad_size))(offset_y)

smooth_offset_x = F.conv2d(smooth_offset_x,
neighbour_extraction_filter,
padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size]

smooth_offset_y = F.conv2d(smooth_offset_y,
neighbour_extraction_filter,
padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size]

coord_x = torch.clamp(
x_coordinate - smooth_offset_x,
min=0,
max=smooth_offset_x.size()[3] - 1)

coord_y = torch.clamp(
y_coordinate - smooth_offset_y,
min=0,
max=smooth_offset_x.size()[2] - 1)

coord_x -= coord_x.size()[3] / 2
coord_x /= (coord_x.size()[3] / 2)

coord_y -= coord_y.size()[2] / 2
coord_y /= (coord_y.size()[2] / 2)

grid = torch.cat((coord_x.unsqueeze(4), coord_y.unsqueeze(4)), dim=4)
grid = grid.view(grid.size()[0] * grid.size()[1], grid.size()[2], grid.size()[3], grid.size()[4])
reconstruction = F.grid_sample(right_input.repeat(grid.size()[0], 1, 1, 1), grid)
reconstruction = torch.mean(reconstruction, dim=0).unsqueeze(0)

return reconstruction


class ImageReconstruction(nn.Module):
def __init__(self):
super(ImageReconstruction, self).__init__()

self.patch_match = PatchMatch(args.patch_match_args)

filter_size = args.feature_extractor_filter_size
self.feature_extractor = feature_extractor(filter_size)
self.reconstruct = Reconstruct(filter_size)

def forward(self, left_input, right_input):
"""
ImageReconstruction:
Description: This class performs the task of reconstruction the left image using the data of the other image,,
by fidning correspondences (nnf) between the two fields.
The images acan be any random images with some overlap between the two to assist
the correspondence matching.
For feature_extractor, we just use the RGB features of a (self.filter_size * self.filter_size) patch
around each pixel.
For finding the correspondences, we use the Differentiable PatchMatch.
** Note: There is no assumption of rectification between the two images. **
** Note: The words 'left' and 'right' do not have any significance.**
Args:
:left_input: Left Image (Image 1)
:right_input: Right Image (Image 2)
Returns:
:reconstruction: Reconstructed left image.
"""

left_features, right_features, neighbour_extraction_filter = self.feature_extractor(left_input, right_input)
offset_x, offset_y, x_coordinate, y_coordinate = self.patch_match(left_features, right_features)

reconstruction = self.reconstruct(right_input,
offset_x, offset_y,
x_coordinate, y_coordinate,
neighbour_extraction_filter.squeeze(1))

return reconstruction

0 comments on commit 65ea152

Please sign in to comment.