In [5]:
import os
import glob
import math
import torch
import numpy as np
import torchnet as tnt
from importlib import import_module
from torch.nn.functional import interpolate

def root_mean_sqrt_error(im_pred, im_true, border=6, mul_ratio=100, is_train=False):
    b, c, h, w = im_true.size()
    if not is_train:
        im_pred, im_true = im_pred.reshape(im_pred.size(0), -1), im_true.reshape(im_true.size(0), -1)
        img_min, _ = torch.min(im_true, dim=1, keepdim=True)
        img_max, _ = torch.max(im_true, dim=1, keepdim=True)
        im_pred = im_pred * (img_max - img_min) + img_min
    if border != 0:
        im_pred = im_pred.reshape(b, c, h, w)[:, :, border: -border, border: -border]
        im_true = im_true.reshape(b, c, h, w)[:, :, border: -border, border: -border]

    return round(torch.sqrt(torch.mean(((im_true * mul_ratio) - (im_pred * mul_ratio)) ** 2)).item(), 5), im_pred

def get_lowers(im_np, factor, mode='last'):
    """
    mode: 'bicubic', 'bilinear', 'nearest', 'last', 'center'
    """
    if im_np.ndim == 3:
        im_np = im_np.transpose(1, 2, 0)

    h0, w0 = im_np.shape[:2]
    h, w = int(math.ceil(h0 / float(factor))), int(math.ceil(w0 / float(factor)))

    if h0 != h * factor or w0 != w * factor:
        im_np = resize(im_np, (h * factor, w * factor), order=1, mode='reflect', clip=False, preserve_range=True,
                       anti_aliasing=True)

    if mode in ('last', 'center'):
        if mode == 'last':
            idxs = (slice(factor - 1, None, factor),) * 2
        else:
            assert mode == 'center'
            idxs = (slice(int((factor - 1) // 2), None, factor),) * 2
        lowers = im_np[idxs].copy()
    else:
        if len(im_np.shape) == 3:
            im_np = im_np[:, :, 0]
            lowers = np.expand_dims(np.array(Image.fromarray(im_np).resize((w, h), Image.BICUBIC)), 2)
        else:
            lowers = np.array(Image.fromarray(im_np).resize((w, h),Image.BICUBIC))

    if lowers.ndim == 3:
        lowers = lowers.transpose((2, 0, 1))

    return lowers

class args:
    norm = 'None'
    act = 'PReLU'
    num_pyramid = 3
    num_features = 32
    guide_channels = 3
    filter_size = 3
    kernel_norm = 'TGASS'
device = torch.device('cuda')
module = import_module('models.dagf')
model = module.make_model(args).to(device)
model = torch.nn.parallel.DataParallel(model, device_ids=list(range(1)))
checkpoint = torch.load('./net_101.pth', map_location=lambda storage, loc: storage.cuda())
model.load_state_dict(checkpoint['state'])

gt_list = sorted(glob.glob('/data/zhwzhong/Data/Depth/NYU/npy/nyu/test/gt/*.npy'))

scale = 16
test_rmse = tnt.meter.AverageValueMeter()
module = max(int(math.pow(2, 1 + args.num_pyramid)), scale)

for dep_name in gt_list:
    dep_img = np.load(dep_name)
    rgb_img = np.load(dep_name.replace('gt', 'rgb'))
    tmp_gt = (dep_img - np.min(dep_img)) / (np.max(dep_img) - np.min(dep_img))
    lr_img = get_lowers(tmp_gt, factor=scale, mode='last')  # nearest
    
    lr_img = torch.from_numpy(lr_img).float().cuda().unsqueeze(0).unsqueeze(0)
    dep_img = torch.from_numpy(dep_img).float().cuda().unsqueeze(0).unsqueeze(0)
    
    rgb_img = torch.from_numpy(np.transpose(rgb_img.astype(np.float32), (2, 0, 1)) / 255).float().cuda().unsqueeze(0)
    lr_up = interpolate(lr_img, scale_factor=scale, mode='bicubic', align_corners=False)
    with torch.no_grad():
        out_img = model(lr=lr_img, rgb=rgb_img, lr_up=lr_up)[-1]
    
    rmse = root_mean_sqrt_error(im_pred=lr_up, im_true=dep_img, border=6, mul_ratio=100)[0]
    test_rmse.add(rmse)
print(test_rmse.value()[0])


22.69759367483295
