Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.image.resize_images() - weird padding behaviour? #6720

Closed
JoelKronander opened this issue Jan 8, 2017 · 45 comments

Comments

@JoelKronander
Copy link

commented Jan 8, 2017

The tf.image.resize_images() seems to use a strange padding option, which one is not clear to me at the moment. I tried to replicate the bilinear interpolation with various padding options in for example skimage, but cant replicate the behaviour.

It would be nice to be able to set the padding option used in tf.images.resize_images(), or document what is used at least.

Example code for comparing the results of tf.images.resize_images() and skimage transform:
Looks like tf.images.resize_images() does some weird unsymmetrical padding!?
Using tensorflow 0.12.1:

import tensorflow as tf
import tensorlayer as tl
import numpy as np
import skimage
from scipy.misc import imread, imresize, imsave

sess = tf.InteractiveSession()

#create simple test image
imsize = 3
xa, ya = np.ogrid[:imsize, :imsize]
img = np.repeat((xa + ya)[..., np.newaxis], 3, 2) / float(imsize + imsize)

x = tf.placeholder(tf.float32, [1, imsize, imsize, 3])
y = tf.image.resize_images(x,(imsize*3, imsize*3))

sess.run(tf.global_variables_initializer())

upsampled_tf_result = sess.run(y, feed_dict={x: [img]})
upsampled_skimage_result = skimage.transform.rescale(img,
                                     3,
                                     mode='symmetric',
                                     cval=0,
                                     order=1,
                                     preserve_range=False)

print(np.allclose(upsampled_tf_result, upsampled_skimage_result))

imsave('upsampled_tf_result.png', np.squeeze(upsampled_tf_result))
imsave('upsampled_skimage_result.png', upsampled_skimage_result)
@michaelisard

This comment has been minimized.

Copy link
Member

commented Jan 9, 2017

@xmbrst could this be documented better?

@xmbrst xmbrst assigned dr4b and unassigned xmbrst Jan 9, 2017

@girving

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

I think we'd need to understand better what's going wrong before knowing how to fix this. In what way is it weird?

@ppwwyyxx

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

The corner alignment mechanism is different between tf.image.resize* and skimage.transform.rescale, no matter align_corners=True or False.

@girving

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

Yes, but how are they different?

@ppwwyyxx

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

In skimage, the "area" of a pixel is taken into account. In tf.image it feels like a pixel is considered as a "point" without area. This leads to a difference in alignment.

E.g. when up-scaling a 2x2 image to a 4x4 image, the alignment is:
skimage: (0, 0) -> (0.5, 0.5), (1, 1) -> (2.5, 2.5)
tf.image.resize_*(align_corners=True): (0, 0) -> (0, 0), (1, 1) -> (3, 3)
tf.image.resize_*(align_corners=False): (0, 0) -> (0, 0), (1, 1) -> (2, 2)

Love to hear comments about which alignment is better for NN training.

@girving

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

Have you considered using method=ResizeMethod.AREA if you want area based resizing?

@ppwwyyxx

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

The AREA method seems to only affect interpolation but not alignment. At least I get strange results with it:

import tensorflow as tf
import numpy as np
from skimage.transform import rescale

arr = np.array(
    [[1,2,3,4],
    [5,6,7,8],
    [9,10,11,12]], dtype='float32')

input = tf.constant(arr)
input4D = tf.reshape(input, [1, 3, 4, 1])
resize = tf.image.resize_area(input4D, [6, 8], align_corners=True)[0,:,:,0]
sess = tf.Session()
r1 = sess.run(resize)
r2 = rescale(arr/100.0, 2, mode='edge') * 100

gives

r1=array([[  1.        ,   1.        ,   1.66666651,   2.        ,   2.33333278,   3.        ,   3.        ,   4.00000048],
       [  1.        ,   1.        ,   1.66666651,   2.        ,   2.33333278,   3.        ,   3.        ,   4.00000048],
       [  3.00000024,   3.00000024,   3.66666651,   4.00000048,   4.33333302,   5.00000048,   5.00000095,   6.00000048],
       [  5.        ,   5.        ,   5.66666603,   6.        ,   6.33333206,   7.        ,   7.00000048,   8.00000095],
       [  4.99999952,   4.99999952,   5.66666555,   5.99999952,   6.33333111,   6.99999905,   7.        ,   8.        ],
       [  9.00000191,   9.00000191,   9.66666698,  10.00000191,  10.33333397,  11.00000191,  11.00000191,  12.00000286]], dtype=float32)
r2=array([[  0.99999998,   1.24999997,   1.74999996,   2.24999995,   2.74999994,   3.24999993,   3.74999992,   3.99999991],
       [  2.        ,   2.24999998,   2.74999995,   3.24999994,   3.74999995,   4.24999994,   4.74999991,   4.99999989],
       [  4.00000005,   4.25000001,   4.74999993,   5.24999992,   5.74999998,   6.24999997,   6.74999988,   6.99999984],
       [  6.00000015,   6.25000009,   6.74999999,   7.24999995,   7.74999999,   8.24999996,   8.74999985,   8.9999998 ],
       [  8.00000029,   8.25000023,   8.75000013,   9.25000005,   9.74999999,  10.24999991,  10.74999981,  10.99999975],
       [  9.00000036,   9.25000031,   9.7500002 ,  10.2500001 ,  10.74999999,  11.24999989,  11.74999978,  11.99999973]])
@girving

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

What does that produce? We have a lot of bugs to triage, so it's helpful if people include output along with code.

@girving

This comment has been minimized.

Copy link
Contributor

commented Apr 28, 2017

Ug, you're right, that's pretty weird. @martinwicke Our tf.image.resize_area function isn't even reflection equivariant. It would be lovely to fix this, but I'd be worried about breaking old models.

@gpapan

This comment has been minimized.

Copy link

commented Apr 29, 2017

There are two separate issues here:
(1) Alignment of the tensor values at the input and output of the resize function.
(2) Interpolation method.

For (1), @ppwwyyxx comment is exactly right:
When using align_corners=True, we consider the image value as a point sample of a continuous function at the pixel center. When using align_corners=False, we consider the image value as the average of a continuous function over a 1x1 pixel square centered at the pixel center.

Unfortunately, there is a bug in the implementation of nearest neighbor and area interpolation methods when align_corners=True. For nearest neighbor interpolation this has already been fixed internally and will be pushed to github in the next couple of days. We will fix a similar bug for area interpolation very soon.

@gpapan

This comment has been minimized.

Copy link

commented Apr 29, 2017

@ppwwyyxx, Regarding your question on "which alignment is better for NN training", multiple approaches are possible as long as you are consistent. Here is my own favorite set of rules that we have followed in our DeepLab semantic image segmentation system:

"DeepLab's Four Alignment Rules":
(1) Use of odd-sized kernels in all convolution and pooling ops.
(2) Use of SAME boundary conditions in all convolution and pooling ops.
(3) Use align_corners=True when upsampling feature maps with bilinear interpolation.
(4) Use of inputs with height/width equal to a multiple of the output_stride, plus one (for example, when the CNN output stride is 8, use height or width equal to 8 * n + 1, for some n, e.g., image HxW set to 321x513).

@ppwwyyxx

This comment has been minimized.

Copy link
Contributor

commented Apr 30, 2017

Thanks @gpapan ! For (4), is it to ensure equal paddings on both side of image?

@martinwicke

This comment has been minimized.

Copy link
Member

commented Jan 22, 2018

There are several issues with resize_images. It would be good to have a known-good implementation of this, even if we have to hide it behind a flag (correct=False).

@tpet

This comment has been minimized.

Copy link

commented Apr 13, 2018

Preferably, the new implementation should follow what is used elsewhere (OpenCV, SciPy, Matlab, ...), which is to align the very corners of top-left (-0.5, -0.5) and bottom-right pixels (height - 0.5, width - 0.5) and resample using corresponding pixel centers.

@rsethur

This comment has been minimized.

Copy link

commented May 18, 2018

Many thanks for TF developers/contributors support. I believe this is critical part of image processing pipelines - it would be great to have this fixed soon please.
Detailed writeup by Oleksandr Savsunenko:
https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35

@rmmal

This comment has been minimized.

Copy link

commented Dec 27, 2018

Any new news about this ? , results from tensorflow.resize_bilinear is different from cv2.resize bilinear ,, what should i do to make them like each others as the network output different results in each case

@jkyl

This comment has been minimized.

Copy link
Contributor

commented Dec 28, 2018

Any new news about this ? , results from tensorflow.resize_bilinear is different from cv2.resize bilinear ,, what should i do to make them like each others as the network output different results in each case

If you don't need to compute its gradient or compute on GPU: use cv2.resize via tf.py_func wrapper. Otherwise, unless you want to implement a C++ op along with its gradient and submit a PR, I'm afraid we're SOL.

@johnpjf

This comment has been minimized.

Copy link

commented Dec 28, 2018

Working on it, we'll have something soon!

@maxfiedler

This comment has been minimized.

Copy link

commented Feb 7, 2019

@johnpjf Sorry to bother you as well, but did you already merge something here?
I just posted another issue of which I wonder if it is related to changes that were made w.r.t. this issue here.
See #25591

@martinwicke

This comment has been minimized.

Copy link
Member

commented Feb 22, 2019

We're almost there.

@Overdrivr

This comment has been minimized.

Copy link

commented Feb 28, 2019

Is it possible that this bug impacts UpSampling2D layers in some way ? I'm trying to replace transposed 2d convolution with Upsampling2d + Conv but suddenly it's completely impossible to get the training to converge appropriately.

@johnpjf

This comment has been minimized.

Copy link

commented Mar 1, 2019

@Overdrivr I'm not familiar with UpSampling layer but it's unlikely that this bug would cause something as extreme as failing to converge.

@ppwwyyxx

This comment has been minimized.

Copy link
Contributor

commented Mar 1, 2019

Thanks for this fix! I tried some of my cases and the new version seems to work as expected now.

There are in fact some similar issues with tf.image.crop_and_resize. I have been using my own workaround for it, but if the team is interested I can open an issue describing it in more details.

@martinwicke

This comment has been minimized.

Copy link
Member

commented Mar 1, 2019

@ppwwyyxx I'd be interested in a detailed report.

@Akshunn

This comment has been minimized.

Copy link

commented Mar 14, 2019

@ppwwyyxx, Regarding your question on "which alignment is better for NN training", multiple approaches are possible as long as you are consistent. Here is my own favorite set of rules that we have followed in our DeepLab semantic image segmentation system:

"DeepLab's Four Alignment Rules":
(1) Use of odd-sized kernels in all convolution and pooling ops.
(2) Use of SAME boundary conditions in all convolution and pooling ops.
(3) Use align_corners=True when upsampling feature maps with bilinear interpolation.
(4) Use of inputs with height/width equal to a multiple of the output_stride, plus one (for example, when the CNN output stride is 8, use height or width equal to 8 * n + 1, for some n, e.g., image HxW set to 321x513).

@gpapan Why 8n+1 and not 8n?

@DeepBlender

This comment has been minimized.

Copy link

commented Mar 14, 2019

I wanted to experiment with the new resizing and discovered to my surprise that it is not (anymore?) differentiable. Is this a bug I should report or is this by design?

@martinwicke

This comment has been minimized.

Copy link
Member

commented Mar 14, 2019

@johnpjf

This comment has been minimized.

Copy link

commented Mar 14, 2019

@DeepBlender There certainly should be gradients, thanks for the catch! Looks like I had forgotten to expose gradients for the new kernels (Lanczos etc.). Fix is in review.
Note that, as in TF 1.0, there is no gradient for the area kernel.

@DeepBlender

This comment has been minimized.

Copy link

commented Mar 14, 2019

@johnpjf that sounds great, thanks a lot for your work!
I just started to prepare the bug report, but you were obviously faster.

@QuantumInformation

This comment has been minimized.

Copy link

commented Jun 18, 2019

Any ideas why we ported align_corners over to TFJS? It doesn't have the legacy issues.

@mohapatras

This comment has been minimized.

Copy link

commented Jun 27, 2019

Thanks for this fix! I tried some of my cases and the new version seems to work as expected now.

There are in fact some similar issues with tf.image.crop_and_resize. I have been using my own workaround for it, but if the team is interested I can open an issue describing it in more details.

@ppwwyyxx I cant find anti alias option as it is implemented in the fix above. Can you please explain how to use it ? can't find the new option in the documentation [link below].(https://www.tensorflow.org/api_docs/python/tf/image/resize_images)

@johnpjf

This comment has been minimized.

Copy link

commented Jun 27, 2019

@mohapatras antialias is only available in the 2.0 versions, which you can use by

import tensorflow.compat.v2 as tf_v2
...
tf_v2.image.resize(..., antialias=True, ...)

@johnpjf

This comment has been minimized.

Copy link

commented Jul 12, 2019

Can you provide examples of inputs and outputs?

@biendltb

This comment has been minimized.

Copy link

commented Jul 12, 2019

Can you provide examples of inputs and outputs?

Hi John, sorry as I didn't check my code carefully. I've deleted the comment since there was a bug in my code. My results show that the result between tf.image.resize and cv2.resize are the same when using bi-linear interpolation. However, the speed difference is obvious. This might come from the GPU-performance-sake of the tensorflow platform. Thank you for your time.

@protossw512

This comment has been minimized.

Copy link

commented Aug 13, 2019

@johnpjf @martinwicke
Thank you for the fix, however, I still found the new implementation is not aligned with OpenCV under bilinear interpolation, is that desired behavior?

Here is my test code:

import tensorflow as tf
import tensorflow.compat.v2 as tf_v2
import numpy as np
import cv2
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)
resize_shape = (10, 10)

a = np.ones((1, 2, 2, 1), dtype=np.float32)
a[0, 0, 0, 0] = 5.0
a[0, 1, 1, 0] = 5.0

b = tf.constant(a, dtype=tf.float32)
# c = tf.image.resize_bilinear(b, resize_shape)
c = tf_v2.image.resize(b, resize_shape,
                       method='bilinear',
                       antialias=True)
d = tf_v2.image.resize(c, (5, 5),
                       method='bilinear',
                       antialias=True)

with tf.Session() as sess:
    np_c = sess.run(c)
    np_d = sess.run(d)

temp = cv2.resize(a[0], resize_shape, interpolation=cv2.INTER_LINEAR)
temp2 = cv2.resize(np_c[0, :, :, 0], (5,5), interpolation=cv2.INTER_LINEAR)

print ("Tensorflow:")
print (np_c[0, :, :, 0])
print ("OpenCV:")
print (temp)
print ("Tensorflow:")
print (np_d[0, :, :, 0])
print ("OpenCV:")
print (temp2)
print ("Tensorflow:")
print (np_c[0, :, :, 0] - temp)
print ("OpenCV:")
print (np_d[0, :, :, 0] - temp2)

Here is my output with tensorflow 1.14.0:
The upsample looks correct, but there seems to be some issue with downsample.

Tensorflow:
[[5.   5.   5.   4.2  3.4  2.6  1.8  1.   1.   1.  ]
 [5.   5.   5.   4.2  3.4  2.6  1.8  1.   1.   1.  ]
 [5.   5.   5.   4.2  3.4  2.6  1.8  1.   1.   1.  ]
 [4.2  4.2  4.2  3.72 3.24 2.76 2.28 1.8  1.8  1.8 ]
 [3.4  3.4  3.4  3.24 3.08 2.92 2.76 2.6  2.6  2.6 ]
 [2.6  2.6  2.6  2.76 2.92 3.08 3.24 3.4  3.4  3.4 ]
 [1.8  1.8  1.8  2.28 2.76 3.24 3.72 4.2  4.2  4.2 ]
 [1.   1.   1.   1.8  2.6  3.4  4.2  5.   5.   5.  ]
 [1.   1.   1.   1.8  2.6  3.4  4.2  5.   5.   5.  ]
 [1.   1.   1.   1.8  2.6  3.4  4.2  5.   5.   5.  ]]
OpenCV:
[[5.   5.   5.   4.2  3.4  2.6  1.8  1.   1.   1.  ]
 [5.   5.   5.   4.2  3.4  2.6  1.8  1.   1.   1.  ]
 [5.   5.   5.   4.2  3.4  2.6  1.8  1.   1.   1.  ]
 [4.2  4.2  4.2  3.72 3.24 2.76 2.28 1.8  1.8  1.8 ]
 [3.4  3.4  3.4  3.24 3.08 2.92 2.76 2.6  2.6  2.6 ]
 [2.6  2.6  2.6  2.76 2.92 3.08 3.24 3.4  3.4  3.4 ]
 [1.8  1.8  1.8  2.28 2.76 3.24 3.72 4.2  4.2  4.2 ]
 [1.   1.   1.   1.8  2.6  3.4  4.2  5.   5.   5.  ]
 [1.   1.   1.   1.8  2.6  3.4  4.2  5.   5.   5.  ]
 [1.   1.   1.   1.8  2.6  3.4  4.2  5.   5.   5.  ]]
Tensorflow:
[[5.    4.5   3.    1.5   1.   ]
 [4.5   4.125 3.    1.875 1.5  ]
 [3.    3.    3.    3.    3.   ]
 [1.5   1.875 3.    4.125 4.5  ]
 [1.    1.5   3.    4.5   5.   ]]
OpenCV:
[[5.   4.6  3.   1.4  1.  ]
 [4.6  4.28 3.   1.72 1.4 ]
 [3.   3.   3.   3.   3.  ]
 [1.4  1.72 3.   4.28 4.6 ]
 [1.   1.4  3.   4.6  5.  ]]
Tensorflow:
[[ 0.  0.  0.  0. -0.  0. -0.  0.  0.  0.]
 [ 0.  0.  0.  0. -0.  0. -0.  0.  0.  0.]
 [ 0.  0.  0.  0. -0.  0. -0.  0.  0.  0.]
 [ 0.  0.  0.  0. -0. -0. -0.  0.  0.  0.]
 [-0. -0. -0.  0. -0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [-0. -0. -0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]]
OpenCV:
[[ 0.    -0.1    0.     0.1    0.   ]
 [-0.1   -0.155 -0.     0.155  0.1  ]
 [ 0.     0.     0.     0.     0.   ]
 [ 0.1    0.155  0.    -0.155 -0.1  ]
 [ 0.     0.1    0.    -0.1    0.   ]]

If I change bilinear to bicubic, there are even more inconsistent values in outputs:

Tensorflow:
[[ 6.03   5.833  5.462  4.696  3.598  2.402  1.304  0.538  0.167 -0.03 ]
 [ 5.833  5.649  5.302  4.586  3.559  2.441  1.414  0.698  0.351  0.167]
 [ 5.462  5.302  5.     4.378  3.486  2.514  1.622  1.     0.698  0.538]
 [ 4.696  4.586  4.378  3.949  3.335  2.665  2.051  1.622  1.414  1.304]
 [ 3.598  3.559  3.486  3.335  3.118  2.882  2.665  2.514  2.441  2.402]
 [ 2.402  2.441  2.514  2.665  2.882  3.118  3.335  3.486  3.559  3.598]
 [ 1.304  1.414  1.622  2.051  2.665  3.335  3.949  4.378  4.586  4.696]
 [ 0.538  0.698  1.     1.622  2.514  3.486  4.378  5.     5.302  5.462]
 [ 0.167  0.351  0.698  1.414  2.441  3.559  4.586  5.302  5.649  5.833]
 [-0.03   0.167  0.538  1.304  2.402  3.598  4.696  5.462  5.833  6.03 ]]
OpenCV:
[[5.957 5.899 5.432 4.576 3.545 2.455 1.424 0.568 0.101 0.043]
 [5.899 5.842 5.384 4.545 3.534 2.466 1.455 0.616 0.158 0.101]
 [5.432 5.384 5.    4.296 3.448 2.552 1.704 1.    0.616 0.568]
 [4.576 4.545 4.296 3.84  3.29  2.71  2.16  1.704 1.455 1.424]
 [3.545 3.534 3.448 3.29  3.1   2.9   2.71  2.552 2.466 2.455]
 [2.455 2.466 2.552 2.71  2.9   3.1   3.29  3.448 3.534 3.545]
 [1.424 1.455 1.704 2.16  2.71  3.29  3.84  4.296 4.545 4.576]
 [0.568 0.616 1.    1.704 2.552 3.448 4.296 5.    5.384 5.432]
 [0.101 0.158 0.616 1.455 2.466 3.534 4.545 5.384 5.842 5.899]
 [0.043 0.101 0.568 1.424 2.455 3.545 4.576 5.432 5.899 5.957]]
Tensorflow:
[[5.873 5.046 3.    0.954 0.127]
 [5.046 4.457 3.    1.543 0.954]
 [3.    3.    3.    3.    3.   ]
 [0.954 1.543 3.    4.457 5.046]
 [0.127 0.954 3.    5.046 5.873]]
OpenCV:
[[5.904 5.102 3.    0.898 0.096]
 [5.102 4.521 3.    1.479 0.898]
 [3.    3.    3.    3.    3.   ]
 [0.898 1.479 3.    4.521 5.102]
 [0.096 0.898 3.    5.102 5.904]]
Tensorflow:
[[ 0.072 -0.066  0.03   0.12   0.053 -0.053 -0.12  -0.03   0.066 -0.072]
 [-0.066 -0.192 -0.082  0.041  0.025 -0.025 -0.041  0.082  0.192  0.066]
 [ 0.03  -0.082  0.     0.082  0.038 -0.038 -0.082  0.     0.082 -0.03 ]
 [ 0.12   0.041  0.082  0.109  0.044 -0.044 -0.109 -0.082 -0.041 -0.12 ]
 [ 0.053  0.025  0.038  0.044  0.018 -0.018 -0.044 -0.038 -0.025 -0.053]
 [-0.053 -0.025 -0.038 -0.044 -0.018  0.018  0.044  0.038  0.025  0.053]
 [-0.12  -0.041 -0.082 -0.109 -0.044  0.044  0.109  0.082  0.041  0.12 ]
 [-0.03   0.082  0.    -0.082 -0.038  0.038  0.082  0.    -0.082  0.03 ]
 [ 0.066  0.192  0.082 -0.041 -0.025  0.025  0.041 -0.082 -0.192 -0.066]
 [-0.072  0.066 -0.03  -0.12  -0.053  0.053  0.12   0.03  -0.066  0.072]]
OpenCV:
[[-0.031 -0.056  0.     0.056  0.031]
 [-0.056 -0.064  0.     0.064  0.056]
 [ 0.     0.     0.     0.     0.   ]
 [ 0.056  0.064 -0.    -0.064 -0.056]
 [ 0.031  0.056  0.    -0.056 -0.031]]

Am I missing something?

@johnpjf

This comment has been minimized.

Copy link

commented Aug 13, 2019

For bilinear you are using antialias=True for tensorflow, which enlarges the kernel when downsampling to antialias, this is why your downsampling version is different in TF.
For bicubic you also have antialias=True, but the upsampling version is also different. This looks like it's from a different choice of the parameter in bicubic kernel https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
TF uses a=0.5, while opencv uses a=0.75.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
You can’t perform that action at this time.