Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/windows conda support #61

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
17 changes: 17 additions & 0 deletions MiDaS/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import torch.nn as nn


class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.

Args:
path (str): file path
"""
parameters = torch.load(path)

if "optimizer" in parameters:
parameters = parameters["model"]

self.load_state_dict(parameters)
153 changes: 153 additions & 0 deletions MiDaS/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn as nn


def _make_encoder(features, use_pretrained):
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features)

return pretrained, scratch


def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)

pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4

return pretrained


def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)


def _make_scratch(in_shape, out_shape):
scratch = nn.Module()

scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
return scratch


class Interpolate(nn.Module):
"""Interpolation module.
"""

def __init__(self, scale_factor, mode):
"""Init.

Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()

self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode

def forward(self, x):
"""Forward pass.

Args:
x (tensor): input

Returns:
tensor: interpolated data
"""

x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False
)

return x


class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""

def __init__(self, features):
"""Init.

Args:
features (int): number of features
"""
super().__init__()

self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)

self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)

self.relu = nn.ReLU(inplace=True)

def forward(self, x):
"""Forward pass.

Args:
x (tensor): input

Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)

return out + x


class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""

def __init__(self, features):
"""Init.

Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()

self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)

def forward(self, *xs):
"""Forward pass.

Returns:
tensor: output
"""
output = xs[0]

if len(xs) == 2:
output += self.resConfUnit1(xs[1])

output = self.resConfUnit2(output)

output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)

return output
76 changes: 76 additions & 0 deletions MiDaS/midas_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn

from MiDaS.base_model import BaseModel
from MiDaS.blocks import FeatureFusionBlock, Interpolate, _make_encoder


class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""

def __init__(self, path=None, features=256, non_negative=True):
"""Init.

Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)

super(MidasNet, self).__init__()

use_pretrained = False if path is None else True

self.pretrained, self.scratch = _make_encoder(features, use_pretrained)

self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)

self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)

if path:
self.load(path)

def forward(self, x):
"""Forward pass.

Args:
x (tensor): input data (image)

Returns:
tensor: depth
"""

layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)

layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)

path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)

out = self.scratch.output_conv(path_1)

return torch.squeeze(out, dim=1)
17 changes: 17 additions & 0 deletions MiDaS/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import torch.nn as nn


class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.

Args:
path (str): file path
"""
parameters = torch.load(path)

if "optimizer" in parameters:
parameters = parameters["model"]

self.load_state_dict(parameters)
Loading