New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed of group convolution #10229

Closed
Johnccl opened this Issue Aug 4, 2018 · 13 comments

Comments

Projects
None yet
6 participants
@Johnccl
Copy link

Johnccl commented Aug 4, 2018

I'm trying to use group convolution in my model, but I find that the speed of group convolution layer is too slow. I tested different groups settings on titan x maxwell. The result shows that when groups==input_channels, the speed is very fast even with different output_channels. But other groups will slow down the forward heavily. Here is my test code, if I was wrong, please correct me.

import time
import math


def count(a, m):
    t0 = time.time()
    for i in range(1000):
        b = m(a)
    return time.time() - t0

def th_test(in_channel=1024, size=19, batch=1):
    import torch as th
    import torch.nn as nn
    x = th.rand((batch, in_channel, size, size))
    x = x.cuda()
    out_channels_lsit = [256, 512, 1024, 2048]
    for out_channels in out_channels_lsit:
        n = int(math.log(min(in_channel, out_channels), 2))+1
        for i in range(n):
            g = int(math.pow(2, i))
            m = nn.Conv2d(in_channel, out_channels=out_channels, kernel_size=3, padding=1, groups=g).cuda()
            t = count(x, m)
            print('OutChannel:{}, Group:{}, Time:{}'.format(out_channels, g, t))

def mx_test(in_channel=1024, size=19, batch=1):
    import mxnet as mx
    import mxnet.ndarray as nd
    from mxnet.gluon import nn
    x = nd.uniform(-1, 1, (batch, in_channel, size, size), mx.gpu(0))

    out_channels_lsit = [256, 512, 1024, 2048]
    for out_channels in out_channels_lsit:
        n = int(math.log(min(in_channel, out_channels), 2)) + 1
        for i in range(n):
            g = int(math.pow(2, i))
            m = nn.Conv2D(out_channels, kernel_size=3, padding=(1, 1), groups=g)
            m.initialize(ctx=[mx.gpu(0)])
            t = count(x, m)
            print('OutChannel:{}, Group:{}, Time:{}'.format(out_channels, g, t))

if __name__=='__main__':
    print('Pytorch testing:')
    th_test(1024, 19, 1)
    print('Mxnet testing:')
    mx_test(1024, 19, 1)

The input tensor is (1, 1024, 19, 19), and the output is:

Pytorch testing:
OutChannel:256, Group:1, Time:0.33240747451782227
OutChannel:256, Group:2, Time:0.20134592056274414
OutChannel:256, Group:4, Time:0.17717289924621582
OutChannel:256, Group:8, Time:0.22671079635620117
OutChannel:256, Group:16, Time:0.31947803497314453
OutChannel:256, Group:32, Time:0.5367441177368164
OutChannel:256, Group:64, Time:1.0223586559295654
OutChannel:256, Group:128, Time:1.7879347801208496
OutChannel:256, Group:256, Time:3.656572103500366
OutChannel:512, Group:1, Time:0.6438479423522949
OutChannel:512, Group:2, Time:0.39124298095703125
OutChannel:512, Group:4, Time:0.2587141990661621
OutChannel:512, Group:8, Time:0.2279491424560547
OutChannel:512, Group:16, Time:0.3441953659057617
OutChannel:512, Group:32, Time:0.5649154186248779
OutChannel:512, Group:64, Time:0.9861962795257568
OutChannel:512, Group:128, Time:1.8548946380615234
OutChannel:512, Group:256, Time:3.8017823696136475
OutChannel:512, Group:512, Time:7.314914703369141
OutChannel:1024, Group:1, Time:1.4008221626281738
OutChannel:1024, Group:2, Time:0.7930033206939697
OutChannel:1024, Group:4, Time:0.46567654609680176
OutChannel:1024, Group:8, Time:0.34047675132751465
OutChannel:1024, Group:16, Time:0.3591480255126953
OutChannel:1024, Group:32, Time:0.5814201831817627
OutChannel:1024, Group:64, Time:1.0014891624450684
OutChannel:1024, Group:128, Time:1.896963119506836
OutChannel:1024, Group:256, Time:3.7428271770477295
OutChannel:1024, Group:512, Time:7.297173023223877
OutChannel:1024, Group:1024, Time:0.03275704383850098
OutChannel:2048, Group:1, Time:3.1169140338897705
OutChannel:2048, Group:2, Time:1.7458274364471436
OutChannel:2048, Group:4, Time:0.9465909004211426
OutChannel:2048, Group:8, Time:0.5486705303192139
OutChannel:2048, Group:16, Time:0.4627697467803955
OutChannel:2048, Group:32, Time:0.6334977149963379
OutChannel:2048, Group:64, Time:1.0478804111480713
OutChannel:2048, Group:128, Time:1.9274232387542725
OutChannel:2048, Group:256, Time:3.762500524520874
OutChannel:2048, Group:512, Time:7.371256113052368
OutChannel:2048, Group:1024, Time:0.032269954681396484
Mxnet testing:
OutChannel:256, Group:1, Time:0.09648561477661133
OutChannel:256, Group:2, Time:0.09166908264160156
OutChannel:256, Group:4, Time:0.1002199649810791
OutChannel:256, Group:8, Time:0.09164929389953613
OutChannel:256, Group:16, Time:0.09277224540710449
OutChannel:256, Group:32, Time:0.09213089942932129
OutChannel:256, Group:64, Time:0.09251523017883301
OutChannel:256, Group:128, Time:0.09215164184570312
OutChannel:256, Group:256, Time:0.09244847297668457
OutChannel:512, Group:1, Time:0.09291553497314453
OutChannel:512, Group:2, Time:0.0935969352722168
OutChannel:512, Group:4, Time:0.09287786483764648
OutChannel:512, Group:8, Time:0.09386038780212402
OutChannel:512, Group:16, Time:0.09314537048339844
OutChannel:512, Group:32, Time:0.09238576889038086
OutChannel:512, Group:64, Time:0.09345221519470215
OutChannel:512, Group:128, Time:0.09265351295471191
OutChannel:512, Group:256, Time:0.09274530410766602
OutChannel:512, Group:512, Time:0.0929572582244873
OutChannel:1024, Group:1, Time:0.09283661842346191
OutChannel:1024, Group:2, Time:0.0920555591583252
OutChannel:1024, Group:4, Time:0.09308600425720215
OutChannel:1024, Group:8, Time:0.09451031684875488
OutChannel:1024, Group:16, Time:0.09326553344726562
OutChannel:1024, Group:32, Time:0.09737563133239746
OutChannel:1024, Group:64, Time:0.09368896484375
OutChannel:1024, Group:128, Time:0.09325265884399414
OutChannel:1024, Group:256, Time:0.09457755088806152
OutChannel:1024, Group:512, Time:0.0924680233001709
OutChannel:1024, Group:1024, Time:0.09270071983337402
OutChannel:2048, Group:1, Time:0.09429264068603516
OutChannel:2048, Group:2, Time:0.09288167953491211
OutChannel:2048, Group:4, Time:0.09360265731811523
OutChannel:2048, Group:8, Time:0.10918331146240234
OutChannel:2048, Group:16, Time:0.10007429122924805
OutChannel:2048, Group:32, Time:0.09536433219909668
OutChannel:2048, Group:64, Time:0.09516191482543945
OutChannel:2048, Group:128, Time:0.09293746948242188
OutChannel:2048, Group:256, Time:0.09230923652648926
OutChannel:2048, Group:512, Time:0.09552860260009766
OutChannel:2048, Group:1024, Time:0.09382128715515137

I noticed that the groups is very sensitive to pytoch, why is this happened? Is there anything I missed? Please correct me. Thanks.

@soumith

This comment has been minimized.

Copy link
Member

soumith commented Aug 4, 2018

you should either run your script with CUDA_LAUNCH_BLOCKING=1 python yourscript.py, or change count in the following way, to get correct timings:

def count(a, m):
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(1000):
        b = m(a)
    torch.cuda.synchronize()
    return time.time() - t0
@Johnccl

This comment has been minimized.

Copy link
Author

Johnccl commented Aug 5, 2018

Thank you for your reply @soumith . I tried the two methods respectively, but it seems to take more time.

@soumith

This comment has been minimized.

Copy link
Member

soumith commented Aug 5, 2018

can you update the timings that you found with the new methods.

Also, some information:

  • how did you install pytorch
  • how did you install mxnet
  • what's the output of
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
@Johnccl

This comment has been minimized.

Copy link
Author

Johnccl commented Aug 5, 2018

The timings with new methods are just followed by your advise @soumith :
my first try is run script with:

CUDA_LAUNCH_BLOCKING=1 python3 speed_test.py

I also tried to change the timings :

def count(a, m):
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(1000):
        b = m(a)
    torch.cuda.synchronize()
    return time.time() - t0

The result shows that both the two methods take more time in group convolution, only when groups==input_channels it runs very fast (input tensor is (1, 1024, 19, 19)). The output is like this:

Pytorch testing:
OutChannel:256, Group:1, Time:0.49562525749206543
OutChannel:256, Group:2, Time:0.2684659957885742
OutChannel:256, Group:4, Time:0.19190192222595215
OutChannel:256, Group:8, Time:0.23194408416748047
OutChannel:256, Group:16, Time:0.3203761577606201
OutChannel:256, Group:32, Time:0.5373902320861816
OutChannel:256, Group:64, Time:0.9720823764801025
OutChannel:256, Group:128, Time:1.801313877105713
OutChannel:256, Group:256, Time:3.6272590160369873
OutChannel:512, Group:1, Time:1.0002150535583496
OutChannel:512, Group:2, Time:0.4958508014678955
OutChannel:512, Group:4, Time:0.2823457717895508
OutChannel:512, Group:8, Time:0.24367904663085938
OutChannel:512, Group:16, Time:0.3579528331756592
OutChannel:512, Group:32, Time:0.5675616264343262
OutChannel:512, Group:64, Time:0.9823338985443115
OutChannel:512, Group:128, Time:1.8257930278778076
OutChannel:512, Group:256, Time:3.641608476638794
OutChannel:512, Group:512, Time:7.346875190734863
OutChannel:1024, Group:1, Time:2.176182746887207
OutChannel:1024, Group:2, Time:1.009695291519165
OutChannel:1024, Group:4, Time:0.5244359970092773
OutChannel:1024, Group:8, Time:0.363513708114624
OutChannel:1024, Group:16, Time:0.3781554698944092
OutChannel:1024, Group:32, Time:0.5965301990509033
OutChannel:1024, Group:64, Time:1.01088547706604
OutChannel:1024, Group:128, Time:1.898212194442749
OutChannel:1024, Group:256, Time:3.7377984523773193
OutChannel:1024, Group:512, Time:7.350734233856201
OutChannel:1024, Group:1024, Time:0.044171810150146484
OutChannel:2048, Group:1, Time:4.684430837631226
OutChannel:2048, Group:2, Time:2.1917829513549805
OutChannel:2048, Group:4, Time:1.1301765441894531
OutChannel:2048, Group:8, Time:0.6398894786834717
OutChannel:2048, Group:16, Time:0.5236623287200928
OutChannel:2048, Group:32, Time:0.6961493492126465
OutChannel:2048, Group:64, Time:1.1796467304229736
OutChannel:2048, Group:128, Time:2.0464370250701904
OutChannel:2048, Group:256, Time:3.7772271633148193
OutChannel:2048, Group:512, Time:7.414055824279785
OutChannel:2048, Group:1024, Time:0.09024667739868164

And here is the information of my pytorch:

  • I installed pytorch by: sudo pip3 install torch
  • The mxnet was compiled by source
  • the output of python3 collect_env.py is
Collecting environment information...
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 8.0.61

OS: Ubuntu 14.04.3 LTS
GCC version: (Ubuntu 4.8.4-2ubuntu1~14.04.1) 4.8.4
CMake version: version 3.5.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 8.0.61
GPU models and configuration: GPU 0: GeForce GTX TITAN X
Nvidia driver version: 375.66
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21
/usr/local/MATLAB/R2016b/bin/glnxa64/libcudnn.so.4.0.7
/usr/local/cuda-7.5/lib64/libcudnn.so.5.1.3
/usr/local/cuda-7.5/lib64/libcudnn_static.a
/usr/local/cuda-8.0/lib64/libcudnn.so.5.0.5
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/lib/python2.7/dist-packages/torch/lib/libcudnn-900fef33.so.7.0.5

Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect

I checked cudnn version by: cat /usr/local/cuda/include/cudnn.h | grep CUDNN_MAJOR -A 2. It's version is 5.0.5.
Any advise? Thank you.

@ailzhang ailzhang self-assigned this Aug 7, 2018

@ailzhang

This comment has been minimized.

Copy link
Contributor

ailzhang commented Aug 7, 2018

I can repro a similar perf with cuda9.0 cudnn7.0.
I profiled both pytorch and mxnet on a single input(in_channel=1024, size=19, batch=1, kernel_size=3, padding=1). outchannel=64, group=1
For Pytorch, we have 92% compute spent on maxwell_scudnn_winograd_128*128_ldg1_ldg4_tile148n_nt, with 773 invocations and avg duration of 256us.
For mxnet, we have 22.4% compute spent on maxwell_cgemm_32*64_tn, with only 2 invocations and avg duration of 2.78ms.
In general Pytorch seems to be launching a lot more kernels than mxnet for the same input/output pair.
cc: @ngimel @apaszke do you have any insight on this?

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented Aug 7, 2018

@ailzhang for your convolution in pytorch on P100 I see

GPU activities:   93.86%  243.44ms      1001  243.20us  224.80us  259.40us  maxwell_scudnn_winograd_128x128_ldg1_ldg4_tile148n_nt

1001 invocations of a kernel, which is to be expected for 1000 convolutions run in a loop + warm up
I could not profile mxnet with the given script, it errored out, and the profiles that it did produce never had 1000 invocations of the same kernel (which is to be expected if the timing loop did run correctly), so I would not trust mxnet timings. Your results, where mxnet launches just 2 kernels, also are suspicious.
For actual grouped convolutions (group !=1) @Johnccl is correct that perf is much better for groups == input_channels, because in this case pytorch's own kernel is called. In all other cases, the call is sent to cudnn and the performance is whatever cudnn provides.
My script, slightly modified from original, is below

import time
import math
import torch
torch.backends.cudnn.benchmark = True
def count(a, m):
    torch.cuda.synchronize()
    t0 = time.time()
    for i in range(1000):
        b = m(a)
    torch.cuda.synchronize()
    return time.time() - t0

def th_test(in_channel=1024, size=19, batch=1):
    import torch as th
    import torch.nn as nn
    x = th.rand((batch, in_channel, size, size))
    x = x.cuda()
    out_channels_lsit = [1024] #256, 512, 1024, 2048]
    for out_channels in out_channels_lsit:
        n = int(math.log(min(in_channel, out_channels), 2))+1
        for i in range(n):
            g = int(math.pow(2, i))
            m = nn.Conv2d(in_channel, out_channels=out_channels, kernel_size=3, padding=1, groups=g).cuda()
            b = m(x)#warm up
            t = count(x, m)
            print('OutChannel:{}, Group:{}, Time:{}'.format(out_channels, g, t))

if __name__=='__main__':
    print('Pytorch testing:')
    th_test(1024, 19, 1)
@ailzhang

This comment has been minimized.

Copy link
Contributor

ailzhang commented Aug 7, 2018

Hi @ngimel , thanks for confirming. Yea I also saw mxnet error out with some arguments. I just picked one that worked fine and did the profiling.

Here is my script just in case it helps. I'm taking a look at mxnet implementation to see if they have any customized kernel for groups != input_channels, since in those cases mxnet is much faster.

import time
import math


def count(a, m):
    t0 = time.time()
    for i in range(1000):
        b = m(a)
    return time.time() - t0

def th_test(in_channel=1024, size=19, batch=1):
    import torch as th
    import torch.nn as nn
    x = th.rand((batch, in_channel, size, size))
    x = x.cuda()
    # out_channels_lsit = [256, 512, 1024, 2048]
    out_channels_lsit = [64]
    for out_channels in out_channels_lsit:
        n = int(math.log(min(in_channel, out_channels), 2))+1
        for i in range(1):
            g = int(math.pow(2, i))
            m = nn.Conv2d(in_channel, out_channels=out_channels, kernel_size=3, padding=1, groups=g).cuda()
            t = count(x, m)
            print('OutChannel:{}, Group:{}, Time:{}'.format(out_channels, g, t))

def mx_test(in_channel=1024, size=19, batch=1):
    import mxnet as mx
    import mxnet.ndarray as nd
    from mxnet.gluon import nn
    x = nd.uniform(-1, 1, (batch, in_channel, size, size), mx.gpu(0))

    # out_channels_lsit = [256, 512, 1024, 2048]
    out_channels_lsit = [64]
    for out_channels in out_channels_lsit:
        n = int(math.log(min(in_channel, out_channels), 2)) + 1
        for i in range(1):
            g = int(math.pow(2, i))
            m = nn.Conv2D(out_channels, kernel_size=3, padding=(1, 1), groups=g)
            m.initialize(ctx=[mx.gpu(0)])
            t = count(x, m)
            print('OutChannel:{}, Group:{}, Time:{}'.format(out_channels, g, t))

if __name__=='__main__':
    # print('Pytorch testing:')
    # th_test(1024, 19, 1)
    print('Mxnet testing:')
    mx_test(1024, 19, 1)
@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented Aug 7, 2018

mxnet may be faster, but this script is not timing it properly (there are no necessary synchronizations, even with CUDA_LAUNCH_BLOCKING=1 it is not calling any convolution kernel for a thousand times, which it should), so we have no idea what mxnet's actual performance is.

@ailzhang

This comment has been minimized.

Copy link
Contributor

ailzhang commented Aug 7, 2018

Talked to @ngimel offline, the perf diff should be a false alarm that mxnet didn't really run the conv kernels with the script above.

For a single shape, if you print the output tensor, that actually enforce mxnet to run 1000 invocations of maxwell_scudnn_winograd_128x128_ldg1_ldg4_tile148n_nt as well, with a similar avg. duration.

I'm not very familiar with mxnet syntax, adding sth like torch.cuda.synchronize() in mxnet part should fix the gap. Or the way I verified it is like, before I enforced the print of output tensor, there's simply no conv kernel called in the profiler(and it's very fast, the time is constant even when you increase the iteration number). After I enforced the print of output tensor, maxwell_scudnn_winograd_128x128_ldg1_ldg4_tile148n_nt kernel is called 1000 times as Pytorch. These indicate that mxnet returns before the kernel was actually run. @Johnccl If you have an idea how to synchronize cuda in mxnet, please let us know. Otherwise from the profiling, it's safe to say this is a false alarm imho.

Here's a hacky script I used to generate the nvprof result.
[EDIT] deleted this script since I found the right command to enforce kernel launch in MXnet. Posting the perf numbers below.

@ailzhang

This comment has been minimized.

Copy link
Contributor

ailzhang commented Aug 7, 2018

Found the right command to enforce kernel launch (export MXNET_ENGINE_TYPE=NaiveEngine) and here's the update perf. The numbers are close, especially the kernel time spent shown in nvprof, although it's still a bit slower in general. @ngimel any idea how this can be further improved? :)

 λ ~ python 10229.py
/private/home/ailzhang/miniconda3/lib/python3.6/site-packages/urllib3/contrib/pyopenssl.py:46: DeprecationWarning: OpenSSL.rand is deprecated - you should use os.urandom instead
  import OpenSSL.SSL
Pytorch testing:
OutChannel:256, Group:1, Time:0.8669874668121338
OutChannel:256, Group:2, Time:0.6110000610351562
OutChannel:256, Group:4, Time:0.40593385696411133
OutChannel:256, Group:8, Time:0.19748449325561523
OutChannel:256, Group:16, Time:0.30086755752563477
OutChannel:256, Group:32, Time:0.5136899948120117
OutChannel:256, Group:64, Time:0.9273054599761963
OutChannel:256, Group:128, Time:1.7813796997070312
OutChannel:256, Group:256, Time:0.05748271942138672
OutChannel:512, Group:1, Time:0.3776583671569824
OutChannel:512, Group:2, Time:0.6384704113006592
OutChannel:512, Group:4, Time:0.43076610565185547
OutChannel:512, Group:8, Time:0.23347163200378418
OutChannel:512, Group:16, Time:0.3054044246673584
OutChannel:512, Group:32, Time:0.5156643390655518
OutChannel:512, Group:64, Time:0.9284119606018066
OutChannel:512, Group:128, Time:1.7722992897033691
OutChannel:512, Group:256, Time:0.05848836898803711
OutChannel:512, Group:512, Time:0.058203935623168945
OutChannel:1024, Group:1, Time:0.7848114967346191
OutChannel:1024, Group:2, Time:0.7653074264526367
OutChannel:1024, Group:4, Time:0.5440354347229004
OutChannel:1024, Group:8, Time:0.29901933670043945
OutChannel:1024, Group:16, Time:0.30808281898498535
OutChannel:1024, Group:32, Time:0.5171351432800293
OutChannel:1024, Group:64, Time:0.9506747722625732
OutChannel:1024, Group:128, Time:1.800419807434082
OutChannel:1024, Group:256, Time:0.05892205238342285
OutChannel:1024, Group:512, Time:0.05730032920837402
OutChannel:1024, Group:1024, Time:0.031682729721069336
OutChannel:2048, Group:1, Time:1.6538004875183105
OutChannel:2048, Group:2, Time:1.499765157699585
OutChannel:2048, Group:4, Time:0.8753867149353027
OutChannel:2048, Group:8, Time:0.5235047340393066
OutChannel:2048, Group:16, Time:0.36252284049987793
OutChannel:2048, Group:32, Time:0.5182161331176758
OutChannel:2048, Group:64, Time:0.9530549049377441
OutChannel:2048, Group:128, Time:1.825690507888794
OutChannel:2048, Group:256, Time:0.06976127624511719
OutChannel:2048, Group:512, Time:0.06025123596191406
OutChannel:2048, Group:1024, Time:0.03231954574584961
Mxnet testing:
[13:16:18] src/engine/engine.cc:55: MXNet start using engine: NaiveEngine
[13:16:20] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
OutChannel:256, Group:1, Time:0.6017172336578369
OutChannel:256, Group:2, Time:0.414600133895874
OutChannel:256, Group:4, Time:0.37932562828063965
OutChannel:256, Group:8, Time:0.4258706569671631
OutChannel:256, Group:16, Time:0.4455239772796631
OutChannel:256, Group:32, Time:0.5849206447601318
OutChannel:256, Group:64, Time:0.8991494178771973
OutChannel:256, Group:128, Time:1.6287968158721924
OutChannel:256, Group:256, Time:2.8902885913848877
OutChannel:512, Group:1, Time:1.1513350009918213
OutChannel:512, Group:2, Time:0.5768940448760986
OutChannel:512, Group:4, Time:0.41323256492614746
OutChannel:512, Group:8, Time:0.3937795162200928
OutChannel:512, Group:16, Time:0.4500117301940918
OutChannel:512, Group:32, Time:0.5608577728271484
OutChannel:512, Group:64, Time:0.9582664966583252
OutChannel:512, Group:128, Time:1.5262198448181152
OutChannel:512, Group:256, Time:2.826936960220337
OutChannel:512, Group:512, Time:5.390891075134277
OutChannel:1024, Group:1, Time:2.3451201915740967
OutChannel:1024, Group:2, Time:1.2541780471801758
OutChannel:1024, Group:4, Time:0.5795691013336182
OutChannel:1024, Group:8, Time:0.4481794834136963
OutChannel:1024, Group:16, Time:0.46959853172302246
OutChannel:1024, Group:32, Time:0.5889122486114502
OutChannel:1024, Group:64, Time:0.9274506568908691
OutChannel:1024, Group:128, Time:1.717729091644287
OutChannel:1024, Group:256, Time:2.82771897315979
OutChannel:1024, Group:512, Time:5.3689186573028564
OutChannel:1024, Group:1024, Time:0.8016271591186523
OutChannel:2048, Group:1, Time:2.6470489501953125
OutChannel:2048, Group:2, Time:2.0326945781707764
OutChannel:2048, Group:4, Time:0.8291740417480469
OutChannel:2048, Group:8, Time:0.6124267578125
OutChannel:2048, Group:16, Time:0.5176849365234375
OutChannel:2048, Group:32, Time:0.6203253269195557
OutChannel:2048, Group:64, Time:0.9143955707550049
OutChannel:2048, Group:128, Time:1.616405963897705
OutChannel:2048, Group:256, Time:2.9386751651763916
OutChannel:2048, Group:512, Time:5.4872753620147705
OutChannel:2048, Group:1024, Time:10.75217056274414
[13:19:39] src/engine/naive_engine.cc:55: Engine shutdown
@ailzhang

This comment has been minimized.

Copy link
Contributor

ailzhang commented Aug 14, 2018

Closing, please feel free to reopen if you see any further issues.

@ailzhang ailzhang closed this Aug 14, 2018

@kampta

This comment has been minimized.

Copy link

kampta commented Jan 30, 2019

@ailzhang

I get a bit different numbers for pytorch using the script you shared (I'm using a P6000 card, torch 0.4.1). I am getting

Script

import time
import math

def count(a, m):
    t0 = time.time()
    for i in range(1000):
        b = m(a)
    return time.time() - t0

def th_test(in_channel=1024, size=19, batch=1):
    import torch as th
    import torch.nn as nn
    x = th.rand((batch, in_channel, size, size))
    x = x.cuda()
    out_channels_list = [256]
    for out_channels in out_channels_list:
        n = int(math.log(min(in_channel, out_channels), 2))+1
        for i in range(9):
            g = int(math.pow(2, i))
            m = nn.Conv2d(in_channel, out_channels=out_channels, kernel_size=3, padding=1, groups=g).cuda()
            t = count(x, m)
            print('OutChannel:{}, Group:{}, Time:{}, Params:{}'.format(out_channels, g, t, next(iter(m.parameters())).size()))

if __name__=='__main__':
    print('Pytorch testing:')
    th_test(1024, 19, 1)

Output

Pytorch testing:
OutChannel:256, Group:1, Time:0.20193028450012207, Params:torch.Size([256, 1024, 3, 3])
OutChannel:256, Group:2, Time:0.09860920906066895, Params:torch.Size([256, 512, 3, 3])
OutChannel:256, Group:4, Time:0.11520171165466309, Params:torch.Size([256, 256, 3, 3])
OutChannel:256, Group:8, Time:0.14992403984069824, Params:torch.Size([256, 128, 3, 3])
OutChannel:256, Group:16, Time:0.21524858474731445, Params:torch.Size([256, 64, 3, 3])
OutChannel:256, Group:32, Time:0.3496212959289551, Params:torch.Size([256, 32, 3, 3])
OutChannel:256, Group:64, Time:0.6045060157775879, Params:torch.Size([256, 16, 3, 3])
OutChannel:256, Group:128, Time:1.1571979522705078, Params:torch.Size([256, 8, 3, 3])
OutChannel:256, Group:256, Time:2.2207155227661133, Params:torch.Size([256, 4, 3, 3])

Is it expected for nn.conv2d to take more time with more number of groups (and fewer parameters!) or am I doing something wrong?

@kampta

This comment has been minimized.

Copy link

kampta commented Feb 13, 2019

Requesting comment @Johnccl @ailzhang @soumith

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment