Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of depth-wise convolution #1098

Closed
wants to merge 16 commits into from
Closed

Conversation

stooloveu
Copy link
Contributor

@stooloveu stooloveu commented Jan 8, 2017

I added support for 'SpatialDepthWiseConvolution', which is similar to TensorFlow's 'depthwise_conv2d'. I think it might be useful for some users.

Do not merge 'README.md' and 'install.sh', I edited for my own convenience. However, description for this module can be found in 'README.md'. If it is useful, the description can be added to the docs file later.

@stooloveu
Copy link
Contributor Author

stooloveu commented Jan 9, 2017

@soumith If it is useful, I will continue on writing the test part for TDD. Thank you!

@stooloveu
Copy link
Contributor Author

@Atcold Thanks for the advice. Will work on it.

@soumith
Copy link
Member

soumith commented Jan 9, 2017

are you planning to work on a CUDA implementation for this? If not, consider putting this in https://github.com/clementfarabet/lua---nnx which we reserve for experimental modules

@stooloveu
Copy link
Contributor Author

@soumith Yes. Actually I have written a CUDA implemention fot this (still under testng for its speed). See https://github.com/stooloveu/cunn

@Atcold
Copy link
Contributor

Atcold commented Jan 16, 2017

@stooloveu, let me know when your code passes the test I gave you, so that I can give it a final look. OK?

@stooloveu
Copy link
Contributor Author

@Atcold Like I said in the email, the dimensions are correct. I also modified the test script by adding weight and bias to test the result, it is also identical (with the output of your test code).

@Atcold
Copy link
Contributor

Atcold commented Jan 24, 2017

@stooloveu, sweet! Could you please profile your code vs. mine?
What about the CUDA version? Is that one correct as well?

@nicholas-leonard
Copy link
Member

@stooloveu This PR is starting to look quite good. All that remains is some unit tests and documentation.

@Atcold What do you mean your code? Do you already have a version of this?

@stooloveu
Copy link
Contributor Author

@nicholas-leonard @Atcold Thank you! The codes for both nn and cunn are done. I was working on exams last week and I will put some work on the test and documentation this weekend.

And I think what he meant by his code was that he wrote some test code to compare those results for preliminary small-size tests.

@Atcold
Copy link
Contributor

Atcold commented Feb 17, 2017

@nicholas-leonard, Torch "can" already perform depth-wise convolutions. It is simply non efficient. Hence, I delegated @stooloveu to take care of the optimised version, for both CPU and GPU, which has to produce the same numerical results of the non efficient version but quickly.
Right now I'm waiting for some profiling and testing, but I'm confident that @stooloveu has done a great job.

@stooloveu
Copy link
Contributor Author

stooloveu commented Mar 20, 2017

@nicholas-leonard @soumith This functional part (depth wise convolution) for 'nn' package is ready to merge:

It has passed the test written by @Atcold (see above). The test code and profiling information can be found at: https://github.com/stooloveu/SpatialDepthWiseConvolution_profiling

The documentation has been added into doc/convolution.md

I will also submit a pull request for the GPU version under 'torch/cunn'.
(torch/cunn#452)

Please let me know if there's any problem.

Many thanks!

EDIT:

  1. I don't know why it fails the checks after I merged it with upstream today:
/home/travis/torch/install/bin/luajit: /home/travis/torch/install/share/lua/5.1/nn/THNN.lua:82: declaration specifier expected near 'real' at line 757
  1. I think I might also need to add the test code into test.lua, but I'm not sure about how to do it. I would appreciate it a lot if anyone could help me here.

@Atcold
Copy link
Contributor

Atcold commented Mar 20, 2017

Yeah, it looks fine to me (yup, tests are needed).
@soumith, what do you think?

@nicholas-leonard
Copy link
Member

nicholas-leonard commented May 12, 2017

Added @Atcold 's unit test, rebased to one commit (that was hell), merged in 5f1f7a2. Thanks @stooloveu @Atcold

@killeent
Copy link

killeent commented Aug 1, 2017

@Atcold @nicholas-leonard @stooloveu why is bias a 2D Tensor with shape (nOutputPlane, nInputPlane)?

cc @soumith

@soumith
Copy link
Member

soumith commented Aug 1, 2017

looks like a bug to me. the bias code is wrong. bias has to be 1d and nOutputPlane

@stooloveu
Copy link
Contributor Author

@killeent @soumith If the bias has to be 1d, then its size should be (nOutputPlane*nInputPlane) right?

@soumith
Copy link
Member

soumith commented Aug 1, 2017

you have one single bias value for each output channel. so bias tensor will just be 1d of size nOutputPlane

@stooloveu
Copy link
Contributor Author

I think as the output image size will be a 3D tensor (nOutputPlane * nInputPlane) x oheight x owidth, the bias should be a 1D tensor with the size of (nOutputPlane * nInputPlane).
Then it will be the same as Tensorflow's implementation:
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/nn/depthwise_conv2d

@stooloveu
Copy link
Contributor Author

@soumith
Copy link
Member

soumith commented Aug 1, 2017

i think there's a terminology fix here.

nOutputPlane = nInputPlane * channel_multiplier.

@soumith
Copy link
Member

soumith commented Aug 1, 2017

bias will be 1d of size nOutputPlane.

trevor, when channel_multiplier = 1, it's groups=nInputPlane.

the fix we need to do is make bias be 1d

@stooloveu
Copy link
Contributor Author

Yeah, that's a good point. I used nOutputPlane to be consistent with normal conv. Maybe I should address the terminology better.

Thanks!

@fmassa
Copy link
Contributor

fmassa commented Aug 1, 2017

I think we should benchmark these functions against what we currently have in pytorch, because it only performs a for loop over the number of channels. One alternative implementation is to follow something similar to caffe, which have dedicated kernels for each operation, so should be faster?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants