In [1]:
import sys
sys.path.append('/HighResMDE/src')

from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import tqdm
import csv
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torchvision import transforms

from model import Model, ModelConfig
from dataloader.BaseDataloader import BaseImageDataset
from dataloader.NYUDataloader import NYUImageData
from layers.DN_to_distance import DN_to_distance
from layers.depth_to_normal import Depth2Normal
from loss import silog_loss, get_metrics
from segmentation import compute_seg, get_smooth_ND

torch.manual_seed(42)

<torch._C.Generator at 0x7f94b85e7510>

In [2]:
BATCH_SIZE = 2
local_rank = "cuda:0"

train_dataset = BaseImageDataset('train', NYUImageData, '/scratchdata/nyu_data/', '/HighResMDE/src/nyu_train.csv')
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True)

In [3]:
config =  ModelConfig("tiny07")
config.batch_size = BATCH_SIZE
config.height = 480//4
config.width = 640//4
model = Model(config).to(local_rank)
model.backbone.backbone.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

Swinv2Backbone(
  (embeddings): Swinv2Embeddings(
    (patch_embeddings): Swinv2PatchEmbeddings(
      (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    )
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Swinv2Encoder(
    (layers): ModuleList(
      (0): Swinv2Stage(
        (blocks): ModuleList(
          (0-1): 2 x Swinv2Layer(
            (attention): Swinv2Attention(
              (self): Swinv2SelfAttention(
                (continuous_position_bias_mlp): Sequential(
                  (0): Linear(in_features=2, out_features=512, bias=True)
                  (1): ReLU(inplace=True)
                  (2): Linear(in_features=512, out_features=3, bias=False)
                )
                (query): Linear(in_features=96, out_features=96, bias=True)
                (key): Linear(in_features=96, out_features=96, bias=False)
                (value): Linear(in_features=96, out_features=96, bi

In [None]:
silog_criterion = silog_loss(variance_focus=0.85)
dn_to_distance = DN_to_distance(config.batch_size, config.height * 4, config.width * 4).to(local_rank)
normal_estimation = Depth2Normal()
blur = transforms.GaussianBlur(kernel_size=5)
loop = tqdm.tqdm(train_dataloader, desc=f"Epoch {0+1}", unit="batch")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

store = []

for itr, x in enumerate(loop):
    if itr<96: continue
    optimizer.zero_grad()
    for k in x.keys():
        x[k] = x[k].to(local_rank)

    d1_list, u1, d2_list, u2, norm_est, dist_est = model(x)
    
    print

    # Estimate GT normal and distance

    depth_gt = x["depth_values"] * x["max_depth"].view(-1, 1, 1, 1)
    normal_gt, x["mask"] = normal_estimation(depth_gt, x["camera_intrinsics"], x["mask"], 1.0) # TODO: Figure out what scale does
    normal_gt = torch.stack([blur(each_normal) for each_normal in normal_gt])
    normal_gt = F.normalize(normal_gt, dim=1, p=2)
    dist_gt = dn_to_distance(depth_gt, normal_gt, x["camera_intrinsics_inverted"])

    # Depth Loss

    loss_depth1_0 = silog_criterion(d1_list[0], depth_gt, x["mask"])
    loss_depth2_0 = silog_criterion(d2_list[0], depth_gt, x["mask"])

    loss_depth1 = 0
    loss_depth2 = 0
    weights_sum = 0
    for i in range(len(d1_list) - 1):
        loss_depth1 += (0.85**(len(d1_list)-i-2)) * silog_criterion(d1_list[i + 1], depth_gt, x["mask"])
        loss_depth2 += (0.85**(len(d2_list)-i-2)) * silog_criterion(d2_list[i + 1], depth_gt, x["mask"])
        weights_sum += 0.85**(len(d1_list)-i-2)
    
    loss_depth = 10 * ((loss_depth1 + loss_depth2) / weights_sum + loss_depth1_0 + loss_depth2_0 )
    
    # Uncertainty Loss

    uncer1_gt = torch.exp(-5 * torch.abs(depth_gt - d1_list[0].detach()) / (depth_gt + d1_list[0].detach() + 1e-7))
    uncer2_gt = torch.exp(-5 * torch.abs(depth_gt - d2_list[0].detach()) / (depth_gt + d2_list[0].detach() + 1e-7))
    
    loss_uncer1 = torch.abs(u1-uncer1_gt)[x["mask"]].mean()
    loss_uncer2 = torch.abs(u2-uncer2_gt)[x["mask"]].mean()

    loss_uncer = loss_uncer1 + loss_uncer2

    loss_normal = 5 * ((1 - (normal_gt * norm_est).sum(1, keepdim=True))[x["mask"]]).mean() #* x["mask"]).sum() / (x["mask"] + 1e-7).sum()
    loss_distance = 0.25 * torch.abs(dist_gt- dist_est)[x["mask"]].mean()

    # Segmentation Loss
    #segment, planar_mask, dissimilarity_map = compute_seg(x["pixel_values"], norm_est, dist_est[:, 0])
    #loss_grad_normal, loss_grad_distance = get_smooth_ND(norm_est, dist_est, planar_mask)

    #loss_seg = 0.01 * (loss_grad_distance + loss_grad_normal)

    loss = loss_depth + loss_uncer + loss_normal + loss_distance #+ loss_seg
    loss = loss.mean()

    if loss.isnan():
        print("Loss is NaN")
        store.append(x, print(len(x)))
        continue

Epoch 1:   0%|          | 0/25344 [00:00<?, ?batch/s]

Epoch 1:  27%|██▋       | 6729/25344 [20:10<57:28,  5.40batch/s]  