Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tf.nn.atrous_conv2d_transpose() #4668

Closed
Fenugreek opened this issue Sep 30, 2016 · 16 comments
Closed

Implement tf.nn.atrous_conv2d_transpose() #4668

Fenugreek opened this issue Sep 30, 2016 · 16 comments
Labels
stat:contribution welcome Status - Contributions welcome

Comments

@Fenugreek
Copy link
Contributor

Can we have a atrous_conv2d_transpose() function, just like the existing conv2d_transpose() function? Or is there some simple way to get what I am looking for using other existing functions?

I had a look at the conv2d_transpose() code, and it seems shouldn't be too difficult to adapt it to get a atrous_conv2d_transpose().

Thanks.

@concretevitamin concretevitamin added the stat:contribution welcome Status - Contributions welcome label Oct 3, 2016
@concretevitamin
Copy link
Contributor

I am not aware of anyone actively working on this. Marking as contributions welcome, and feel free to send us a PR!

@guotong1988
Copy link
Contributor

I want to work on it !

@burness
Copy link
Contributor

burness commented Oct 24, 2016

Hi @Fenugreek Could you offer some documents with the atrous_conv2d_transpose?

@Fenugreek
Copy link
Contributor Author

Yes. Here's code that worked for me. It requires three variables that need to be implemented, placed in < >, that I'll say more about below:

def atrous_conv2d_transpose(value, filters, rate, padding, name=None):
    value = array_ops.space_to_batch(input=value,
                                     paddings=<batch_to_space_crop>,
                                     block_size=rate)

    value = tf.nn.conv2d_transpose(value, filters,
                                   <output_shape>, [1, 1, 1, 1],
                                   padding='VALID', name=name)

    value = array_ops.batch_to_space(input=value,
                                     crops=<space_to_batch_pad>,
                                     block_size=rate)
    return value

You'll notice that the steps above are the steps in nn_ops.atrous_conv2d() in reverse. In that code, the variables batch_to_space_crop and space_to_batch_pad are constructed. We can use the same construction (I did, and it worked). The other variable in the code excerpt above that needs implementation, output_shape, however, needs some new calculation. I hard-coded it to work in my specific case, and so have no code to offer for the general case.

Also: I see in the code for nn_ops.conv2d_transpose() a call to gen_nn_ops.conv2d_backprop_input(...). Not sure if it's better to implement/use something like that rather than above approach.

Thanks.

@guotong1988
Copy link
Contributor

great!

@gpapan
Copy link

gpapan commented Nov 3, 2016

#5300
I am curious what the motivation/use for adding atrous_conv2d_transpose(inputs, filters) is in the first place.
Since atrous_conv2d() always has stride=1, its transpose atrous_conv2d_transpose(inputs, filters) is equivalent to atrous_conv2d(inputs, mirrored(filters)) -- see for example sec. 4.3 and 4.4 of https://arxiv.org/pdf/1603.07285.pdf.

@Fenugreek
Copy link
Contributor Author

@gpapan I didn't realize that, though I'd read that guide (to convolution arithmetic) you linked to before -- I looked at it again just now. Yes, if you set tf.transpose(filters, perm=[1, 0]) as the kernel you can achieve the transpose.

The one complication I see is for the most common case when the original convolution has padding="SAME" (rather than padding="VALID"), after the transpose you'll have values which contain some zero-padding. This zero-padding is easily stripped using array slicing, which I think has recently been implemented (including gradients), but at least I am not so confident of just writing it into my code without testing to check that I used the right indices for the slices.

So maybe worth implementing anyway. I don't know if this approach (just calling atrous_conv2d with transposed filters, stripping any resulting zero-padding) is different from and/or faster than what @guotong1988 did in #5300. Thanks.

@guotong1988
Copy link
Contributor

@gpapan Can you please provide the exact mirrored function? Thank you!

@gpapan
Copy link

gpapan commented Nov 29, 2016

@guotong1988 I will look into it in detail and get back to you later this week.

@Fenugreek
Copy link
Contributor Author

@guotong1988 -- mirrored(filters) should just be tf.transpose(filters, perm=[1, 0]). But @gpapan can confirm.

@guotong1988
Copy link
Contributor

@Fenugreek Could you provide your hard-coded exact code? In fact I'm not sure that my test case can cover that much. Thank you . Here is my only test case.

# Input, output: [batch, height, width, depth]
x_image = tf.placeholder(tf.float32,shape=[1])
x = tf.reshape(x_image,[1,1,1,1])

#Filter: W [kernel_height, kernel_width, output_depth, input_depth]
W_cpu = np.array([[1,-1,1],[1,1,1],[-1,1,-1]],dtype=np.float32)
W = tf.Variable(W_cpu)
W = tf.reshape(W, [3,3,1,1])

strides=[1, 1, 1, 1]
padding='VALID'

y = tf.nn.atrous_conv2d_transpose(x, W, [1,5,5,1], 2, strides, padding)

x_data = np.array([1],dtype=np.float32)
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)

    x = (sess.run(x, feed_dict={x_image: x_data}))
    W = (sess.run(W, feed_dict={x_image: x_data}))
    y = (sess.run(y, feed_dict={x_image: x_data}))

    print "The shape of x:\t", x.shape, ",\t and the x.reshape(1) is :"
    print x.reshape(1)
    print ""

    print "The shape of x:\t", W.shape, ",\t and the W.reshape(3,3) is :"
    print W.reshape(3,3)
    print ""

    print "The shape of y:\t", y.shape, ",\t and the y.reshape(5,5) is :"
    print y.reshape(5,5)
    print ""

@guotong1988
Copy link
Contributor

guotong1988 commented Dec 1, 2016

@Fenugreek I confirm my commit by write two more examples here .
But I'm afraid that there is still some bug with my commit.

@Fenugreek
Copy link
Contributor Author

@guotong1988 I tried running your code and got an output shape I was not expecting. Maybe the code is missing the trimming of the zero-padding. See my comment on #5300.

@guotong1988
Copy link
Contributor

guotong1988 commented Dec 1, 2016

@Fenugreek I get your point . When the padding is SAME . I should cut the surrounding pixels.
Or when the padding is SAME, I could use atrous_conv2d(inputs, mirrored(filters))

@Fenugreek
Copy link
Contributor Author

@guotong1988 OK, I ran your code after your fix, and got reasonable correct looking results this time (I trained something on MNIST and got convergence).

This is with rate=2 and filters with shape [5, 5, 1, 8] and padding='SAME'.

One minor thing I saw was that you take strides as an argument but don't use it for padding='SAME' (and possibly misuse it for padding='VALID'). Maybe simply don't support it? (Since atrous_conv2d does not support it.)

@guotong1988
Copy link
Contributor

@Fenugreek I think you are right. I remove the parameter.

benoitsteiner pushed a commit to benoitsteiner/tensorflow that referenced this issue Dec 3, 2016
@vrv vrv closed this as completed Dec 7, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome
Projects
None yet
Development

No branches or pull requests

6 participants