Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atrous convolution does not preserve tensor shape #4742

Closed
ethereon opened this issue Oct 4, 2016 · 24 comments
Closed

Atrous convolution does not preserve tensor shape #4742

ethereon opened this issue Oct 4, 2016 · 24 comments
Labels

Comments

@ethereon
Copy link

ethereon commented Oct 4, 2016

For an input with an undefined batch size, atrous_conv2d emits tensors where all except the final dimension are undefined:

input = tf.placeholder(tf.float32, (None, 256, 256, 3))

conv = tf.nn.conv2d(input, tf.zeros((3, 3, 3, 16)), strides=[1, 1, 1, 1], padding='SAME')
print(conv.get_shape()) # Correctly displays (?, 256, 256, 16)

dilated = tf.nn.atrous_conv2d(input, tf.zeros((3, 3, 3, 16)), rate=2, padding='SAME')
print(dilated.get_shape()) # Displays (?, ?, ?, 16)

(For concrete batch sizes, everything works as expected.)

Tested on 0.10.0rc0

@tatatodd
Copy link
Contributor

tatatodd commented Oct 4, 2016

Indeed I have reproduced the problem. I believe @gpapan implemented atrous_conv2d, and might have thoughts on how easy this would be to fix.

@tatatodd tatatodd added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 4, 2016
@jkiske
Copy link

jkiske commented Oct 10, 2016

👍

1 similar comment
@gerbenvanveenendaal
Copy link

👍

@AnishShah
Copy link
Contributor

Can I work on this? I think, I can solve this.

@vrv
Copy link

vrv commented Dec 20, 2016

That would be great @AnishShah, can you describe what the problem is?

@AnishShah
Copy link
Contributor

@vrv The problem is in this line. It is using ShapeOp to estimate paddings for ShapeToBatchOp. That is why it is not able to predict the output shape. I tried few things, but I was unsuccessful. What do you suggest?

@aselle aselle removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 10, 2017
@aselle
Copy link
Contributor

aselle commented Feb 15, 2017

@AnishShah, do you have any updates?

@aselle aselle added stat:awaiting response Status - Awaiting response from author type:bug Bug labels Feb 15, 2017
@AnishShah
Copy link
Contributor

AnishShah commented Feb 15, 2017 via email

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Mar 1, 2017
@ahundt
Copy link
Contributor

ahundt commented Mar 12, 2017

I think this may also shrink the dimensions as well? I believe a source of confusion here may be due to varying definitions of atrous convolution depending on the paper being read. Basically, some papers defined atrous convolutions incorrectly when they really meant dilated convolutions. This is explained in Multi-Scale Context Aggregation by Dilated Convolutions with the authors' implementation in https://github.com/fyu/dilation.

Also see the related issue #3492.

I think what people are hoping for is a new function, or perhaps simply an additional parameter to the atrous function, is the ability to specify a constant scale of the output data so this can behave the same as these papers where the output dimensions are the same as the input, since this is particularly useful for semantic segmentation.

I believe this is implemented in tensorflow/models/slim/.../resnet_utils.py in the function conv2d_same. It may even be simple enough to migrate that option directly upstream. @warmspringwinds is also very familiar with this and may be able to verify that everything I've said here is correct, or perhaps contribute some additional information.

@vrv or @tatatodd regarding the TensorFlow API design, if this version of dilated convolution with constant input/output dimensions is supported directly in tf.nn should it be applied via:

  1. atrous_conv2d with the SAME padding flag
  2. atrous_conv2d with a separate parameter
  3. a totally separate function

@vrv
Copy link

vrv commented Mar 12, 2017

I'm going to delegate to @gpapan on this one, who knows more about semantic segmentation and atrous conv :). I would suggest we'd need a totally separate function because of potential confusion between atrous and dilated.

That being said, perhaps someone just posts a good implementation of it here for now instead of having to add it to the API? (Usually, a good sniff test for adding something to the API is whether it's used / fundamental in a state of the art model for an important problem. Otherwise everything under the sun gets added to the core API and our team can't support it all).

@jbms
Copy link

jbms commented Mar 13, 2017

@ahundt In tensorflow, "atrous convolution" and "dilated convolution" are used as synonyms to mean "dilated convolutions" as in the Multi-Scale Context Aggregation by Dilated Convolutions paper you cited.

@AnishShah tf.nn.convolution now provides a more generic interface for atrous convolution for any number of dimensions, and I believe it has slightly more complete shape inference, but there are still cases where it does not infer some of the output shape dimensions even when it could. If you are going to add better shape inference code, I suggest adding it to tf.nn.convolution, as there is separate work underway (see #7545) to make atrous_conv2d simply forward to tf.nn.convolution.

To fix it you will need to use set_shape function on the output tensors to set the additional shape information. I think it would be possible to do this inside of with_space_to_batch, specifically on the input_converted tensor and then again on the result_converted tensor. You will unfortunately have to duplicate some of the work done in calculating the shapes for space_to_batch_nd and batch_to_space_nd. The reason is that a tensor can either be constant or non-constant, but not partially constant.

@ahundt
Copy link
Contributor

ahundt commented Mar 13, 2017

@jbms Thanks for your comment. Does the current code in with_space_to_batch or #7545 have a mode where the output tensor has the same dimensions as the input tensor?

This is the case for conv2d_same in tensorflow/models:

def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
  """[snip...]
  Args:
    inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
  [snip...]
  Returns:
    output: A 4-D tensor of size [batch, height_out, width_out, channels] with
      the convolution output.
  """

I think it would be very productive to add a note in with_space_to_batch explaining what the output dimensions would be relative to given input dimensions in as they vary by configuration.

Regarding your comment on atrous vs dilated convolutions, I quoted the following from a footnote in the Multi-Scale Context Aggregation by Dilated Convolutions:

Some recent work mistakenly referred to the dilated convolution operator itself as the algorithme a trous. This is incorrect. The algorithme a trous applies a filter at multiple scales to produce a signal decomposition. The algorithm uses dilated convolutions, but is not equivalent to the dilated convolution operator itself.

Perhaps this is a bit pedantic but if the paper is stating this correctly, wouldn't it mean TensorFlow is mistaken in its use of atrous and dilation as synonyms? This seems to imply that what is described as the atrous algorithm only dilated filter size, while the dilated version can be configured so the output is the same size as the input.

@ahundt
Copy link
Contributor

ahundt commented Mar 13, 2017

Okay I answered my own question. Yes, both tf.nn.atrous_conv2d and tf.nn.convolution produce the same output dimensions with the SAME flag. I was mixing up the effect of filter size on output dimension, sorry about that.

I made this test and ran it on tf 1.0 which does confirm the original issue with None values:

import tensorflow as tf
import numpy as np

input_img_np = np.random.random((1, 256, 256, 1)).astype(np.float32)
kernel =  np.random.random((6,6,1,1)).astype(np.float32)

with tf.Session() as sess:
    concrete_input_op = tf.constant(input_img_np)
    concrete_output_op = tf.nn.convolution(concrete_input_op, kernel, padding='SAME', dilation_rate=np.array([2, 2]))
    concrete_output = sess.run(concrete_output_op)
    
    print('convolution + CONCRETE + SAME')
    print('concrete_input_op: ', concrete_input_op.get_shape())
    print('concrete_output_op: ', concrete_output_op.get_shape())
    print('concrete_output:', concrete_output.shape)
    assert(concrete_input_op.get_shape() == concrete_output_op.get_shape())


    undef_input_op = tf.placeholder(tf.float32, shape=(None, 256, 256, 1))
    undef_output_op = tf.nn.convolution(undef_input_op, kernel, padding='SAME', dilation_rate=np.array([2, 2]))
    undef_output = sess.run(undef_output_op, feed_dict={undef_input_op: input_img_np})
    
    print('convolution + UNDEF + SAME')
    print('undef_input_op: ', undef_input_op.get_shape())
    print('undef_output_op: ', undef_output_op.get_shape())
    print('undef_output:', undef_output.shape)
    # This assert will correctly fail even though the shapes are ok because shapes are only partially known
    # assert(undef_input_op.get_shape() == undef_output_op.get_shape())

    valid_concrete_input_op = tf.constant(input_img_np)
    valid_concrete_output_op = tf.nn.convolution(valid_concrete_input_op, kernel, padding='VALID', dilation_rate=np.array([2, 2]))
    valid_concrete_output = sess.run(valid_concrete_output_op)
    
    print('convolution + CONCRETE + VALID')
    print('valid_concrete_input_op: ', valid_concrete_input_op.get_shape())
    print('valid_concrete_output_op: ', valid_concrete_output_op.get_shape())
    print('valid_concrete_output:', valid_concrete_output.shape)


    valid_undef_input_op = tf.placeholder(tf.float32, shape=(None, 256, 256, 1))
    valid_undef_output_op = tf.nn.convolution(valid_undef_input_op, kernel, padding='VALID', dilation_rate=np.array([2, 2]))
    valid_undef_output = sess.run(valid_undef_output_op, feed_dict={valid_undef_input_op: input_img_np})
    
    print('convolution + UNDEF + VALID')
    print('valid_undef_input_op: ',  valid_undef_input_op.get_shape())
    print('valid_undef_output_op: ', valid_undef_output_op.get_shape())
    print('valid_undef_output:', valid_undef_output.shape)
    # This assert will correctly fail even though the shapes are ok because shapes are only partially known
    # assert(undef_input_op.get_shape() == undef_output_op.get_shape())
    ############################################################################
    # Now atrous
    concrete_input_op = tf.constant(input_img_np)
    concrete_output_op = tf.nn.atrous_conv2d(concrete_input_op, kernel, padding='SAME', rate=2)
    concrete_output = sess.run(concrete_output_op)
    
    print('atrous_conv2d + CONCRETE + SAME')
    print('concrete_input_op: ', concrete_input_op.get_shape())
    print('concrete_output_op: ', concrete_output_op.get_shape())
    print('concrete_output_op: ', concrete_output_op.get_shape())
    print('concrete_output:', concrete_output.shape)
    assert(concrete_input_op.get_shape() == concrete_output_op.get_shape())


    undef_input_op = tf.placeholder(tf.float32, shape=(None, 256, 256, 1))
    undef_output_op = tf.nn.atrous_conv2d(undef_input_op, kernel, padding='SAME', rate=2)
    undef_output = sess.run(undef_output_op, feed_dict={undef_input_op: input_img_np})
    
    print('atrous_conv2d + UNDEF + SAME')
    print('undef_input_op: ', undef_input_op.get_shape())
    print('undef_output_op: ', undef_output_op.get_shape())
    print('undef_output:', undef_output.shape)
    # This assert will correctly fail even though the shapes are ok because shapes are only partially known
    # assert(undef_input_op.get_shape() == undef_output_op.get_shape())

    valid_concrete_input_op = tf.constant(input_img_np)
    valid_concrete_output_op = tf.nn.atrous_conv2d(valid_concrete_input_op, kernel, padding='VALID', rate=2)
    valid_concrete_output = sess.run(valid_concrete_output_op)
    
    print('atrous_conv2d + CONCRETE + VALID')
    print('valid_concrete_input_op: ', valid_concrete_input_op.get_shape())
    print('valid_concrete_output_op: ', valid_concrete_output_op.get_shape())
    print('valid_concrete_output:', valid_concrete_output.shape)


    valid_undef_input_op = tf.placeholder(tf.float32, shape=(None, 256, 256, 1))
    valid_undef_output_op = tf.nn.atrous_conv2d(valid_undef_input_op, kernel, padding='VALID', rate=2)
    valid_undef_output = sess.run(valid_undef_output_op, feed_dict={valid_undef_input_op: input_img_np})
    
    print('atrous_conv2d + UNDEF + VALID')
    print('valid_undef_input_op: ',  valid_undef_input_op.get_shape())
    print('valid_undef_output_op: ', valid_undef_output_op.get_shape())
    print('valid_undef_output:', valid_undef_output.shape)
    # This assert will correctly fail even though the shapes are ok because shapes are only partially known
    # assert(undef_input_op.get_shape() == undef_output_op.get_shape())

Which produces this output with the additional None values on the last set of printouts:

convolution + CONCRETE + SAME
('concrete_input_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('concrete_output_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('concrete_output:', (1, 256, 256, 1))
convolution + UNDEF + SAME
('undef_input_op: ', TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)]))
('undef_output_op: ', TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)]))
('undef_output:', (1, 256, 256, 1))
convolution + CONCRETE + VALID
('valid_concrete_input_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('valid_concrete_output_op: ', TensorShape([Dimension(1), Dimension(246), Dimension(246), Dimension(1)]))
('valid_concrete_output:', (1, 246, 246, 1))
convolution + UNDEF + VALID
('valid_undef_input_op: ', TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)]))
('valid_undef_output_op: ', TensorShape([Dimension(None), Dimension(246), Dimension(246), Dimension(1)]))
('valid_undef_output:', (1, 246, 246, 1))
atrous_conv2d + CONCRETE + SAME
('concrete_input_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('concrete_output_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('concrete_output_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('concrete_output:', (1, 256, 256, 1))
atrous_conv2d + UNDEF + SAME
('undef_input_op: ', TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)]))
('undef_output_op: ', TensorShape([Dimension(None), Dimension(None), Dimension(None), Dimension(1)]))
('undef_output:', (1, 256, 256, 1))
atrous_conv2d + CONCRETE + VALID
('valid_concrete_input_op: ', TensorShape([Dimension(1), Dimension(256), Dimension(256), Dimension(1)]))
('valid_concrete_output_op: ', TensorShape([Dimension(1), Dimension(246), Dimension(246), Dimension(1)]))
('valid_concrete_output:', (1, 246, 246, 1))
atrous_conv2d + UNDEF + VALID
('valid_undef_input_op: ', TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)]))
('valid_undef_output_op: ', TensorShape([Dimension(None), Dimension(None), Dimension(None), Dimension(1)]))
('valid_undef_output:', (1, 246, 246, 1))

@warmspringwinds
Copy link

@ahundt , in TF atrous and dilated convolution mean the same thing. One of the parameters that they accept is rate which specifies the dilation rate. Definition of rate is consistent in the Deep Lab paper and the paper that you have cited. I think the confusion with the naming is similar to the case with deconvolution which a lot of people use to mean that they perform fractionally strided convolution, while deconvolution at the same time refers to a completely different operation in Signal Processing field.

There are different ways to implement dilated convolution. TF has it implemented by sampling the input feature map which is described in the Deep Lab paper. The piece of code that you refer to actually uses this implementation under the hood.

At the same time, dilated convolution is itself an ordinary convolution -- meaning that if you apply it with the same padding, it should produce the output with the same spatial dimensions.

In case of Image Segmentation, dilated convolution is used to make it possible to use weights
from Image Classification networks after reducing their in-network downsampling (by means of removing layers responsible for downsampling or setting their stride to 1). All of this allows to acquire the prediction map that is downsampled by a smaller factor (most Image Classification models have a downsampling factor of 32 -- for example, it can be reduced to 8 by following the approach described in these papers). After that, you can use bilinear upsampling or learn the upsampling kernel yourself during training to get the prediction map of the same size as the input image. You can find an example of adopting Resnet-101 for Image Segmentation by employing the aforementioned approach here.

@ahundt
Copy link
Contributor

ahundt commented Mar 13, 2017

@warmspringwinds Thanks! Got it now.

All, sorry I ended up hijacking the issue due to the mismatch between my mental model and the design. At least a test script came from it and I learned something, thanks for the clarifications. :-)

@jbms
Copy link

jbms commented Mar 14, 2017

@ahundt The precise definition of with_space_to_batch is given in the docstring. However, the actual output dimensions depend entirely on the behavior of the underlying op that is passed as an argument. I don't think it is possible to specify the output dimensions in a particularly concise way. It isn't intended to be used directly normally, but rather is intended to be used to define new dilated operations.

@jbms
Copy link

jbms commented Mar 22, 2017

#8411 doesn't actually fix this issue --- it just adds some documentation, but does not actually improve the static tensor shape information, which is what this issue is about.

@martinwicke
Copy link
Member

Sorry about that.

@martinwicke martinwicke reopened this Mar 22, 2017
@ahundt
Copy link
Contributor

ahundt commented Mar 23, 2017

Sorry, that was actually my fault! I meant to write that it resolves a point of confusion discussed in this issue.

@martinwicke
Copy link
Member

No, it's my fault, I did edit your description to make the PR close this issue. That was a little optimistic.

@ahundt
Copy link
Contributor

ahundt commented Mar 23, 2017

Aha didn't realize that, well at least things are set correctly now.

@PhilJd
Copy link
Contributor

PhilJd commented May 23, 2017

This seems to be fixed in at least tensorflow version 1.1.0-rc2!

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and this issue has an assignee.Please update the label and/or status accordingly.

@michaelisard
Copy link

Closing since this seems obsolete, but please reopen if it needs attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests