Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpose convolution layer for tensorflow (was deconvolution) #256

Closed
panmari opened this issue Nov 17, 2015 · 20 comments
Assignees
Labels

Comments

@panmari
Copy link
Contributor

@panmari panmari commented Nov 17, 2015

Has anyone already started implementing a deconvolutional layer for tensorflow as it's used e. g. here https://github.com/stokasto/caffe ? Is anyone else interested in such a functionality or is there any trivial way to implement this using the existing tensors?

@panmari panmari changed the title Deconvolutionlayer for tensorflow Deconvolution layer for tensorflow Nov 17, 2015
@vincentvanhoucke

This comment has been minimized.

Copy link
Member

@vincentvanhoucke vincentvanhoucke commented Nov 17, 2015

It's here:

def deconv2d(value, filter, output_shape, strides, padding="SAME",

Not sure why it's not exposed in the documentation yet, possibly because the API isn't considered stable.

@panmari

This comment has been minimized.

Copy link
Contributor Author

@panmari panmari commented Nov 17, 2015

Cool, thanks for pointing me there!

@vrv

This comment has been minimized.

Copy link
Contributor

@vrv vrv commented Nov 17, 2015

(gradient) of `conv2d`, not an actual deconvolution.
suggests it isn't agreed that the op is doing deconvolution. Hence why it wasn't made public yet.

@vrv vrv reopened this Nov 17, 2015
@suhangpro

This comment has been minimized.

Copy link

@suhangpro suhangpro commented Nov 17, 2015

"deconvolution" has been used as a synonym for transposed convolution in recent computer vision literature, especially in ConvNet related work. I believe this "deconv2d" is exactly the op most people will be using for either upsampling or visualization purposes.

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 5, 2015

@vincentvanhoucke: I'm happy to expose it in the public API as long as we're fine with the misleading name.

@vincentvanhoucke

This comment has been minimized.

Copy link
Member

@vincentvanhoucke vincentvanhoucke commented Dec 5, 2015

@girving I remember being very, very confused when the term 'deconvolution' started popping up in the literature for this operator. I blame Matt (Zeiler) ;-) I had to go back to his papers to convince myself of what they were doing and that it had nothing to do with actually deconvolving the input.
Matt Zeiler calls the entire stack Deconvolution Networks:
http://www.matthewzeiler.com/pubs/iccv2011/iccv2011.pdf
but only ever refers to this operator as a 'projection operator' as far as I can tell.
Caffe calls it Deconvolution (@Yangqing) :
http://caffe.berkeleyvision.org/doxygen/classcaffe_1_1DeconvolutionLayer.html

Since 'deconv' is apparently here to stay, should we consider making the name a bit more explicit? On the table from my POV:
1- transpose_deconv2d()
2- project_deconv2d()
3- deconv2d()

Note that I actually don't know what an alternative 'deconv2d' might look like. The process of deconvolving an input is ill-posed and there are many ways to go about it, so it's possible that option 3 is just fine as long as there is an abundance of documentation. @shlens and @josh11b might have opinions on the matter.

@suhangpro

This comment has been minimized.

Copy link

@suhangpro suhangpro commented Dec 5, 2015

Another option is 'convt' (short for convolution transpose), which is adopted by matconvnet.

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 6, 2015

I think transpose_conv2d or conv2d_transpose are the cleanest names. transpose_deconv2d implies it is the inverse transpose of conv2d. convt is clever but very easy to mistake for conv.

@vincentvanhoucke

This comment has been minimized.

Copy link
Member

@vincentvanhoucke vincentvanhoucke commented Dec 6, 2015

conv2d_transpose SGTM, it would put it jut after conv2d in alphabetical ordering. We should emphasize in the doc that it's what's often referred as 'deconv' or 'deconvolution', so that it shows up if anyone searches the docs for those terms.

@girving girving added the enhancement label Dec 8, 2015
@jeffschecter

This comment has been minimized.

Copy link

@jeffschecter jeffschecter commented Dec 10, 2015

The current implementation crashes the Python kernel when given an output shape with a -1 in the first dimension. It would be helpful to throw a Python error and document this limitation, as ops produced by similar functions (eg conv2d) handle the scenario just fine.

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 10, 2015

@jeffschecter: What kind of error do you get with the kernel crash? This may be the same thing as #449, in which case it's fixed in 5de9085 and will be part of 0.6.0 soon. That is, it should still generate exceptions, but will no longer crash the process.

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 10, 2015

@jeffschecter: Oops, no, it's an independent bug. I'll fix that now. Thanks for the catch.

@futurely

This comment has been minimized.

Copy link

@futurely futurely commented Dec 10, 2015

There are many algorithms to deconvolve an image. For easier extensibility in the future and to avoid being locked in a specific implementation, the public API should add a parameter to reflect which method is used and call the corresponding internal function.
def deconv2d(value, filter, output_shape, strides, padding = 'SAME', method = 'transpose', name = None)
def _deconv2d_transpose(value, filter, output_shape, strides, padding = 'SAME', name = None)

https://reference.wolfram.com/language/ref/ImageDeconvolve.html
http://mathworld.wolfram.com/Deconvolution.html

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 10, 2015

@futurely: I'm not sure what you mean by those references, since they are to actual deconvolution ops. There are many ways to perform deconvolution, but the transpose of convolution is not one of them.

@futurely

This comment has been minimized.

Copy link

@futurely futurely commented Dec 10, 2015

For example, if blind deconvolution is implemented too, users can switch between different deconvolution algorithm with the following code. It's stylistic difference but may be a bit easier to use than to change potentially multiple occurrences of deconv2d_transpose and deconv2d_blind.

method = 'blind'
# method = 'transpose'
deconv2d(..., method = method)
@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 10, 2015

@futurely: Feature requests about deconvolution should go in a separate issue. This thread is about the transpose of convolution, which is unrelated.

@josh11b

This comment has been minimized.

Copy link
Member

@josh11b josh11b commented Dec 10, 2015

The transpose-convolution operator already exists in TF, I think it is one of the conv_2d_backprop_*() functions. If we were to give it another name as part of exposing it in the api, I'd prefer conv_2d_transpose or some such and having documentation that some sources mistakenly refer to that op as deconvolution. I think we should not contribute to the misuse of the deconvolution term -- it leads to things like futurely@'s confusion.

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 10, 2015

Yep, Vincent and I agreed on conv2d_transpose earlier in the thread.

@girving girving changed the title Deconvolution layer for tensorflow Transpose convolution layer for tensorflow (was deconvolution) Dec 10, 2015
@ry

This comment has been minimized.

Copy link
Contributor

@ry ry commented Dec 10, 2015

How do you use deconv2d for upsampling?

@girving

This comment has been minimized.

Copy link
Contributor

@girving girving commented Dec 10, 2015

@ry: Questions like that are better suited to stackoverflow. This thread should stay focused on the missing conv2d_transpose op. (However: it isn't related to upsampling, you may be looking for tf.resize_bilinear and friends).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
10 participants
You can’t perform that action at this time.