Skip to content

Conversation

ngimel
Copy link
Collaborator

@ngimel ngimel commented Oct 24, 2017

batch size Input channels Height, Width kH, kW stride current time #3057 time Speed-up
1 32 112 3 1 0.0003635788 0.000360384 0.9912128843
1 64 112 3 2 0.000231576 0.0003566647 1.5401626686
1 128 56 3 1 0.0002005291 0.000383029 1.9100917868
1 128 56 3 2 0.000126195 0.0002591753 2.053769129
1 256 28 3 1 0.0001298428 0.0003213358 2.4748071979
1 256 28 3 2 0.0001278973 0.000162158 1.2678771158
1 512 14 3 1 0.0001259661 0.0001772976 1.4075027444
1 512 14 3 2 0.0001271725 0.0001322794 1.0401574803
1 1024 7 3 1 0.0001308775 0.0001342106 1.0254672642
64 32 112 3 1 0.0076541996 0.0177096987 2.3137231327
64 64 112 3 2 0.0076271296 0.0168428278 2.2082787077
64 128 56 3 1 0.0070373774 0.0168711519 2.3973635443
64 128 56 3 2 0.003865943 0.0091258287 2.3605699435
64 256 28 3 1 0.0035696888 0.0085421801 2.392976124
64 256 28 3 2 0.0021048355 0.005251503 2.4949707306
64 512 14 3 1 0.0019623423 0.0044877005 2.2869101627
64 512 14 3 2 0.0011884212 0.0028495121 2.3977290053
64 1024 7 3 1 0.0012278891 0.002658639 2.1652110428
128 32 112 3 1 0.0144340229 0.0354276562 2.4544547567
128 64 112 3 2 0.0154968691 0.0339261246 2.1892244415
128 128 56 3 1 0.014062891 0.0335231686 2.3838034831
128 128 56 3 2 0.0080575609 0.018141408 2.2514763643
128 256 28 3 1 0.0070493364 0.0169120741 2.3991015678
128 256 28 3 2 0.0041349077 0.0103336859 2.4991333709
128 512 14 3 1 0.003833847 0.0086662102 2.2604475533
128 512 14 3 2 0.0022731161 0.0053928566 2.3724510024
128 1024 7 3 1 0.0023054218 0.0047499275 2.0603290298

The biggest performance improvements are due to templating kernels. The benchmarks comparing to #3057 performance are above, I've taken sizes from https://github.com/marvis/pytorch-mobilenet/blob/master/benchmark.py#L19-L46 and some are slightly different from what was listed in #3057. Benchmarks are for 50 iterations, time is given per iteration.

Copy link
Contributor

@killeent killeent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel looks great! can you explain the changes with the loops / warps? Everything else is straightforward.

const int nwarps = blockDim.x / WARP_SIZE;
const int imageElements = outputWidth * outputHeight;
//use warp per item
for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){

This comment was marked as off-topic.

This comment was marked as off-topic.

indtmp1 = indtmp2;
indtmp2 = indtmp1/inputChannels;
const int c = indtmp1 - indtmp2 * inputChannels;
const int n = indtmp2;

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -155,7 +189,7 @@ __global__ void spatialDepthwiseConvolutionAccGradParameters(
int bidx = blockIdx.x;
int kW = bidx % kernelWidth;
int kH = (bidx / kernelWidth) % kernelHeight;
int ch = (bidx / channelStride) % kernelChannels;
int ch = (bidx / channelStride);

This comment was marked as off-topic.

This comment was marked as off-topic.

template <typename T, typename AccT, typename IndexType>

const int WARP_SIZE = 32;
const int MAX_BLOCK_SIZE = 256;

This comment was marked as off-topic.

This comment was marked as off-topic.

const IndexType offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
#pragma unroll
for (int kH = 0; kH < KH_LIMIT; ++kH) {
#pragma unroll

This comment was marked as off-topic.

This comment was marked as off-topic.

const int laneId = threadIdx.x % WARP_SIZE;
const int batch = threadIdx.x / WARP_SIZE;
const int nwarps = blockDim.x / WARP_SIZE;
const int imageElements = outputWidth * outputHeight;

This comment was marked as off-topic.

@@ -164,25 +198,31 @@ __global__ void spatialDepthwiseConvolutionAccGradParameters(
AccT grad = ScalarConvert<float, AccT>::to(0.0);

// Block-stride loop over the number of elements we need to reduce
for (IndexType idx = threadIdx.x; idx < blockElements; idx += blockDim.x) {
const int laneId = threadIdx.x % WARP_SIZE;

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang ezyang merged commit f9d002d into pytorch:master Oct 24, 2017
@ngimel ngimel deleted the depthwise branch October 26, 2017 00:54
@elliothe
Copy link

elliothe commented Nov 5, 2017

In order to use this new depthwise conv, do I have to install the cudnn v7?

@soumith
Copy link
Member

soumith commented Nov 6, 2017

@elliothe no.

@elliothe
Copy link

elliothe commented Nov 6, 2017

@soumith in order to get the fastest depthwise convolution in pytorch, what should I do now? Since I have installed the lastest pytorch from source, with CUDA 8.0 and cudnn v6. However, if I try to replace the normal spatial conv through combining depthwise conv and summation, the simulation speed is more than tens times slower. Is this speed degradation normal?

@soumith
Copy link
Member

soumith commented Nov 6, 2017

@elliothe how did you actually "replace" normal spatial conv with depthwise conv?

All you need to do to use depthwise conv is by using the groups parameter of nn.Conv2d

@elliothe
Copy link

elliothe commented Nov 6, 2017

@soumith Normal spatial convolution is nn.Conv2d(input_channels,output_channels), where I rewirte it as:

nn.Conv2d(input_channels, input_channels*output_channels, groups = input_channels)
nn.ReLU()   # wanna test the intermediate activation function
nn.view(-1, input_channels, output_channels, Hin, Win).sum(1)

When the input_channels and output_channels are large, the speed degrades a lot.

@csarofeen
Copy link
Contributor

Wouldn't you also see this "degradation" with nn.Conv2d(input_channels, input_channels*output_channels)?

@ngimel
Copy link
Collaborator Author

ngimel commented Nov 6, 2017

That's a fair point, though, currently when deciding between cudnn and depthwise-separable kernels the only thing that's checked is that number_of_groups = input_channels, however, when depth_multiplier is large (that is, number of output channels = < large numer > * number of input channels), using cudnn and just dispatching a number of regular convolution kernels might be faster. Are layers like this encountered in practice?

@elliothe
Copy link

elliothe commented Nov 6, 2017

@ngimel I believe this is going to be a research direction, since the depthwise-separable kernels become popular due to some recent the-state-of-art NNs, like mobilenets, xception. It would be very helpful for researchers, if pytorch can support such kind of function efficiently.

@ngimel
Copy link
Collaborator Author

ngimel commented Nov 6, 2017

In Mobilenet and xception the number of output channels is equal to the number of input channels (may be in some xception layers it's 2*input_channels, I don't remember off the top of my head). For those layers the kernels in pytorch provide much better performance than what used to be available (the benchmarks for mobilenet comparing with the previous cudnn implementation are in #3057, this PR slightly improves on #3057).

@KeCh96
Copy link

KeCh96 commented Feb 3, 2018

I have upgraded my pytorch to 0.3.0, but I found m = nn.Conv2d(128, 256, kernel_size=3, groups=128) is still 2 times slower than m = nn.Conv2d(128, 256, kernel_size=3). I am really confused by this problem, should I need to upgrade pytorch to other version? Do I need cuda 9? @ngimel

@fmassa
Copy link
Member

fmassa commented Feb 3, 2018

@KeCh96 I replied on your other post. Please refrain from posting the same question in multiple different places

facebook-github-bot pushed a commit that referenced this pull request Jul 9, 2019
Summary:
This PR activates faster depthwise convolution kernels for Volta and Turing GPUs using cudnn >= 7600.
The script to benchmark the current PyTorch master branch and this PR branch can be found [here](https://gist.github.com/ptrblck/4590cf20721d8f43296c9903abd4a774).
(50 warmup iterations, 1000 iterations for timing)

I've used #3265 to create a similar benchmark and added a few additional setups.
Since the results are quite long, I've uploaded them in a spreadsheet [here](https://docs.google.com/spreadsheets/d/13ByXcqg7LQUr3DVG3XpLwnJ-CXg3GUZJ3puyTMw9n2I/edit?usp=sharing).
Times are given in ms per iteration.
We've benchmarked this PR on a DGX1 using V100 GPUs.

The current workload check in `check_cudnn_depthwise_workload` is quite long and can be moved to another file, if wanted.

CC ngimel (Thanks for the support while benchmarking it ;) )
Pull Request resolved: #22302

Differential Revision: D16115057

Pulled By: ezyang

fbshipit-source-id: bad184658518e73b4d6b849d77e408f5a7a757de
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 9, 2019
Summary:
This PR activates faster depthwise convolution kernels for Volta and Turing GPUs using cudnn >= 7600.
The script to benchmark the current PyTorch master branch and this PR branch can be found [here](https://gist.github.com/ptrblck/4590cf20721d8f43296c9903abd4a774).
(50 warmup iterations, 1000 iterations for timing)

I've used pytorch/pytorch#3265 to create a similar benchmark and added a few additional setups.
Since the results are quite long, I've uploaded them in a spreadsheet [here](https://docs.google.com/spreadsheets/d/13ByXcqg7LQUr3DVG3XpLwnJ-CXg3GUZJ3puyTMw9n2I/edit?usp=sharing).
Times are given in ms per iteration.
We've benchmarked this PR on a DGX1 using V100 GPUs.

The current workload check in `check_cudnn_depthwise_workload` is quite long and can be moved to another file, if wanted.

CC ngimel (Thanks for the support while benchmarking it ;) )
Pull Request resolved: pytorch/pytorch#22302

Differential Revision: D16115057

Pulled By: ezyang

fbshipit-source-id: bad184658518e73b4d6b849d77e408f5a7a757de
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants