<a href="https://colab.research.google.com/github/yujunyoung1107/GLPDepth_colab/blob/main/GLPDepth_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorboardX
!pip install mmcv
!pip install timm
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
!git clone https://github.com/yujunyoung1107/GLPDepth_new.git

In [None]:
%cd /content/GLPDepth_new/datasets/
!unzip /content/gdrive/MyDrive/nyu_depth_v2.zip

In [None]:
%cd /content/GLPDepth_new/code/models/weights/
!cp /content/gdrive/MyDrive/mit_b4.pth mit_bt.pth
%cd /content/GLPDepth_new/code/

In [None]:
import os
import cv2
import numpy as np
from collections import OrderedDict
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

import utils.logging as logging
import utils.metrics as metrics
from models.model import GLPDepth
from dataset.base_dataset import get_dataset

In [11]:
max_dep = 10.0
min_dep = 1e-3
ckpt_dir = '/content/gdrive/MyDrive/' + 'best_model_nyu.ckpt'
data_name = 'nyudepthv2'
result_path = '/content/gdrive/MyDrive/result'
heat_map = True
depth_map = False

if not isinstance(result_path, list):
    paths = [result_path]
for path in paths:
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
%cd /content/GLPDepth_new/
metric_name = ['d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
               'log10', 'silog']

result_metrics = {}
for metric in metric_name:
    result_metrics[metric] = 0.0

if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.benchmark = True
else:
    device = torch.device('cpu')

model = GLPDepth(max_depth=max_dep, is_train=False).to(device)
model_weight = torch.load(ckpt_dir, map_location=torch.device('cpu'))
if 'module' in next(iter(model_weight.items()))[0]:
    model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
model.load_state_dict(model_weight)
model.eval()

dataset_kwargs = {'data_path': '/content/GLPDepth_new/datasets/', 'dataset_name': data_name, 'is_train': False}

test_dataset = get_dataset(**dataset_kwargs)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)

In [None]:
with tqdm(test_loader, unit='batch') as pbar:
    pbar.set_description(f'Test')
    for batch_idx, batch in enumerate(pbar):
        input_RGB = batch['image'].to(device)
        filename = batch['filename']

        with torch.no_grad():
            pred = model(input_RGB)
        pred_d = pred['pred_d']

        depth_gt = batch['depth'].to(device)
        pred_d, depth_gt = pred_d.squeeze(), depth_gt.squeeze()
        pred_crop, gt_crop = metrics.cropping_img(pred_d, min_dep, max_dep, depth_gt, data_name)
        computed_result = metrics.eval_depth(pred_crop, gt_crop)
        for metric in metric_name:
            result_metrics[metric] += computed_result[metric]

        save_path = os.path.join(result_path, filename[0])

        if depth_map and not heat_map:
            if save_path.split('.')[-1] == 'jpg':
                save_path = save_path.replace('jpg', 'png')
            pred_d = pred_d.squeeze()
            if data_name == 'nyudepthv2':
                pred_d = pred_d.cpu().numpy() * 1000.0
                cv2.imwrite(save_path, pred_d.astype(np.uint16),
                            [cv2.IMWRITE_PNG_COMPRESSION, 0])
            else:
                pred_d = pred_d.cpu().numpy() * 256.0
                cv2.imwrite(save_path, pred_d.astype(np.uint16),
                            [cv2.IMWRITE_PNG_COMPRESSION, 0])
            
        if heat_map and not depth_map:
            pred_d_numpy = pred_d.squeeze().cpu().numpy()
            pred_d_numpy = (pred_d_numpy / pred_d_numpy.max()) * 255
            pred_d_numpy = pred_d_numpy.astype(np.uint8)
            pred_d_color = cv2.applyColorMap(pred_d_numpy, cv2.COLORMAP_RAINBOW)
            cv2.imwrite(save_path, pred_d_color)

    for key in result_metrics.keys():
        result_metrics[key] = result_metrics[key] / (batch_idx + 1)
    display_result = logging.display_result(result_metrics)
    print(display_result)

    print("Done")