### About
- This notebook is an implementation of the following papers.
    - [Depth Map Prediction from a Single Image using a Multi-Scale Deep Network](https://arxiv.org/pdf/1406.2283.pdf) by Eigen et al., NIPS 2014
- Dataset
    - [NYU Depth Dataset V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)

In [7]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Function
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

### Device configuration

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Hyper parameters

In [4]:
num_epochs = 20 # not specified in paper
batch_size = 32
lr_coarse_conv = 0.001
lr_coarse_fc = 0.1
lr_fine_1_3 = 0.001
lr_fine_2 = 0.01
momentum = 0.9

### Dataset loader

### Network definition

In [21]:
class GlobalCoarseNet(nn.Module):
    
    def __init__(self):
        super(GlobalCoarseNet, self).__init__()
        
        self.coarse1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4), 
                                     nn.MaxPool2d(kernel_size=2))
        self.coarse2 = nn.Sequential(nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5), 
                                     nn.MaxPool2d(kernel_size=2))
        self.coarse3 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3))
        self.coarse4 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3))
        self.coarse5 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3))
        self.coarse6 = nn.Linear(in_features=256 * 8 * 6, out_features=4096)
        self.coarse7 = nn.Linear(in_features=4096, out_features=74 * 55)
    
    def forward(self, x):
        x = self.coarse1(x)
        x = self.coarse2(x)
        x = self.coarse3(x)
        x = self.coarse4(x)
        x = self.coarse5(x)
        x = self.coarse6(x)
        x = self.coarse7(x)
        
        return x

class LocalFineNet(nn.Module):
    
    def __init__(self):
        super(LocalFineNet, self).__init__()
        
        self.fine1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=63, kernel_size=9, stride=2),
                                   nn.MaxPool2d(kernel_size=2))
        
        # fine2 is implemented in forward function
        
        self.fine3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
        self.fine4 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=5)
    
    def forward(self, x, global_output_batch):
        x = self.fine1(x)
        x = torch.cat((x, global_output_batch), dim=1)
        x = self.fine3(x)
        x = self.fine4(x)
        
        return x

# initialize
global_model = GlobalCoarseNet().to(device)
local_model = LocalFineNet().to(device)

### Loss and optimizer

In [24]:
class ScaleInvariantLoss(Function):
    def forward(pred, target):
        pred = pred.numpy()
        target = target.numpy()
        
        dist = np.log(pred) - np.log(target)
        
        # the lambda parameter is set to 0.5
        loss = np.mean(dist ** 2) - 0.5 / ((dist.size) ** 2) * (np.sum(dist) ** 2) 
        
        return torch.Tensor(loss)
    
    def backward(grad_output):
        return grad_output

# try this lr first
global_optimizer = torch.optim.SGD(params=global_model.parameters(), 
                                   lr=lr_coarse_conv, momentum=0.9)

### Train global model