New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement tf.nn.atrous_conv2d_transpose() #4668
Comments
I am not aware of anyone actively working on this. Marking as contributions welcome, and feel free to send us a PR! |
I want to work on it ! |
Hi @Fenugreek Could you offer some documents with the atrous_conv2d_transpose? |
Yes. Here's code that worked for me. It requires three variables that need to be implemented, placed in < >, that I'll say more about below: def atrous_conv2d_transpose(value, filters, rate, padding, name=None):
value = array_ops.space_to_batch(input=value,
paddings=<batch_to_space_crop>,
block_size=rate)
value = tf.nn.conv2d_transpose(value, filters,
<output_shape>, [1, 1, 1, 1],
padding='VALID', name=name)
value = array_ops.batch_to_space(input=value,
crops=<space_to_batch_pad>,
block_size=rate)
return value You'll notice that the steps above are the steps in Also: I see in the code for Thanks. |
great! |
#5300 |
@gpapan I didn't realize that, though I'd read that guide (to convolution arithmetic) you linked to before -- I looked at it again just now. Yes, if you set The one complication I see is for the most common case when the original convolution has So maybe worth implementing anyway. I don't know if this approach (just calling atrous_conv2d with transposed filters, stripping any resulting zero-padding) is different from and/or faster than what @guotong1988 did in #5300. Thanks. |
@gpapan Can you please provide the exact |
@guotong1988 I will look into it in detail and get back to you later this week. |
@guotong1988 -- |
@Fenugreek Could you provide your hard-coded exact code? In fact I'm not sure that my test case can cover that much. Thank you . Here is my only test case.
|
@Fenugreek I confirm my commit by write two more examples here . |
@guotong1988 I tried running your code and got an output shape I was not expecting. Maybe the code is missing the trimming of the zero-padding. See my comment on #5300. |
@Fenugreek I get your point . When the padding is SAME . I should cut the surrounding pixels. |
@guotong1988 OK, I ran your code after your fix, and got reasonable correct looking results this time (I trained something on MNIST and got convergence). This is with One minor thing I saw was that you take |
@Fenugreek I think you are right. I remove the parameter. |
tensorflow#5300. Change: 140759688
Can we have a
atrous_conv2d_transpose()
function, just like the existingconv2d_transpose()
function? Or is there some simple way to get what I am looking for using other existing functions?I had a look at the
conv2d_transpose()
code, and it seems shouldn't be too difficult to adapt it to get aatrous_conv2d_transpose()
.Thanks.
The text was updated successfully, but these errors were encountered: