-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature request: depthwise separable convolution #1708
Comments
As this underlies https://arxiv.org/abs/1706.03059 the demand for depthwise separable convs will only continue to grow. I believe this is largely a case of cuDNN not providing an optimized implementation; perhaps a new THCUNN kernel is in order -- maybe first for 1D, which should be less complex? Chainer has a fairly simple implementation here which we should maybe port until there’s an optimized kernel. |
For the record, there is an implementation in THNN/THCUNN of |
making this high pri, looks like demand for this is off the charts. |
ported caffe depthwise conv2d here https://github.com/szagoruyko/pyinn (needs CuPy) |
cudnn 7 grouped convolutions and cuda 9 hgemm fix |
Anything new about this problem? |
Hi, based on this paper (https://arxiv.org/pdf/1608.04337.pdf, I think this is one of the earliest discovery of such separable convolution but somehow was ignored), we can benefit by adding flexibility in separating channels into non-overlapping subgroups and perform separate convolutions on each separate subgroup, thus not necessarily view each channel separately and perform convolution on each individual depth only. There are similar walkarounds to achieve such design as depth convolution, though also slow. This adds more flexibility in architecture design and it would be great if PyTorch may consider a more flexible version than the original Torch depthwise separable convolution :) |
Hi, Just want to check the status of this thread. I tested the latest version of pytorch. Looks like there is no update about separable convolution yet. Please correct me if I'm wrong. : ) |
Looking forward to this feature being implemented! |
this is now added to master via #3057 |
Hi, |
@qianguih use m = nn.Conv2d(128, 128, kernel_size=3, groups=128).cuda() |
@colesbury I see. Thank you! |
Since the pytorch and torch share the same backend THCUNN file, could I make this spatialdepthwiseconvolution function into Torch? since the orginal function in torch has GPU memory leak problem. |
@qianguih note that we do support having a depthwise multiplier, so
is also valid. |
@elliothe one potential issue you will likely run into is that the existing |
@killeent Thanks for your answer! What I did is to remake the torch.cunn with the replaced THCUNN library. Then I receive the following errors when I try to call the
I am wondering whether the formating problem you mentioned leads to this error. |
The problem here is that torch.nn is calling SpatialDepthWiseConvolution that was removed https://github.com/torch/nn/blob/master/SpatialDepthWiseConvolution.lua#L80. torch.nn has to be fixed to call SpatialDepthwiseConvolution routines. |
I'm not sure I buy the "bandwidth bound" explanation either. If you increase the kernal size from 3x3 to say 5x5, your operation needs the same amount of non-cache memory, but takes way longer. I don't know how to program cuda, but here's how I see the operation working: Read a single channel from a single layer into cache. That's 400 floats (1600 bytes). It will fit into cache. Then we do the following 18x18 times: Here we have one main memory read and one main memory write per channel per spacial location. If we do the same operation with Conv2d(1024, 1024, kernel_size=5, groups=1024), we should still be doing only one read and one write per spacial location per channel from main memory. The amount of time this takes should be insignificant next to the pointwise convolution that follows the channel wise convolution: Here we need to read a million floats from main memory and do 400 1024x1024 matrix multiplies. The channel-wise operation shouldn't even register next to that. I really wish we could get this working. We could have blazing fast convolutions with arbitrarily large kernels. |
Hi, Currently, pytorch is using thnn implementation of depthwise convolution, thnn_conv_depthwise2d, instead of cudnn. According to recent cudnn 7.1.1 release notes, it seems like cudnn has implemented group convolution for groupCount>1 for all forward & backward algorithms. Are there any plans to switch group convolution backend from thnn to cudnn? |
I'm also much interested in this. NVidia guy emailed me a year ago saying they're gonna implement grouping feature. Yet, it's still nowhere to be found.. This is one of the biggest things in DL in the last several years, yet it doesn't come into fruition. Hope you guys are gonna include it fast, and thanks @SeungjunNah for noticing it in the release notes. Great to know it's at least felly ready on cudnn level by now. Can't wait to see my GPUs chugging conv nets at 4x speed. |
To benefit from cudnn 7.1, is it as simple as removing |
@colinfang fwiw cudnn hasn't optimized for depthwise convolution, only optimized for group convolution (which is a general case of depthwise) |
cudnn has some kernels for depthwise-separable but on average they are no better than pytorch's. Feel free to bring over Tensorflow's tailored implementation to pytorch. |
In my case the input is 9x9, kernel is 5x5, padding=4 and the input channel == output channel == 50000 == groups, batchsize = 1. I tried both cudnn's native |
I haven't checked the TF implementation, but I believe one reason why it might be faster in his case is because his number of channels is fairly large, and TF use by default NHWC layout? |
Also, I don't know how heavily templated tf's is. In pytorch, 5x5 templates are not instantiated and fall to generic case, which might be 2 times slower than if they were instantiated. |
I think tf also only does 3x3 kernel templates. But it has special code path for image of size up to 32x32, not sure if it is relevant. (NHWC or NCHW don't make much difference in tf) |
It does not look like it's too hard to bring it over to pytorch, so if you want pytorch to have awesome performance on small inputs, that might be your best bet :-) |
Hello everyone, I am really confused now. Can someone summarize in one sentence explain how to use high-performance depthwise convolution? The specified pytorch version? Cuda version or cudnn version? |
PyTorch 0.4 with cuDNN should be sufficient. |
cuDNN7? |
latest cudnn7 patch has supported depthwise conv path |
I have a question. |
Expanded kernel weights are shared with the same memory. |
@austingg Hi I can't find anything about supporting depthwise convolution in Nvidia's documentations and release notes. |
@wandering007 Thank you. |
@myih cudnn v7 release note. |
@ibmua According to cudnn v7.3.0 and cudnn v7.6.0 release notes, NVIDIA doesn't seem to be working on NCHW format. |
According to cudnn v7.6.3 release notes,the latest cudnn seems to support depth-wise convolution well。can you check it?thank you very much! |
Does this mean with the new PyTorch 1.5 ‘CHANNELS LAST’ MEMORY FORMAT we can expect group/depthwise convolutions to speed up? |
@HansBambel For what it's worth, I just checked for my implementation of Xception using channels last (and apex fp16). It was actually considerably slower. 50% or so slower. I also ran the apex imagenet example, and didn't see a noticeable speedup for resnet50 (which doesn't have grouped convolutions) with channels last. Both tested on two Titan RTX's with nvlink. |
@tstandley can you please share more details how you run apex imagenet example, to identify problem. We can do it in separate issue/slack/forum or any other way. |
Ok, so I think I've traced the problem to two things. The first is that I have other operations in my network after the encoder portion that are way slower with channels_last. I'm not sure what they are, maybe some of the losses, or ConvTranspose2d. That wasn't tested on the apex imagenet example, but I was using the apex library with O1. The other "issue" is that I'm using DataParallel, and not doing MPSG training. When I just run with a single GPU and Xception just on imagenet with apex (O1), I do get about a 20% speedup with channels last. That's awesome, but on mnasnet, I get a CUDA illegal memory access with channels last and I don't get that error without that flag. (that's probably an apex bug) Also see my other bug on distributed training. |
ConvTranspose2d should be as fast as regular conv in channels last, please file an issue if you find out that's not the case. You can try using autograd profiler or pyprof (which unfortunately now lives only in one of the authors forks https://github.com/adityaiitb/pyprof2) to narrow down what's getting slower. |
I don't see an implementation for depthwise separable convolution. Currently it is possible with
Conv2d
by settinggroups=out_channels
. However this is painstakingly slow. See benchmark at bottom. We need an efficient implementation for this.I realize torch7's
SpatialDepthWiseConvolution
is still slower. However TF seems to have a slightly optimized implementation, so their depthwise conv is about 3x-8x faster than the normal conv (comparing 3x3 conv with 3x3 depthwise conv ONLY, without pointwise conv), but still slow.Benchmark comparing time (in seconds) of different group size (groups=1,2,4,256). We can see that for small number of groups we get reasonable speedup. For large number of groups it gets much slower instead.
The text was updated successfully, but these errors were encountered: