Skip to content

Evaluate on NYUV2 test set? #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
oljike opened this issue May 8, 2020 · 13 comments
Closed

Evaluate on NYUV2 test set? #34

oljike opened this issue May 8, 2020 · 13 comments

Comments

@oljike
Copy link

oljike commented May 8, 2020

Hi! Could You please provide a script how you evaluated the NYUV2 test set. I am trying to get the same results as you mentioned in the paper but can't.

@sniklaus
Copy link
Owner

Sorry for not having uploaded it, I noticed some inconsistencies that I wanted to look into further but haven't had time to do so yet. Does the following script work for you and does it replicate the results presented in the paper? Thanks!

#!/usr/bin/env python

import torch
import torchvision

import base64
import cupy
import cv2
import flask
import getopt
import gevent
import gevent.pywsgi
import glob
import h5py
import io
import math
import moviepy
import moviepy.editor
import numpy
import os
import random
import re
import scipy
import scipy.io
import shutil
import sys
import tempfile
import time
import urllib
import zipfile

##########################################################

assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 12) # requires at least pytorch version 1.2.0

torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance

torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance

##########################################################

objCommon = {}

exec(open('./common.py', 'r').read())

exec(open('./models/disparity-estimation.py', 'r').read())
exec(open('./models/disparity-adjustment.py', 'r').read())
exec(open('./models/disparity-refinement.py', 'r').read())
exec(open('./models/pointcloud-inpainting.py', 'r').read())

##########################################################

print('this script first downloads the labeled nyu depth data, which may take a while')
print('it performs the evaluation using the train / test split from Nathan Silberman')
print('it crops the input by 16 pixels, like https://github.com/princeton-vl/relative_depth')

##########################################################

fltAbsrel = []
fltLogten = []
fltSqrel = []
fltRmse = []
fltThr1 = []
fltThr2 = []
fltThr3 = []

##########################################################

if os.path.isfile('./benchmark-nyu-splits.mat') == False:
	urllib.request.urlretrieve('http://horatio.cs.nyu.edu/mit/silberman/indoor_seg_sup/splits.mat', './benchmark-nyu-splits.mat')
# end

intTests = [ intTest - 1 for intTest in scipy.io.loadmat('./benchmark-nyu-splits.mat')['testNdxs'].flatten().tolist() ]

assert(len(intTests) == 654)

##########################################################

if os.path.isfile('./benchmark-nyu-data.mat') == False:
	urllib.request.urlretrieve('http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat', './benchmark-nyu-data.mat')
# end

objData = h5py.File('./benchmark-nyu-data.mat', 'r')

npyImages = numpy.array(objData['images'], numpy.uint8)
npyDepths = numpy.array(objData['depths'], numpy.float32)[:, None, :, :]

objData.close()

assert(npyImages.shape[0] == 1449 and npyImages.shape[2] == 640 and npyImages.shape[3] == 480)
assert(npyDepths.shape[0] == 1449 and npyDepths.shape[2] == 640 and npyDepths.shape[3] == 480)

for intTest in intTests:
	print(intTest)

	npyImage = npyImages[intTest, :, 16:-16, 16:-16].transpose(2, 1, 0)
	npyReference = npyDepths[intTest, :, 16:-16, 16:-16].transpose(2, 1, 0)[:, :, 0]

	tenImage = torch.FloatTensor(numpy.ascontiguousarray(npyImage.copy().transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
	tenDisparity = disparity_estimation(tenImage)
	tenDisparity = disparity_refinement(torch.nn.functional.interpolate(input=tenImage, size=(tenDisparity.shape[2] * 4, tenDisparity.shape[3] * 4), mode='bilinear', align_corners=False), tenDisparity)
	tenDisparity = torch.nn.functional.interpolate(input=tenDisparity, size=(tenImage.shape[2], tenImage.shape[3]), mode='bilinear', align_corners=False) * (max(tenImage.shape[2], tenImage.shape[3]) / 256.0)
	tenDepth = 1.0 / tenDisparity

	npyEstimate = tenDepth[0, 0, :, :].cpu().numpy()
	npyLstsqa = numpy.stack([npyEstimate.flatten(), numpy.full(npyEstimate.flatten().shape, 1.0, numpy.float32)], 1)
	npyLstsqb = npyReference.flatten()
	npyScalebias = numpy.linalg.lstsq(npyLstsqa, npyLstsqb, None)[0]
	npyEstimate = (npyEstimate * npyScalebias[0]) + npyScalebias[1]

	fltAbsrel.append(((npyEstimate - npyReference).__abs__() / npyReference).mean().item())
	fltLogten.append((numpy.log10(npyEstimate) - numpy.log10(npyReference)).__abs__().mean().item())
	fltSqrel.append(((npyEstimate - npyReference).__pow__(2.0) / npyReference).mean().item())
	fltRmse.append(numpy.sqrt((npyEstimate - npyReference).__pow__(2.0).mean()).item())
	fltThr1.append((numpy.maximum((npyEstimate / npyReference), (npyReference / npyEstimate)) < 1.25 ** 1).mean().item())
	fltThr2.append((numpy.maximum((npyEstimate / npyReference), (npyReference / npyEstimate)) < 1.25 ** 2).mean().item())
	fltThr3.append((numpy.maximum((npyEstimate / npyReference), (npyReference / npyEstimate)) < 1.25 ** 3).mean().item())
# end

##########################################################

print('abs_rel = ', sum(fltAbsrel) / len(fltAbsrel))
print('log10   = ', sum(fltLogten) / len(fltLogten))
print('sq_rel  = ', sum(fltSqrel) / len(fltSqrel))
print('rms     = ', sum(fltRmse) / len(fltRmse))
print('thr1    = ', sum(fltThr1) / len(fltThr1))
print('thr2    = ', sum(fltThr2) / len(fltThr2))
print('thr3    = ', sum(fltThr3) / len(fltThr3))

@oljike
Copy link
Author

oljike commented May 11, 2020

Great! It works just as mentioned in the paper! However, I have a question. You are solving least squares problem to transform prediction to target domain. Does it mean that the output of the model is not in meters?

@sniklaus
Copy link
Owner

I am happy to hear that you can replicate the results stated in our paper! And you are correct, the prediction does not have a unit. By the way, I think that this line:

npyImage = npyImages[intTest, :, 16:-16, 16:-16].transpose(2, 1, 0)

Should be changed to reverse the channel ordering to convert RGB to BGR:

npyImage = npyImages[intTest, :, 16:-16, 16:-16].transpose(2, 1, 0)[:, :, ::-1]

In other words, it seems like that I have accidentally done the evaluation on an incorrect channel ordering. As such, our paper reports the NYU metrics as higher than they should be.

@oljike
Copy link
Author

oljike commented May 11, 2020

Yes! i have done the changes and got much better results.
The old ones are:
abs_rel = 0.07661436491899352
log10 = 0.0323186856466673
sq_rel = 0.041415154813063994
rms = 0.298233926677877
thr1 = 0.9360783415947422
thr2 = 0.9871355444319663
thr3 = 0.9969128232439605

Results after changes:
abs_rel = 0.054298472797230386
log10 = 0.023021275363435303
sq_rel = 0.0257079871146212
rms = 0.23403965844131938
thr1 = 0.9644042508435656
thr2 = 0.9918034294272099
thr3 = 0.9978733982407989

@sniklaus
Copy link
Owner

I get pretty much the same results on my end, it just seems a little too good to be true when comparing to the other methods.

@oljike
Copy link
Author

oljike commented May 11, 2020

You mean visual comparison? I have not seen better results yet on single image... Also, You mentioned that you are planning to release the dataset? Do you have any updates?

@oljike
Copy link
Author

oljike commented May 11, 2020

Another question is about predictions transformations. Is it fair to use True targets for evaluation? You solve least square by X = (A.T A)^-1 (A.T B) which already have B (True targets) in it to transform your outputs?

@sniklaus
Copy link
Owner

sniklaus commented May 12, 2020

I just re-ran the NYU evaluation with the corrected color channel order. As such, the part that corresponds to NYU in Table 1 of our paper should be corrected as follows.

method data rel log rms sig1 sig2 sig3
DIW DIW 0.25 0.10 0.75 0.62 0.88 0.96
DIW DIW + NYU 0.18 0.08 0.57 0.75 0.94 0.98
DeepLens iPhone 0.27 0.10 0.81 0.59 0.86 0.95
MegaDepth Mega 0.23 0.09 0.69 0.65 0.89 0.97
MegaDepth Mega + DIW 0.19 0.08 0.60 0.71 0.93 0.98
Ours Mega + NYU + Ours 0.05 0.02 0.23 0.96 0.99 1.00
Ours + Refinement Mega + NYU + Ours 0.05 0.02 0.23 0.96 0.99 1.00

I am just always a little skeptic when results are much better than competing methods (on a side note, MiDaS is a great single image depth estimation method as well but we have not compared to it in our paper). We have a great dataset and spent a lot of time optimizing our network, but there always is the worry about having made a mistake in the evaluation. If anyone spots an issue with the evaluation code I posted above then please do not hesitate to let me know.

As for the dataset, please see #25 for any updates. We are currently still discussing with our legal team but hope to have an update soon.

Regarding your question about the transform. Yes, it is fair since we 1) use the same transform optimization for all methods we compare to, and 2) it is just solving for a scale (and bias) that is impossible to predict otherwise.

@oljike
Copy link
Author

oljike commented May 14, 2020

Great! Thanks for reply!

@dfrumkin
Copy link

dfrumkin commented May 29, 2020

Hello Simon!

Thank you very much for the script! Actually, I would like to echo @oljike regarding the use of least squares. The question is what is the fair and correct way to evaluate a model that returns scale- and shift-invariant disparity against the ground truth which is absolute depth.

It seems using the fact that the ground truth is within a certain range (0 to 10) is not fair - and you do not use it. At the same time, however, this could be implicitly part of the model (e.g. you actually trained on NYU, whereas other models, such as MegaDepth, did not) - or even explicitly if somebody forces his model to return values in the target range (e.g. by putting sigmoid * 10 at the end).

Another important issue is outliers. Minimal clamping or some more robust mapping (median, MAD, etc.) in addition to or instead of least squares could have a huge impact on the result (talking in general, not specifically about your model). Alternatively, typical disparity visualization when you do just affine mapping to 0..1 could result in much worse results - so manual evaluation of the output could seem pretty bad relative to the benchmark results.

Also, what is the reason you remove 16 pixels from each edge? For example, you could compute the depth on the whole image, but then skip comparing depth on the border because you were not sure of the final extrapolation/interpolation in that area. However, you start by removing the border from the image and GT depth and then compute and compare on the whole resulting image and depth map.

Would be happy to hear your considerations. And, once again, thank you very much for sharing your code!

@sniklaus
Copy link
Owner

sniklaus commented May 29, 2020

You are very welcome to change the error minimization metric and redo the evaluation. 🙂

I remain that our evaluation is fair, we used the same error minimization for all baselines. If you disagree with evaluating MegaDepth in this way because it has not been trained on NYU then ignore those results and focus on the ones from DIW, they use the same network architecture after all. Also, DIW themselves use a similar error minimization metric in their paper but instead of on a per-sample basis they do it on the entire set. If you do some literature review you will find that there are a few different approaches for dealing with the scale ambiguity in the evaluation, none of them are fully satisfying. The important thing is to be consistent though and to treat all methods in the same way, which we do.

I am afraid that I do not understand you concerns about the disparity visualization. We do not compare visualizations of depth or disparity maps in our paper. We only do so in the supplementary material since one reviewer asked us for such a comparison. However, we belief that visualizations of depth maps have little meaning. Side-by-side comparisons, for example, make it impossible to judge how well the depth edges are aligned with image edges at depth discontinuities.

I strongly believe that cropping the input image by 16 pixels is more appropriate than cropping the depth estimate since the input images are subject to a white boundary as shown below. Besides, evaluating the cropped depth estimate when using the unaltered input yields insignificant changes in the error measurement as also shown below. And again, the most important thing is to be consistent and to treat all methods in the same way which we do.

Screenshot_2020-05-29_10-03-52

abs_rel =  0.060580624187712644
log10   =  0.025731676240346255
sq_rel  =  0.03038643302785059
rms     =  0.25608104297361517
thr1    =  0.958100431114699
thr2    =  0.9910660395254196
thr3    =  0.9975169521423967

@sniklaus
Copy link
Owner

I just added the benchmark-nyu.py and am hence closing this issue. Thanks again for everyone involved in this discussion!

@YvanYin
Copy link

YvanYin commented Apr 28, 2023

Hi Sniklaus,
I have a question about your depths unit. When I load the depth, the value range is around [100, 65000]. Are depths saved in millimeters? Are depths over 50000 sky regions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants