Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature request: depthwise separable convolution #1708

Closed
hyqneuron opened this issue Jun 3, 2017 · 53 comments
Closed

feature request: depthwise separable convolution #1708

hyqneuron opened this issue Jun 3, 2017 · 53 comments
Assignees

Comments

@hyqneuron
Copy link

hyqneuron commented Jun 3, 2017

I don't see an implementation for depthwise separable convolution. Currently it is possible with Conv2d by setting groups=out_channels. However this is painstakingly slow. See benchmark at bottom. We need an efficient implementation for this.

I realize torch7's SpatialDepthWiseConvolution is still slower. However TF seems to have a slightly optimized implementation, so their depthwise conv is about 3x-8x faster than the normal conv (comparing 3x3 conv with 3x3 depthwise conv ONLY, without pointwise conv), but still slow.

Benchmark comparing time (in seconds) of different group size (groups=1,2,4,256). We can see that for small number of groups we get reasonable speedup. For large number of groups it gets much slower instead.

Sequential (
  (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
testing model of weight size: torch.Size([256, 256, 3, 3])
3.65793013573
Sequential (
  (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
  (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
)
testing model of weight size: torch.Size([256, 128, 3, 3])
1.99519991875
Sequential (
  (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
  (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4)
)
testing model of weight size: torch.Size([256, 64, 3, 3])
1.34896993637
Sequential (
  (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
)
testing model of weight size: torch.Size([256, 1, 3, 3])
5.64783811569

@jekbradbury
Copy link
Contributor

jekbradbury commented Jun 12, 2017

As this underlies https://arxiv.org/abs/1706.03059 the demand for depthwise separable convs will only continue to grow. I believe this is largely a case of cuDNN not providing an optimized implementation; perhaps a new THCUNN kernel is in order -- maybe first for 1D, which should be less complex?

Chainer has a fairly simple implementation here which we should maybe port until there’s an optimized kernel.

@fmassa
Copy link
Member

fmassa commented Jun 12, 2017

For the record, there is an implementation in THNN/THCUNN of SpatialDepthWiseConvolution which are mostly performing a for loop over the number of groups, so it shouldn't be any more efficient than our current implementation using groups=nInputPlane. But we could maybe modify it to use bmm instead of mm, as in the Chainer example?

@soumith
Copy link
Member

soumith commented Jul 21, 2017

making this high pri, looks like demand for this is off the charts.

@szagoruyko
Copy link
Contributor

ported caffe depthwise conv2d here https://github.com/szagoruyko/pyinn (needs CuPy)

@futurely
Copy link

futurely commented Aug 15, 2017

cudnn 7 grouped convolutions and cuda 9 hgemm fix
3522d3a

@soumith soumith added this to High Priority in Issue Status Aug 23, 2017
@soumith soumith moved this from High Priority to High Priority_ in Issue Status Aug 23, 2017
@amdcat
Copy link

amdcat commented Aug 29, 2017

Anything new about this problem?
depthwise convolution is still slow in my pytorch environment

@WendyShang
Copy link

Hi, based on this paper (https://arxiv.org/pdf/1608.04337.pdf, I think this is one of the earliest discovery of such separable convolution but somehow was ignored), we can benefit by adding flexibility in separating channels into non-overlapping subgroups and perform separate convolutions on each separate subgroup, thus not necessarily view each channel separately and perform convolution on each individual depth only. There are similar walkarounds to achieve such design as depth convolution, though also slow.

This adds more flexibility in architecture design and it would be great if PyTorch may consider a more flexible version than the original Torch depthwise separable convolution :)

@qianguih
Copy link

Hi,

Just want to check the status of this thread. I tested the latest version of pytorch. Looks like there is no update about separable convolution yet. Please correct me if I'm wrong. : )

@soumith soumith added this to nn / autograd / torch in Issue Categories Sep 11, 2017
@rkaplan
Copy link
Contributor

rkaplan commented Oct 2, 2017

Looking forward to this feature being implemented!

@soumith
Copy link
Member

soumith commented Oct 18, 2017

this is now added to master via #3057

@soumith soumith closed this as completed Oct 18, 2017
@qianguih
Copy link

Hi,
Thanks for your amazing work on this! Just updated my pytorch for a faster depthwise convolution. However, I coudn't find documentations about how to call the new depthwise conv function. Is there any example for this?

@colesbury
Copy link
Member

colesbury commented Oct 23, 2017

@qianguih use groups=in_channels=out_channels e.g.:

m = nn.Conv2d(128, 128, kernel_size=3, groups=128).cuda()

@qianguih
Copy link

@colesbury I see. Thank you!

@elliothe
Copy link

Since the pytorch and torch share the same backend THCUNN file, could I make this spatialdepthwiseconvolution function into Torch? since the orginal function in torch has GPU memory leak problem.

@killeent
Copy link
Contributor

@qianguih note that we do support having a depthwise multiplier, so groups=in_channels must be true, but out_channels can be any multiple of in_channels, e.g.:

m = nn.Conv2d(128, 256, kernel_size=3, groups=128).cuda()

is also valid.

@killeent
Copy link
Contributor

@elliothe one potential issue you will likely run into is that the existing SpatialDepthwiseConvolution in LuaTorch (the backing implementations are actually removed in this PR) is that the Lua layers have a differing format, see e.g. my thrown away PR here: torch/nn#1277. So you would have to handle some of these differences.

@elliothe
Copy link

@killeent Thanks for your answer! What I did is to remake the torch.cunn with the replaced THCUNN library. Then I receive the following errors when I try to call the SpatialDepthwiseConvolution function:

not found: THNN_CudaSpatialDepthWiseConvolution_updateOutput/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaSpatialDepthWiseConvolution_updateOutput	
not found: THNN_CudaSpatialDepthWiseConvolution_updateGradInput/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaSpatialDepthWiseConvolution_updateGradInput	
not found: THNN_CudaSpatialDepthWiseConvolution_accGradParameters/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaSpatialDepthWiseConvolution_accGradParameters	
not found: THNN_CudaDoubleSpatialDepthWiseConvolution_updateOutput/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaDoubleSpatialDepthWiseConvolution_updateOutput	
not found: THNN_CudaDoubleSpatialDepthWiseConvolution_updateGradInput/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaDoubleSpatialDepthWiseConvolution_updateGradInput	
not found: THNN_CudaDoubleSpatialDepthWiseConvolution_accGradParameters/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaDoubleSpatialDepthWiseConvolution_accGradParameters	
not found: THNN_CudaHalfSpatialDepthWiseConvolution_updateOutput/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaHalfSpatialDepthWiseConvolution_updateOutput	
not found: THNN_CudaHalfSpatialDepthWiseConvolution_updateGradInput/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaHalfSpatialDepthWiseConvolution_updateGradInput	
not found: THNN_CudaHalfSpatialDepthWiseConvolution_accGradParameters/home/elliot/torch/install/share/lua/5.1/nn/THNN.lua:108: failed to find function/global THNN_CudaHalfSpatialDepthWiseConvolution_accGradParameters	

I am wondering whether the formating problem you mentioned leads to this error.

@ngimel
Copy link
Collaborator

ngimel commented Oct 23, 2017

The problem here is that torch.nn is calling SpatialDepthWiseConvolution that was removed https://github.com/torch/nn/blob/master/SpatialDepthWiseConvolution.lua#L80. torch.nn has to be fixed to call SpatialDepthwiseConvolution routines.

@tstandley
Copy link

I'm not sure I buy the "bandwidth bound" explanation either. If you increase the kernal size from 3x3 to say 5x5, your operation needs the same amount of non-cache memory, but takes way longer.

I don't know how to program cuda, but here's how I see the operation working:
Let's say we run a Conv2d(1024, 1024, kernel_size=3, groups=1024) on a shape 20x20x1024 tensor:

Read a single channel from a single layer into cache. That's 400 floats (1600 bytes). It will fit into cache.
Read that channels worth of parameters into cache (that's just 9 floats)

Then we do the following 18x18 times:
Multiply those 9 floats (weights) by the appropriate 9 floats in the channel from cache.
Store the output in the appropriate cell in a 18x18 matrix in main memory.

Here we have one main memory read and one main memory write per channel per spacial location.

If we do the same operation with Conv2d(1024, 1024, kernel_size=5, groups=1024), we should still be doing only one read and one write per spacial location per channel from main memory.

The amount of time this takes should be insignificant next to the pointwise convolution that follows the channel wise convolution:
Conv2d(1024, 1024, kernel_size=1, groups=1)

Here we need to read a million floats from main memory and do 400 1024x1024 matrix multiplies.

The channel-wise operation shouldn't even register next to that.

I really wish we could get this working. We could have blazing fast convolutions with arbitrarily large kernels.

@SeungjunNah
Copy link

SeungjunNah commented Mar 27, 2018

Hi,

Currently, pytorch is using thnn implementation of depthwise convolution, thnn_conv_depthwise2d, instead of cudnn.

According to recent cudnn 7.1.1 release notes, it seems like cudnn has implemented group convolution for groupCount>1 for all forward & backward algorithms.
Also, current cudnn api reference says the winograd algorithm supports group count greater than 0.

Are there any plans to switch group convolution backend from thnn to cudnn?

@ibmua
Copy link

ibmua commented Apr 8, 2018

I'm also much interested in this. NVidia guy emailed me a year ago saying they're gonna implement grouping feature. Yet, it's still nowhere to be found.. This is one of the biggest things in DL in the last several years, yet it doesn't come into fruition. Hope you guys are gonna include it fast, and thanks @SeungjunNah for noticing it in the release notes. Great to know it's at least felly ready on cudnn level by now. Can't wait to see my GPUs chugging conv nets at 4x speed.

@colinfang
Copy link

colinfang commented Apr 24, 2018

To benefit from cudnn 7.1, is it as simple as removing if (params.is_depthwise(input, weight)) branch in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L337, so that it falls back to cudnn's native implementations? And turn on cudnn.benchmark = True for the WINOGRAD has a chance to kick in. Sadly in my case it is sill at least 2x slower than Tensorflow's own tailored cuda version.

@soumith
Copy link
Member

soumith commented Apr 24, 2018

@colinfang fwiw cudnn hasn't optimized for depthwise convolution, only optimized for group convolution (which is a general case of depthwise)

@ngimel
Copy link
Collaborator

ngimel commented Apr 24, 2018

cudnn has some kernels for depthwise-separable but on average they are no better than pytorch's. Feel free to bring over Tensorflow's tailored implementation to pytorch.

@colinfang
Copy link

colinfang commented Apr 24, 2018

In my case the input is 9x9, kernel is 5x5, padding=4 and the input channel == output channel == 50000 == groups, batchsize = 1. I tried both cudnn's native conv2d & conv_transpose_2d. The transpose version is slower in forward, but faster in backward. Overall they have similar performance. And thnn_conv_depthwise2d is similar to the fastest cudnn 7.1 's convolution version. Perhaps I didn't config correctly.

@fmassa
Copy link
Member

fmassa commented Apr 24, 2018

I haven't checked the TF implementation, but I believe one reason why it might be faster in his case is because his number of channels is fairly large, and TF use by default NHWC layout?

@ngimel
Copy link
Collaborator

ngimel commented Apr 24, 2018

Also, I don't know how heavily templated tf's is. In pytorch, 5x5 templates are not instantiated and fall to generic case, which might be 2 times slower than if they were instantiated.

@colinfang
Copy link

colinfang commented Apr 24, 2018

I think tf also only does 3x3 kernel templates. But it has special code path for image of size up to 32x32, not sure if it is relevant. (NHWC or NCHW don't make much difference in tf)
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc#L41 The fastest algorithm for forward in cudnn7.1 is conv2d_grouped_direct_kernel, (spotted from nvprof), not sure what that is. (It might be CUDNN_CONVOLUTION_FWD_ALGO_DIRECT but on the sdk it says it is not implementated in cudnn)

@ngimel
Copy link
Collaborator

ngimel commented Apr 24, 2018

It does not look like it's too hard to bring it over to pytorch, so if you want pytorch to have awesome performance on small inputs, that might be your best bet :-)

@assassint2017
Copy link

Hello everyone, I am really confused now. Can someone summarize in one sentence explain how to use high-performance depthwise convolution? The specified pytorch version? Cuda version or cudnn version?

@ezyang
Copy link
Contributor

ezyang commented Jul 14, 2018

PyTorch 0.4 with cuDNN should be sufficient.

@assassint2017
Copy link

cuDNN7?

@austingg
Copy link

latest cudnn7 patch has supported depthwise conv path

@Kongsea
Copy link
Contributor

Kongsea commented Aug 6, 2018

I have a question.
For example, if my input is (64, 128, 7, 7), then if using depthwise convolution, it will be: nn.Conv2d(128, 256, 3, groups=128).
However, in this situation, there will be 256 different kernels. But if I want this 256 kernels to be a some one kernel (i.e., sharing kernel weights). How to set the parameters in this situation?
Thank you.

@wandering007
Copy link
Contributor

wandering007 commented Aug 6, 2018

@Kongsea

m = nn.Conv2d(128, 256, 3, groups=128)
m.weight.data = m.weight.data[0].expand(256, *m.weight.shape[1:])

Expanded kernel weights are shared with the same memory.

@myih
Copy link

myih commented Aug 14, 2018

@austingg Hi I can't find anything about supporting depthwise convolution in Nvidia's documentations and release notes.

@Kongsea
Copy link
Contributor

Kongsea commented Aug 14, 2018

@wandering007 Thank you.

@austingg
Copy link

@myih cudnn v7 release note. Performance improvements for grouped convolutions when input channels and output channels per group are 1, 2, or 4 for the following algorithms

@SeungjunNah
Copy link

@ibmua According to cudnn v7.3.0 and cudnn v7.6.0 release notes, NVIDIA doesn't seem to be working on NCHW format.
Looks like it will take some time until cudnn group/depthwise convolution gets faster in PyTorch.

@XIEYUNSCUT
Copy link

@ibmua According to cudnn v7.3.0 and cudnn v7.6.0 release notes, NVIDIA doesn't seem to be working on NCHW format.
Looks like it will take some time until cudnn group/depthwise convolution gets faster in PyTorch.

According to cudnn v7.6.3 release notes,the latest cudnn seems to support depth-wise convolution well。can you check it?thank you very much!

@HansBambel
Copy link

@ibmua According to cudnn v7.3.0 and cudnn v7.6.0 release notes, NVIDIA doesn't seem to be working on NCHW format.
Looks like it will take some time until cudnn group/depthwise convolution gets faster in PyTorch.

Does this mean with the new PyTorch 1.5 ‘CHANNELS LAST’ MEMORY FORMAT we can expect group/depthwise convolutions to speed up?

@tstandley
Copy link

@HansBambel For what it's worth, I just checked for my implementation of Xception using channels last (and apex fp16). It was actually considerably slower. 50% or so slower. I also ran the apex imagenet example, and didn't see a noticeable speedup for resnet50 (which doesn't have grouped convolutions) with channels last. Both tested on two Titan RTX's with nvlink.

@VitalyFedyunin
Copy link
Contributor

VitalyFedyunin commented Apr 27, 2020

@tstandley can you please share more details how you run apex imagenet example, to identify problem. We can do it in separate issue/slack/forum or any other way.

@tstandley
Copy link

tstandley commented Apr 28, 2020

Ok, so I think I've traced the problem to two things. The first is that I have other operations in my network after the encoder portion that are way slower with channels_last. I'm not sure what they are, maybe some of the losses, or ConvTranspose2d. That wasn't tested on the apex imagenet example, but I was using the apex library with O1.

The other "issue" is that I'm using DataParallel, and not doing MPSG training. When I just run with a single GPU and Xception just on imagenet with apex (O1), I do get about a 20% speedup with channels last. That's awesome, but on mnasnet, I get a CUDA illegal memory access with channels last and I don't get that error without that flag. (that's probably an apex bug)

Also see my other bug on distributed training.

@ngimel
Copy link
Collaborator

ngimel commented Apr 28, 2020

ConvTranspose2d should be as fast as regular conv in channels last, please file an issue if you find out that's not the case. You can try using autograd profiler or pyprof (which unfortunately now lives only in one of the authors forks https://github.com/adityaiitb/pyprof2) to narrow down what's getting slower.
Also, please file an issue for your illegal memory access. It's unlikely that it's pure apex bug, likely something is buggy with channels-last in pytorch core.
Thank you, this is extremely helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Issue Status
High Priority_
Development

No branches or pull requests