# **Monocular Depth Estimation with Adaptive Geometric Attention**

This work was most recently published and proposed an encoder - decoder based model by extracting the similarity in the depth map and the normal single RGB Image at the Geometric Edges. They used a concepth of cosine similarity in their Attention modeule - a concept borrowed from the field of Natural Language Processing. 

# **Setup And Imports**

In [1]:
!pip install albumentations==0.4.6

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import os 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import cv2 
import PIL 

import torch
from torch import nn 
from torch import optim as O
from torch.nn import Module
from torch.nn import functional as f
from torch.nn.modules.activation import Sigmoid
from torch.nn.modules.upsampling import UpsamplingBilinear2d
from torch.nn.modules.conv import Conv2d
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader

import torchvision
from torchvision import utils as Vision_utils
from torchvision import models
from torchvision.transforms import functional as TF

from torchsummary import summary
  
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm 

from warnings import filterwarnings as IF

In [3]:
IF("ignore")

# **Loading the Dataset**

The dataset we are using currently is the NYU depth -v2 dataset which was downloaded from kaggle and stored in an external hard drive. There is a .csv file that has thw path to each RGB image and its corresponding grayscale DepthMap.

In [4]:
## Mounting the Gooogle Drive on the Colab file
from google.colab import drive 
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
## Loading the .csv file 
df = pd.read_csv("/content/drive/MyDrive/NYU-depthV2/nyu_data/data/nyu2_train.csv", names = ["image", "depth"])
df.head()

Unnamed: 0,image,depth
0,data/nyu2_train/living_room_0038_out/37.jpg,data/nyu2_train/living_room_0038_out/37.png
1,data/nyu2_train/living_room_0038_out/115.jpg,data/nyu2_train/living_room_0038_out/115.png
2,data/nyu2_train/living_room_0038_out/6.jpg,data/nyu2_train/living_room_0038_out/6.png
3,data/nyu2_train/living_room_0038_out/49.jpg,data/nyu2_train/living_room_0038_out/49.png
4,data/nyu2_train/living_room_0038_out/152.jpg,data/nyu2_train/living_room_0038_out/152.png


In [6]:
## Setting up the Path root 
path_root = "/content/drive/MyDrive/NYU-depthV2/nyu_data/"
df["image"] = df["image"].apply(lambda x : path_root + str(x))
df["depth"] = df["depth"].apply(lambda x : path_root + str(x))
df.head()

Unnamed: 0,image,depth
0,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...
1,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...
2,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...
3,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...
4,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...,/content/drive/MyDrive/NYU-depthV2/nyu_data/da...


In [7]:
## Splitting the training and validation dataset
train_df = df[0 : 40550]
val_df = df[40550 : ]
del df

# **Building the Data Pipeline**

1. The Pipeline takes the path dataframe and returns RGB images and corresponding Depth maps one batch at a time.
2. It will reshape and resize the images according to our hyperparameters as well.
3. To implement Image Augmentations we are going to use the Albumentations library which will automatically take care of the resizing part of the problem. 

In [8]:
class DepthDataset(Dataset) :

    def __init__(self, df, transform = None) :

        self.img_dir = df["image"]
        self.depth_dir = df["depth"]
        self.transform = transform 

    def __len__(self) :
        return len(self.img_dir)

    def __getitem__(self, index) :
        img_path = self.img_dir[index]
        depth_path = self.depth_dir[index]

        image_ = cv2.imread(img_path)
        image_ = cv2.cvtColor(image_, cv2.COLOR_BGR2RGB)
        image_ = np.array(image_, dtype = np.float32)

        depth_ = cv2.imread(depth_path)
        depth_ = cv2.cvtColor(depth_, cv2.COLOR_BGR2GRAY)
        depth_ = np.array(depth_, dtype = np.float32)
        shape = depth_.shape
        depth_.resize(shape[0], shape[1], 1)

        if self.transform is not None :
            augmentations = self.transform(image = image_, mask = depth_)
            image_ = augmentations["image"]
            depth_ = augmentations["mask"]


        return image_, depth_



In [9]:
## Now we are going to use the DataLoader Functionality in the Pytorch Library 
## So we will define a function get_loaders that will return the Dataloader object 
## For both training and Validation dataset according to the Hyperparameters

def get_loaders(
    train_df, val_df, batch_size, train_transform, val_transform, num_workers = 4, pin_memory = True
) :

    train_DS = DepthDataset(train_df, train_transform)
    val_DS = DepthDataset(val_df, val_transform)

    train_loader = DataLoader(
        train_DS, batch_size = batch_size, num_workers = num_workers, pin_memory = pin_memory, shuffle = True 
    )
    
    val_loader = DataLoader(
        val_DS, batch_size = batch_size, num_workers = num_workers, pin_memory = pin_memory, shuffle = False
    )

    return train_loader, val_loader

# **Other Utility Functions**

In this section we implemented other important Utility Functions like saving model checkpoint, loading model checkpoint, saving predictions as images etc. 

In [10]:
def save_checkpoint(state, filename = "my_checkpoint.pth.tar") :
    print("=> Saving Checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model) :
    print("=> Loading Checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

## To ckeck Accuracy we are going to use  the DICE score

def check_accuracy(loader, model, device = "cuda") :
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad() :
        for x, y in loader :
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = model(x)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()
    return


"""

def save_predictions_as_imgs(loader, model, folder = "saved_images/", device = "cuda") :

    model.eval()
    for idx, (x, y) in enumerate(loader) :
        x = x.to(device = device)
        with torch.no_grad() :

"""


'\n\ndef save_predictions_as_imgs(loader, model, folder = "saved_images/", device = "cuda") :\n\n    model.eval()\n    for idx, (x, y) in enumerate(loader) :\n        x = x.to(device = device)\n        with torch.no_grad() :\n\n'

# **Preparing the Hyperparameters**

In [11]:
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
num_epochs = 500
Image_height = 224
Image_width = 224
pin_memory = True 
load_model = False
num_workers = 2

# **Building the Model**

Now we begin to create the Model. the research paper defines a new Module named as Adaptive Geometric Attention (AGA) which will be created from scratch. For the Encoder we will use a pre-trained ResNeXt (101 - 32d) with Image Net. Also we will declare modules Atrous Spatial Pyramid Pooling (ASPP) and Dilated Residual Blocks (DRBs) and arranging the neccessary skip connections to  fiinally complete the model.

The Next Three Cells are used to declare the three special modules namely the Adaptive Geometric Attention (AGA Module), Atrous Spatial Pyramid Pooling (ASPP) and the Dilated Residual Block (DRB) 

In [12]:
class AGA_Block(Module) :

    def __init__(self, C, H, W) :
        super(AGA_Block, self).__init__()
        self.channel_wise_attention = nn.Sequential(
            nn.AvgPool2d((H, W), 1), 
            nn.Conv2d(2*C, C, (1, 1), (1, 1), padding = "same", bias = False), 
            nn.ReLU(),
            nn.Conv2d(C, C, (1, 1), (1, 1), padding = "same", bias = False), 
            nn.Sigmoid()
        )

        self.convLow_1 = nn.Conv2d(C, C, (1, 1), (1, 1), padding = "same", bias = False)
        self.convHigh_1 = nn.Conv2d(C, C, (1, 1), (1, 1), padding = "same", bias = False)

        self.convLow_2 = nn.Conv2d(C, C, (1, 1), (1, 1), padding = "same", bias = False)
        self.convHigh_2 = nn.Conv2d(C, C, (1, 1), (1, 1), padding = "same", bias = False)

    def forward(self, low_x, high_x) :

        ## The Forward part is of three sections 
        ## Section 1 : Channel - Wise Attention Maps
        x_ca = torch.cat((low_x, high_x), dim = 1)
        x_ca = self.channel_wise_attention(x_ca)
        ## The Above has taken two input tensors of size [N, C, H, w]
        ## Returns a Tensor of Size [N, C, 1, 1]

        ## Section 2 : Spatial Attention Map 1
        x_low_1 = self.convLow_1(low_x)
        x_low_1 = f.normalize(x_low_1, p = 2, dim = 1)
        x_high_1 = self.convHigh_1(high_x)
        x_high_1 = f.normalize(x_high_1, p = 2, dim = 1)
        x_comb_1 = torch.mul(x_low_1, x_high_1)
        x_comb_1 = torch.sum(x_comb_1, dim = 1)
        shape_1 = x_comb_1.shape
        x_comb_1 = x_comb_1.reshape(shape_1[0], 1, shape_1[1], shape_1[2])
        sa_1 = torch.abs(x_comb_1)
        ## The Dimensions of the Tensor SA-1 is [N, 1, H, W]

        ## Deleting unused Tensors as they are no longer needed to free up RAM space 
        del x_low_1, x_high_1, x_comb_1, shape_1

        ## Section 3 : Spatial Attention Map 2
        x_low_2 = self.convLow_2(low_x)
        x_low_2 = f.normalize(x_low_2, p = 2, dim = 1)
        x_high_2 = self.convHigh_2(high_x)
        x_high_2 = f.normalize(x_high_2, p = 2, dim = 1)
        x_comb_2 = torch.mul(x_low_2, x_high_2)
        x_comb_2 = torch.sum(x_comb_2, dim = 1)
        shape_2 = x_comb_2.shape
        x_comb_2 = x_comb_2.reshape(shape_2[0], 1, shape_2[1], shape_2[2])
        sa_2 = torch.abs(x_comb_2)
        ## The Dimensions of the Tensor SA-2 is [N, 1, H, W]

        ## Deleting unused Tensors as they are no longer neede to free up RAM space
        del x_low_2, x_high_2, x_comb_2, shape_2

        ## The two sensitivity functions
        ## f(x) = x and f(x) = x * (e ^ x)
        ## The first does not enhances the sensitivity
        ## The authors believe that enhancing the sensitivity of one map helps 
        ## So According to Paper Terminology, we choose 
        ## f1(x) = x
        ## f2(x) = x * (e ^ x)

        ## Implementing f2(x)
        temp = torch.exp(sa_2)
        sa_2 = torch.mul(temp, sa_2)
        final_l = torch.mul(sa_2, low_x)
        final_l = torch.mul(final_l, x_ca)
        del temp, sa_2

        ## There is no need to do anything for f1(x) as it is identity function
        final_h = torch.add(sa_1, high_x)
        del sa_1

        ## Final Summation
        output = torch.add(final_l, final_h, alpha = 1)
        return output 

In [13]:
class ASPP_Block(Module) :

    def __init__(self, C, H, W) :
        super(ASPP_Block, self).__init__()
        self.ImageLevelPooling = nn.Sequential(
            nn.AvgPool2d((H, W), 1),
            nn.Conv2d(C, C//5, (1, 1), (1, 1), padding = "same", bias = False),
            nn.UpsamplingBilinear2d((H, W))
        )

        self.conv1 = nn.Conv2d(C, C//5, (1, 1), (1, 1), padding = "same", bias = False)
        self.conv2 = nn.Conv2d(C, C//5, (3, 3), (1, 1), padding = "same", bias = False, dilation = 3)
        self.conv3 = nn.Conv2d(C, C//5, (3, 3), (1, 1), padding = "same", bias = False, dilation = 6)
        self.conv4 = nn.Conv2d(C, C//5, (3, 3), (1, 1), padding = "same", bias = False, dilation = 9)
        self.final_conv = nn.Conv2d((C//5) * 5, C, (3, 3), (1, 1), padding = "same", bias = False)

    def forward(self, x) :
        a = self.conv1(x)
        b = self.conv2(x)
        c = self.conv3(x)
        d = self.conv4(x) 
        e = self.ImageLevelPooling(x)

        out = torch.cat([a, b, c, d, e], dim = 1)
        del a, b, c, d, e
        out = self.final_conv(out)
        return out 
    


In [14]:
class DRB(Module) :

    def __init__(self, C, H, W) :
        super(DRB, self).__init__()
        self.conv1 = nn.Conv2d(C, C, (1, 1), (1, 1), padding = "same", bias = False, dilation = 1)
        self.conv2 = nn.Conv2d(C, C, (3, 3), (1, 1), padding = "same", bias = False, dilation = 2)
        self.conv3 = nn.Conv2d(C, C, (3, 3), (1, 1), padding = "same", bias = False, dilation = 4)
        self.conv4 = nn.Conv2d(C, C, (3, 3), (1, 1), padding = "same", bias = False, dilation = 8)
        self.conv5 = nn.Conv2d(C, C, (3, 3), (1, 1), padding = "same", bias = False, dilation = 16)

        self.final_conv = nn.Conv2d(5*C, C, (1, 1), (1, 1), padding = "same", bias = False)

    def forward(self, x) :
        a = self.conv1(x)
        b = self.conv2(a)
        c = self.conv3(b)
        d = self.conv4(c)
        e = self.conv5(d)

        out = torch.cat([a, b, c, d, e], dim = 1)

        del a, b, c, d, e
        out = self.final_conv(out)
        return out

In [15]:
class Decoder_Layer(Module) :

    def __init__(self, C_in, H_in, W_in, C_out, H_out, W_out) :
        super(Decoder_Layer, self).__init__()
        self.AGA = AGA_Block(C_in, H_in, W_in) 
        self.block = nn.Sequential(
            DRB(C_in, H_in, W_in), 
            nn.Conv2d(C_in, C_out, (1, 1), (1, 1), padding = "same", bias = False), 
            nn.UpsamplingBilinear2d((H_out, W_out))
        )

    def forward(self, low_x, high_x) :
        out = self.AGA(low_x, high_x)
        out = self.block(out)
        return out


Now, we will create the final Model that will use the ResNeXt50 (32 X 4d) pre trained on the ImageNet Dataset as the encoder. This encoder is implemented and can be imported from the torchvision library.

In [16]:
## Firstly, we have to load the pretrained model and explore what will be the dimensions of tensors that would be outputed.
## All pre trained models in the Torchvision.models collections are trained on ImageNet only.
## the argument pretrained initializes those weights.

model = models.resnext50_32x4d(pretrained = True)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1

In [17]:
model._modules.keys()

odict_keys(['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc'])

In [18]:
## Now we have the names of the block which will return the outputs for the DRB routed skip connections through to the Decoders
skip_connections = ['maxpool','layer1', 'layer2', 'layer3', 'layer4']

In [19]:
## In this cell, we have stored the output sizes of each layer that will give a residual or a skip connection in a list 
## This is neccessary for declaring the DRB and the AGA module

x = torch.randn([5, 64, 112, 112])
a = model._modules["maxpool"](x)
b = model._modules["layer1"](a)
c = model._modules["layer2"](b)
d = model._modules["layer3"](c)
e = model._modules["layer4"](d)
sizes_skip = [x.shape, a.shape, b.shape, c.shape, d.shape, e.shape]
del a, b, c, d, e, model 

In [20]:
class DepthEstimationModel(Module) :

    def __init__(self, skip_connections, sizes_skip) :
        super(DepthEstimationModel, self).__init__()

        ## Constructor Instructions for the Encoder 
        self.pretrained_encoder = models.resnext50_32x4d(pretrained = True).to(device)

        ## The Pre-trained Encoder should not update its weigths during back propagation 
        ## So, the following is done to stop the parameters of the encoder from training
        for param in self.pretrained_encoder.parameters() :
            param.requires_grad = False

        self.skip_connections = skip_connections
        self.encoder_layers = self.pretrained_encoder._modules

        ## Constructor Instructions for the Decoder
        self.ASPP = ASPP_Block(2048, 7, 7).to(device)
        self.skip_DRB = [DRB(sizes_skip[i + 1][1], sizes_skip[i + 1][2], sizes_skip[i + 1][3]).to(device) for i in range(5)]
        self.Decoders = list()
        for i in range(5, 0, -1) :
            curr_layer = Decoder_Layer(*sizes_skip[i][1 : ], *sizes_skip[i - 1][1 :]).to(device)
            self.Decoders.append(curr_layer)

        ## Constructor of the Final Section 
        self.Final_Section = nn.Sequential(
            nn.UpsamplingBilinear2d((224, 224)), 
            nn.Conv2d(64, 1, (3, 3), (1, 1), padding = "same", bias = False),
            nn.Softmax()
        ).to(device)

    def forward(self, x) :
        sc_idx = 0
        skip_outputs = list()

        ## Forward Propagation of the Encoder and saving the skip connections
        for layer, key in enumerate(self.encoder_layers) :
            x = self.encoder_layers[key](x)

            if key == self.skip_connections[sc_idx] :
                skip_outputs.append(x)
                sc_idx += 1

                ## Since ResNext is designed to handle classification problems 
                ## The Remaining layers are not needed in out use case
                if sc_idx >= 5 :
                    break 

            else : 
                continue

        ## Passing the Skip connections through the DRB Module
        for i, out in enumerate(skip_outputs) :
            skip_outputs[i] = self.skip_DRB[i](out)


        ## Developing the Decoder Module
        x = self.ASPP(x)
        skip_outputs = skip_outputs[::-1]
        for i, out in enumerate(skip_outputs) :
            x = self.Decoders[i](out, x)

        del skip_outputs

        ## The Final Output 
        x = self.Final_Section(x)
        return x

# **Building the Loss Functions**

The paper describes a three term loss function, Virtual Normal Loss, Weighted Cross Entropy loss and finally the L2 loss. We are going to use the same code used by the authours of "Enforcing Geometric Constraints of Virtual Normal for Depth Prediction". The following is the GITHUB link of the Repository of the above mentioned code.

https://github.com/YvanYin/VNL_Monocular_Depth_Prediction/blob/master/lib/models/VNL_loss.py

In [21]:
class VNL_Loss(nn.Module):
    """
    Virtual Normal Loss Function.
    """
    def __init__(self, focal_x, focal_y, input_size,
                 delta_cos=0.867, delta_diff_x=0.01,
                 delta_diff_y=0.01, delta_diff_z=0.01,
                 delta_z=0.0001, sample_ratio=0.15):
        super(VNL_Loss, self).__init__()
        self.fx = torch.tensor([focal_x], dtype=torch.float32).cuda()
        self.fy = torch.tensor([focal_y], dtype=torch.float32).cuda()
        self.input_size = input_size
        self.u0 = torch.tensor(input_size[1] // 2, dtype=torch.float32).cuda()
        self.v0 = torch.tensor(input_size[0] // 2, dtype=torch.float32).cuda()
        self.init_image_coor()
        self.delta_cos = delta_cos
        self.delta_diff_x = delta_diff_x
        self.delta_diff_y = delta_diff_y
        self.delta_diff_z = delta_diff_z
        self.delta_z = delta_z
        self.sample_ratio = sample_ratio

    def init_image_coor(self):
        x_row = np.arange(0, self.input_size[1])
        x = np.tile(x_row, (self.input_size[0], 1))
        x = x[np.newaxis, :, :]
        x = x.astype(np.float32)
        x = torch.from_numpy(x.copy()).cuda()
        self.u_u0 = x - self.u0

        y_col = np.arange(0, self.input_size[0])  # y_col = np.arange(0, height)
        y = np.tile(y_col, (self.input_size[1], 1)).T
        y = y[np.newaxis, :, :]
        y = y.astype(np.float32)
        y = torch.from_numpy(y.copy()).cuda()
        self.v_v0 = y - self.v0

    def transfer_xyz(self, depth):
        x = self.u_u0 * torch.abs(depth) / self.fx
        y = self.v_v0 * torch.abs(depth) / self.fy
        z = depth
        pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
        return pw

    def select_index(self):
        valid_width = self.input_size[1]
        valid_height = self.input_size[0]
        num = valid_width * valid_height
        p1 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
        np.random.shuffle(p1)
        p2 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
        np.random.shuffle(p2)
        p3 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
        np.random.shuffle(p3)

        p1_x = p1 % self.input_size[1]
        p1_y = (p1 / self.input_size[1]).astype(np.int)

        p2_x = p2 % self.input_size[1]
        p2_y = (p2 / self.input_size[1]).astype(np.int)

        p3_x = p3 % self.input_size[1]
        p3_y = (p3 / self.input_size[1]).astype(np.int)
        p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y}
        return p123

    def form_pw_groups(self, p123, pw):
        """
        Form 3D points groups, with 3 points in each grouup.
        :param p123: points index
        :param pw: 3D points
        :return:
        """
        p1_x = p123['p1_x']
        p1_y = p123['p1_y']
        p2_x = p123['p2_x']
        p2_y = p123['p2_y']
        p3_x = p123['p3_x']
        p3_y = p123['p3_y']

        pw1 = pw[:, p1_y, p1_x, :]
        pw2 = pw[:, p2_y, p2_x, :]
        pw3 = pw[:, p3_y, p3_x, :]
        # [B, N, 3(x,y,z), 3(p1,p2,p3)]
        pw_groups = torch.cat([pw1[:, :, :, np.newaxis], pw2[:, :, :, np.newaxis], pw3[:, :, :, np.newaxis]], 3)
        return pw_groups

    def filter_mask(self, p123, gt_xyz, delta_cos=0.867,
                    delta_diff_x=0.005,
                    delta_diff_y=0.005,
                    delta_diff_z=0.005):
        pw = self.form_pw_groups(p123, gt_xyz)
        pw12 = pw[:, :, :, 1] - pw[:, :, :, 0]
        pw13 = pw[:, :, :, 2] - pw[:, :, :, 0]
        pw23 = pw[:, :, :, 2] - pw[:, :, :, 1]
        ###ignore linear
        pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]],
                            3)  # [b, n, 3, 3]
        m_batchsize, groups, coords, index = pw_diff.shape
        proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1)  # (B* X CX(3)) [bn, 3(p123), 3(xyz)]
        proj_key = pw_diff.view(m_batchsize * groups, -1, index)  # B X  (3)*C [bn, 3(xyz), 3(p123)]
        q_norm = proj_query.norm(2, dim=2)
        nm = torch.bmm(q_norm.view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[]
        energy = torch.bmm(proj_query, proj_key)  # transpose check [bn, 3(p123), 3(p123)]
        norm_energy = energy / (nm + 1e-8)
        norm_energy = norm_energy.view(m_batchsize * groups, -1)
        mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3  # igonre
        mask_cos = mask_cos.view(m_batchsize, groups)
        ##ignore padding and invilid depth
        mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3

        ###ignore near
        mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0
        mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0
        mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0

        mask_ignore = (mask_x & mask_y & mask_z) | mask_cos
        mask_near = ~mask_ignore
        mask = mask_pad & mask_near

        return mask, pw

    def select_points_groups(self, gt_depth, pred_depth):
        pw_gt = self.transfer_xyz(gt_depth)
        pw_pred = self.transfer_xyz(pred_depth)
        B, C, H, W = gt_depth.shape
        p123 = self.select_index()
        # mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)]
        mask, pw_groups_gt = self.filter_mask(p123, pw_gt,
                                              delta_cos=0.867,
                                              delta_diff_x=0.005,
                                              delta_diff_y=0.005,
                                              delta_diff_z=0.005)

        # [b, n, 3, 3]
        pw_groups_pred = self.form_pw_groups(p123, pw_pred)
        pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001
        mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2)
        pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3)
        pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3)

        return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore

    def forward(self, gt_depth, pred_depth, select=True):
        """
        Virtual normal loss.
        :param pred_depth: predicted depth map, [B,W,H,C]
        :param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
        :return:
        """
        gt_points, dt_points = self.select_points_groups(gt_depth, pred_depth)

        gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0]
        gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0]
        dt_p12 = dt_points[:, :, :, 1] - dt_points[:, :, :, 0]
        dt_p13 = dt_points[:, :, :, 2] - dt_points[:, :, :, 0]

        gt_normal = torch.cross(gt_p12, gt_p13, dim=2)
        dt_normal = torch.cross(dt_p12, dt_p13, dim=2)
        dt_norm = torch.norm(dt_normal, 2, dim=2, keepdim=True)
        gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True)
        dt_mask = dt_norm == 0.0
        gt_mask = gt_norm == 0.0
        dt_mask = dt_mask.to(torch.float32)
        gt_mask = gt_mask.to(torch.float32)
        dt_mask *= 0.01
        gt_mask *= 0.01
        gt_norm = gt_norm + gt_mask
        dt_norm = dt_norm + dt_mask
        gt_normal = gt_normal / gt_norm
        dt_normal = dt_normal / dt_norm
        loss = torch.abs(gt_normal - dt_normal)
        loss = torch.sum(torch.sum(loss, dim=2), dim=0)
        if select:
            loss, indices = torch.sort(loss, dim=0, descending=False)
            loss = loss[int(loss.size(0) * 0.25):]
        loss = torch.mean(loss)
        return loss

Combining the Loss Functions and developing a class 

In [22]:
class TrainingLoss(Module) :

    def __init__(self, alpha = 6, beta = 25) :

        super(TrainingLoss, self).__init__()
        self.alpha = alpha 
        self.beta = beta
        self.VirtualNormalLoss = VNL_Loss(1.0, 1.0, (224, 224))
        self.CrossEntropyLoss = nn.CrossEntropyLoss()
        self.l2_loss = nn.MSELoss()

    def forward(self, pred, target) :
        a = self.VirtualNormalLoss(pred, target).to(device)
        b = self.CrossEntropyLoss(pred, target).to(device)
        c = self.l2_loss(pred, target).to(device)

        out = b + self.alpha * a + self.beta * c
        del a, b, c
        return out 


# **Defining the Main Image Augmentations and specific training functions**

In [23]:
## Image Augmentation using Albumentations

train_transform = A.Compose(
    [
        A.Resize(Image_height, Image_width),
        A.Rotate(limit = 35, p = 1.0),
        A.HorizontalFlip(p = 0.5),
        A.VerticalFlip(p = 0.1),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2(),  
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(Image_height, Image_width), 
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2(),
    ]
)

In [24]:
## Defining the train Function

def train_fn(loader, model, optimizer, loss_fn, scaler) :
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop) :
        data = data.to(device)
        targets = targets.float().squeeze().unsqueeze(1).to(device)

        ## Forward
        with torch.cuda.amp.autocast() :
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        ## Backwards 
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        ## Update the tqdm loader
        loop.set_postfix(loss = loss.item())



In [25]:
class TrainingModel() :

    def __init__(
        self, skip_connections, sizes_skip, learning_rate, num_epochs, alpha = 6, beta = 25, device = "cuda"
        ) :

        self.model = DepthEstimationModel(skip_connections, sizes_skip).to(device)
        self.loss_fn = TrainingLoss(alpha, beta).to(device)
        self.optimizer = O.Adam(self.model.parameters(), lr = learning_rate)
        self.scheduler = LinearLR(self.optimizer, 0.33, 1)

        self.train_loader, self.val_loader = get_loaders(
            train_df, val_df, batch_size, train_transform, val_transforms,
            num_workers, pin_memory 
        )
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.device = device

    def Training(self, load_model) :

        if load_model :
            load_checkpoint(torch.load("my_checkpoint.pth.tar"), self.model)

        #check_accuracy(self.val_loader, self.model, device)
        scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.num_epochs) :
            train_fn(self.train_loader, self.model, self.optimizer, self.loss_fn, scaler)
            self.scheduler.step()

            ## Saving the Model
            checkpoint = {
                "state_dict" : self.model.state_dict(), 
                "optimizer" : self.optimizer.state_dict()   
            }

            save_checkpoint(checkpoint)

            ## Checking the Accuracy 
            #check_accuracy(self.val_loader, self.model, self.device) 

In [None]:
## Final Training Call
Trainer = TrainingModel(skip_connections, sizes_skip, learning_rate, num_epochs, device = device)
Trainer.Training(load_model)

  3%|▎         | 78/2535 [12:12<5:09:24,  7.56s/it, loss=2.11e+5]