Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 65ea152
Showing
42 changed files
with
740 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# --------------------------------------------------------------------------- | ||
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch | ||
# | ||
# Copyright (c) 2019 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Uber Non-Commercial License (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at the root directory of this project. | ||
# | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Written by Shivam Duggal | ||
# --------------------------------------------------------------------------- | ||
|
||
from __future__ import print_function | ||
from PIL import Image | ||
import torch | ||
import random | ||
import skimage | ||
import numpy as np | ||
from models.image_reconstruction import ImageReconstruction | ||
import os | ||
import torchvision.transforms as transforms | ||
import argparse | ||
import matplotlib.pyplot as plt | ||
|
||
parser = argparse.ArgumentParser(description='Differentiable PatchMatch') | ||
parser.add_argument('--base_dir', default='./', | ||
help='path of base directory where images are stored.') | ||
parser.add_argument('--save_dir', default='./', | ||
help='save directory') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='enables CUDA training') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
|
||
args = parser.parse_args() | ||
torch.backends.cudnn.benchmark=True | ||
args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
if args.cuda: | ||
torch.manual_seed(args.seed) | ||
np.random.seed(args.seed) | ||
random.seed(args.seed) | ||
torch.cuda.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
model = ImageReconstruction() | ||
|
||
if args.cuda: | ||
model.cuda() | ||
|
||
|
||
def main(): | ||
|
||
base_dir = args.base_dir | ||
for file1, file2 in zip(sorted(os.listdir(base_dir+'/image_1')), sorted(os.listdir(base_dir+'/image_2'))): | ||
|
||
image_1_image_path = base_dir + '/image_1/' + file1 | ||
image_2_image_path = base_dir + '/image_2/' + file2 | ||
|
||
image_1 = np.asarray(Image.open(image_1_image_path).convert('RGB')) | ||
image_2 = np.asarray(Image.open(image_2_image_path).convert('RGB')) | ||
|
||
image_1 = transforms.ToTensor()(image_1).unsqueeze(0).cuda().float() | ||
image_2 = transforms.ToTensor()(image_2).unsqueeze(0).cuda().float() | ||
|
||
reconstruction = model(image_1, image_2) | ||
|
||
plt.imsave(os.path.join(args.save_dir, image_1_image_path.split('/')[-1]), | ||
np.asarray(reconstruction[0].permute(1,2,0).data.cpu()*256).astype('uint16')) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
with torch.no_grad(): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pyc |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
|
||
# --------------------------------------------------------------------------- | ||
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch | ||
# | ||
# Copyright (c) 2019 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Uber Non-Commercial License (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at the root directory of this project. | ||
# | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Written by Shivam Duggal | ||
# --------------------------------------------------------------------------- | ||
|
||
from __future__ import print_function | ||
|
||
class obj(object): | ||
def __init__(self, d): | ||
for key, value in d.items(): | ||
if isinstance(value, (list, tuple)): | ||
setattr(self, key, [obj(x) if isinstance(x, dict) else x for x in value]) | ||
else: | ||
setattr(self, key, obj(value) if isinstance(value, dict) else value) | ||
|
||
config = { | ||
"patch_match_args": { | ||
# sample count refers to random sampling stage of generalized PM. | ||
# Number of random samples generated: (sample_count+1) * (sample_count+1) | ||
# we generate (sample_count+1) samples in x direction, and (sample_count+1) samples in y direction, | ||
# and then perform meshgrid like opertaion to generate (sample_count+1) * (sample_count+1) samples. | ||
"sample_count": 1, | ||
|
||
"iteration_count": 21, | ||
"propagation_filter_size": 3, | ||
"propagation_type": "faster_filter_3_propagation", # for better code for PM propagation, set it to None | ||
"softmax_temperature": 10000000000, # softmax temperature for evaluation. Larger temp. lead to sharper output. | ||
"random_search_window_size": [100,100], # search range around evaluated offsets after every iteration. | ||
"evaluation_type": "softmax" | ||
}, | ||
|
||
"feature_extractor_filter_size": 7 | ||
|
||
} | ||
|
||
|
||
config = obj(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# --------------------------------------------------------------------------- | ||
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch | ||
# | ||
# Copyright (c) 2019 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Uber Non-Commercial License (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at the root directory of this project. | ||
# | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Written by Shivam Duggal | ||
# --------------------------------------------------------------------------- | ||
|
||
from __future__ import print_function | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class feature_extractor(nn.Module): | ||
def __init__(self, filter_size): | ||
super(feature_extractor, self).__init__() | ||
|
||
self.filter_size = filter_size | ||
|
||
def forward(self, left_input, right_input): | ||
""" | ||
Feature Extractor | ||
Description: Aggregates the RGB values from the neighbouring pixels in the window (filter_size * filter_size). | ||
No weights are learnt for this feature extractor. | ||
Args: | ||
:param left_input: Left Image | ||
:param right_input: Right Image | ||
Returns: | ||
:left_features: Left Image features | ||
:right_features: Right Image features | ||
:one_hot_filter: Convolution filter used to aggregate neighbour RGB features to the center pixel. | ||
one_hot_filter.shape = (filter_size * filter_size) | ||
""" | ||
|
||
device = left_input.get_device() | ||
|
||
label = torch.arange(0, self.filter_size * self.filter_size, device=device).repeat( | ||
self.filter_size * self.filter_size).view( | ||
self.filter_size * self.filter_size, 1, 1, self.filter_size, self.filter_size) | ||
|
||
one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float() | ||
|
||
left_features = F.conv3d(left_input.unsqueeze(1), one_hot_filter, | ||
padding=(0, self.filter_size // 2, self.filter_size // 2)) | ||
right_features = F.conv3d(right_input.unsqueeze(1), one_hot_filter, | ||
padding=(0, self.filter_size // 2, self.filter_size // 2)) | ||
|
||
left_features = left_features.view(left_features.size()[0], | ||
left_features.size()[1] * left_features.size()[2], | ||
left_features.size()[3], | ||
left_features.size()[4]) | ||
|
||
right_features = right_features.view(right_features.size()[0], | ||
right_features.size()[1] * right_features.size()[2], | ||
right_features.size()[3], | ||
right_features.size()[4]) | ||
|
||
return left_features, right_features, one_hot_filter |
124 changes: 124 additions & 0 deletions
124
DifferentiablePatchMatch/models/image_reconstruction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# --------------------------------------------------------------------------- | ||
# DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch | ||
# | ||
# Copyright (c) 2019 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Uber Non-Commercial License (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at the root directory of this project. | ||
# | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Written by Shivam Duggal | ||
# --------------------------------------------------------------------------- | ||
|
||
from __future__ import print_function | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from models.patch_match import PatchMatch | ||
from models.feature_extractor import feature_extractor | ||
from models.config import config as args | ||
|
||
|
||
class Reconstruct(nn.Module): | ||
def __init__(self, filter_size): | ||
super(Reconstruct, self).__init__() | ||
self.filter_size = filter_size | ||
|
||
def forward(self, right_input, offset_x, offset_y, x_coordinate, y_coordinate, neighbour_extraction_filter): | ||
""" | ||
Reconstruct the left image using the NNF(NNF represented by the offsets and the xy_coordinates) | ||
We did Patch Voting on the offset field, before reconstruction, in order to | ||
generate smooth reconstruction. | ||
Args: | ||
:right_input: Right Image | ||
:offset_x: horizontal offset to generate the NNF. | ||
:offset_y: vertical offset to generate the NNF. | ||
:x_coordinate: X coordinate | ||
:y_coordinate: Y coordinate | ||
Returns: | ||
:reconstruction: Right image reconstruction | ||
""" | ||
|
||
pad_size = self.filter_size // 2 | ||
smooth_offset_x = nn.ReflectionPad2d( | ||
(pad_size, pad_size, pad_size, pad_size))(offset_x) | ||
smooth_offset_y = nn.ReflectionPad2d( | ||
(pad_size, pad_size, pad_size, pad_size))(offset_y) | ||
|
||
smooth_offset_x = F.conv2d(smooth_offset_x, | ||
neighbour_extraction_filter, | ||
padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size] | ||
|
||
smooth_offset_y = F.conv2d(smooth_offset_y, | ||
neighbour_extraction_filter, | ||
padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size] | ||
|
||
coord_x = torch.clamp( | ||
x_coordinate - smooth_offset_x, | ||
min=0, | ||
max=smooth_offset_x.size()[3] - 1) | ||
|
||
coord_y = torch.clamp( | ||
y_coordinate - smooth_offset_y, | ||
min=0, | ||
max=smooth_offset_x.size()[2] - 1) | ||
|
||
coord_x -= coord_x.size()[3] / 2 | ||
coord_x /= (coord_x.size()[3] / 2) | ||
|
||
coord_y -= coord_y.size()[2] / 2 | ||
coord_y /= (coord_y.size()[2] / 2) | ||
|
||
grid = torch.cat((coord_x.unsqueeze(4), coord_y.unsqueeze(4)), dim=4) | ||
grid = grid.view(grid.size()[0] * grid.size()[1], grid.size()[2], grid.size()[3], grid.size()[4]) | ||
reconstruction = F.grid_sample(right_input.repeat(grid.size()[0], 1, 1, 1), grid) | ||
reconstruction = torch.mean(reconstruction, dim=0).unsqueeze(0) | ||
|
||
return reconstruction | ||
|
||
|
||
class ImageReconstruction(nn.Module): | ||
def __init__(self): | ||
super(ImageReconstruction, self).__init__() | ||
|
||
self.patch_match = PatchMatch(args.patch_match_args) | ||
|
||
filter_size = args.feature_extractor_filter_size | ||
self.feature_extractor = feature_extractor(filter_size) | ||
self.reconstruct = Reconstruct(filter_size) | ||
|
||
def forward(self, left_input, right_input): | ||
""" | ||
ImageReconstruction: | ||
Description: This class performs the task of reconstruction the left image using the data of the other image,, | ||
by fidning correspondences (nnf) between the two fields. | ||
The images acan be any random images with some overlap between the two to assist | ||
the correspondence matching. | ||
For feature_extractor, we just use the RGB features of a (self.filter_size * self.filter_size) patch | ||
around each pixel. | ||
For finding the correspondences, we use the Differentiable PatchMatch. | ||
** Note: There is no assumption of rectification between the two images. ** | ||
** Note: The words 'left' and 'right' do not have any significance.** | ||
Args: | ||
:left_input: Left Image (Image 1) | ||
:right_input: Right Image (Image 2) | ||
Returns: | ||
:reconstruction: Reconstructed left image. | ||
""" | ||
|
||
left_features, right_features, neighbour_extraction_filter = self.feature_extractor(left_input, right_input) | ||
offset_x, offset_y, x_coordinate, y_coordinate = self.patch_match(left_features, right_features) | ||
|
||
reconstruction = self.reconstruct(right_input, | ||
offset_x, offset_y, | ||
x_coordinate, y_coordinate, | ||
neighbour_extraction_filter.squeeze(1)) | ||
|
||
return reconstruction |
Oops, something went wrong.