Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerated 3D depthwise convolution implementation #40801

Closed
wants to merge 1 commit into from

Conversation

linziyi96
Copy link
Contributor

This is another attempt to resolve the slow 3D depthwise (a.k.a. channel-wise) convolution of cuDNN (previously at #31885). 3D depthwise convolutions are seeing increasing use in various recent works (e.g. https://arxiv.org/abs/2004.04730, https://arxiv.org/abs/1904.02811), but currently, cuDNN 3D depthwise convolution is usually even slower than the regular dense convolutions, making it practically very time consuming to train such models.

I have tried to implement a CUDA kernel and found it bringing well noticeable performance gain. It is taking as reference the 2D implementations (PyTorch: https://github.com/pytorch/pytorch/blob/master/aten/src/THCUNN/SpatialDepthwiseConvolution.cuTensorFlow: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/depthwise_conv_op_gpu.h) and existing attempts (#31885), and further tuned with nvidia profiler. The acceleration of forward+backward, according to my tests, is at least 2.5x and in many cases 5x~20x. The timing results are attached in the end. They are intended to cover some most common use cases, but due to the limitation of time and access to different devices, it may still be some way from being complete. Any further tests with different configurations or different devices are certainly welcomed.

Although it has already been used internally for some time, I'm aware that it might need some further refinement for dealing with some less common cases and better following general PyTorch coding principles. I've listed some TODO points I've come up with till now, some of which may include questions I'm now not very clear about:

  1. Test cases. Since this is a change of implementation of an existing algorithm, is the existing test cases sufficient to cover, or do I need to add new test cases specific to this implementation?
  2. Handling of extreme tensor size/configuration. Currently any input/output/weight tensor greater than 2^31-1 is rejected. And very large padding/dilation may not be considered carefully yet. My question here is if there are general rules about how much we shall tolerate these cases? (I guess that a 8GB tensor is actually not that rare, a sample in a batch, i.e. c * t * h * w volume of >8GB is less common, and 2 billion padding is almost nothing useful...)
  3. Some functionalities (e.g. ROCm support, double backward) have not been considered yet.

Any other suggestions are also appreciated. I'm likely able to spend some more time on this PR. Feel free to tag it as WIP if found necessary.

Below is timing results. Timing results are in ms. Kernel size is always (3, 3, 3). The test code will do a 50-iteration warm-up and then another 50-iteration for timing. Timing code is after the table.

gpu / precision input stride cudnn8000 cudnn7605 cudnn7603 ours acceleration vs. 8000 acceleration vs. 7605 acceleration vs. 7603
1080ti-fp32 8, 64, 32, 56, 56 1, 1, 1 211.766 369.786 393.02 24.028 8.8133011 15.389795 16.35675
8, 128, 32, 28, 28 1, 1, 1 110.243 196.788 195.203 12.98 8.4932974 15.160863 15.038752
8, 256, 32, 14, 14 1, 1, 1 54.098 82.065 75.956 7.67 7.0531943 10.699478 9.9029987
8, 512, 32, 7, 7 1, 1, 1 27.191 54.001 56.504 5.693 4.7762164 9.4855085 9.9251713
8, 54, 13, 40, 40 1, 1, 1 38.147 68.127 78.681 4.27 8.9337237 15.954801 18.426464
8, 108, 13, 20, 20 1, 1, 1 19.541 35.042 35.33 2.38 8.2105042 14.723529 14.844538
8, 216, 13, 10, 10 1, 1, 1 10.105 17.961 18.492 1.673 6.0400478 10.735804 11.053198
8, 432, 13, 5, 5 1, 1, 1 8.575 19.642 20.937 1.433 5.9839498 13.706909 14.610607
8, 54, 16, 56, 56 1, 1, 1 91.91 166.328 187.786 10.285 8.936315 16.171901 18.25824
8, 108, 16, 28, 28 1, 1, 1 47.853 85.573 84.938 5.515 8.6768812 15.51641 15.401269
8, 216, 16, 14, 14 1, 1, 1 23.441 35.784 36.467 3.279 7.1488259 10.913083 11.121378
8, 432, 16, 7, 7 1, 1, 1 13.715 33.371 36.416 2.496 5.4947917 13.369792 14.589744
8, 54, 16, 78, 78 1, 1, 1 356.101 319.873 372.093 19.842 17.94683 16.121006 18.752797
8, 108, 16, 39, 39 1, 1, 1 91.41 156.656 172.973 10.101 9.0495991 15.50896 17.124344
8, 216, 16, 20, 20 1, 1, 1 48.876 70.286 71.632 5.844 8.3634497 12.027036 12.257358
8. 432, 16, 10, 10 1, 1, 1 25.06 46.684 49.695 4.068 6.1602753 11.47591 12.216077
8, 64, 32, 56, 56 1, 2, 2 56.704 185.966 189.643 15.74 3.6025413 11.814867 12.048475
8, 128, 32, 28, 28 1, 2, 2 28.973 94.386 95.99 8.28 3.4991546 11.399275 11.592995
8, 256, 32, 14, 14 1, 2, 2 14.337 48.659 49.301 5.041 2.8440786 9.6526483 9.780004
8, 512, 32, 7, 7 1, 2, 2 12.321 41.55 43.374 4.602 2.6773142 9.0286832 9.4250326
8, 54, 13, 40, 40 1, 2, 2 10.231 33.715 34.126 2.807 3.6448165 12.011044 12.157463
8, 108, 13, 20, 20 1, 2, 2 5.381 17.291 17.544 1.636 3.2891198 10.569071 10.723716
8, 216, 13, 10, 10 1, 2, 2 4.555 9.572 10.596 1.126 4.0452931 8.5008881 9.410302
8, 432, 13, 5, 5 1, 2, 2 8.127 13.533 15.155 1.533 5.3013699 8.8277887 9.8858447
8, 54, 16, 56, 56 1, 2, 2 24.532 80.964 82.475 6.766 3.6257759 11.966302 12.189625
8, 108, 16, 28, 28 1, 2, 2 12.524 41.091 42.633 3.545 3.5328632 11.591255 12.026234
8, 216, 13, 10, 10 1, 2, 2 6.927 24.788 26.079 2.21 3.1343891 11.21629 11.800452
8, 432, 16, 7, 7 1, 2, 2 8.082 18.693 20.099 2.038 3.9656526 9.1722277 9.8621197
8, 54, 16, 78, 78 1, 2, 2 48.927 164.085 167.044 13.031 3.754662 12.591896 12.81897
8, 108, 16, 39, 39 1, 2, 2 24.772 81.949 84.012 6.722 3.6852127 12.191163 12.498066
8, 216, 16, 20, 20 1, 2, 2 12.78 51.02 46.551 3.953 3.2329876 12.906653 11.776119
8. 432, 16, 10, 10 1, 2, 2 9.666 35.868 34.658 2.685 3.6 13.358659 12.908007

V100 results (these machines are not fully under my control so I can't upgrade their drivers. 7.6.3 is highest possible cuDNN version come with an official build)

gpu / precision input stride cudnn7603 ours acceleration vs. 7603
v100-fp32 8, 64, 32, 56, 56 1, 1, 1 133.224 14.221 9.3681176
8, 128, 32, 28, 28 1, 1, 1 65.583 7.206 9.1011657
8, 256, 32, 14, 14 1, 1, 1 39.154 3.994 9.8032048
8, 512, 32, 7, 7 1, 1, 1 46.064 2.739 16.817817
8, 54, 13, 40, 40 1, 1, 1 19.906 2.575 7.7304854
8, 108, 13, 20, 20 1, 1, 1 14.82 1.317 11.252847
8, 216, 13, 10, 10 1, 1, 1 18.46 0.835 22.107784
8, 432, 13, 5, 5 1, 1, 1 22.707 0.669 33.941704
8, 54, 16, 56, 56 1, 1, 1 54.22 6.228 8.7058446
8, 108, 16, 28, 28 1, 1, 1 26.819 3.087 8.6877227
8, 216, 13, 10, 10 1, 1, 1 23.091 1.718 13.440629
8, 432, 16, 7, 7 1, 1, 1 33.519 1.206 27.793532
8, 54, 16, 78, 78 1, 1, 1 110.881 12.216 9.0767027
8, 108, 16, 39, 39 1, 1, 1 54.419 5.787 9.4036634
8, 216, 16, 20, 20 1, 1, 1 33.329 3.192 10.441416
8. 432, 16, 10, 10 1, 1, 1 39.03 1.998 19.534535
8, 64, 32, 56, 56 1, 2, 2 75.172 11.596 6.4825802
8, 128, 32, 28, 28 1, 2, 2 42.357 5.946 7.1236125
8, 256, 32, 14, 14 1, 2, 2 29.265 3.331 8.78565
8, 512, 32, 7, 7 1, 2, 2 50.152 2.44 20.554098
8, 54, 13, 40, 40 1, 2, 2 15.439 2.1 7.3519048
8, 108, 13, 20, 20 1, 2, 2 13.267 1.118 11.866726
8, 216, 13, 10, 10 1, 2, 2 7.29 0.69 10.565217
8, 432, 13, 5, 5 1, 2, 2 13.66 0.712 19.185393
8, 54, 16, 56, 56 1, 2, 2 31.91 4.997 6.3858315
8, 108, 16, 28, 28 1, 2, 2 21.3 2.555 8.3365949
8, 216, 13, 10, 10 1, 2, 2 13.554 1.439 9.419041
8, 432, 16, 7, 7 1, 2, 2 13.84 1.078 12.83859
8, 54, 16, 78, 78 1, 2, 2 61.626 12.163 5.0666776
8, 108, 16, 39, 39 1, 2, 2 35.228 4.865 7.24111
8, 216, 16, 20, 20 1, 2, 2 25.296 2.696 9.3827893
8. 432, 16, 10, 10 1, 2, 2 15.249 1.643 9.2811929
v100-fp16 8, 64, 32, 56, 56 1, 1, 1 113.862 15.68 7.2616071
8, 128, 32, 28, 28 1, 1, 1 55.282 8.139 6.7922349
8, 256, 32, 14, 14 1, 1, 1 37.496 4.534 8.2699603
8, 512, 32, 7, 7 1, 1, 1 43.379 6.245 6.946197
8, 54, 13, 40, 40 1, 1, 1 19.6 2.885 6.7937608
8, 108, 13, 20, 20 1, 1, 1 14.869 1.498 9.9259012
8, 216, 13, 10, 10 1, 1, 1 18.018 0.937 19.229456
8, 432, 13, 5, 5 1, 1, 1 22.801 0.725 31.449655
8, 54, 16, 56, 56 1, 1, 1 46.43 6.916 6.7134182
8, 108, 16, 28, 28 1, 1, 1 25.588 3.512 7.285877
8, 216, 13, 10, 10 1, 1, 1 22.608 1.95 11.593846
8, 432, 16, 7, 7 1, 1, 1 35.111 1.336 26.280689
8, 54, 16, 78, 78 1, 1, 1 91.14 13.196 6.9066384
8, 108, 16, 39, 39 1, 1, 1 45.84 6.481 7.0729826
8, 216, 16, 20, 20 1, 1, 1 32.115 3.612 8.891196
8. 432, 16, 10, 10 1, 1, 1 42.731 2.252 18.974689
8, 64, 32, 56, 56 1, 2, 2 71.966 12.484 5.7646588
8, 128, 32, 28, 28 1, 2, 2 41.712 6.518 6.3995091
8, 256, 32, 14, 14 1, 2, 2 34.252 3.67 9.33297
8, 512, 32, 7, 7 1, 2, 2 52.537 2.642 19.885314
8, 54, 13, 40, 40 1, 2, 2 15.432 2.282 6.762489
8, 108, 13, 20, 20 1, 2, 2 13.468 1.235 10.905263
8, 216, 13, 10, 10 1, 2, 2 13.817 0.754 18.324934
8, 432, 13, 5, 5 1, 2, 2 14.453 0.759 19.042161
8, 54, 16, 56, 56 1, 2, 2 31.771 5.424 5.8574853
8, 108, 16, 28, 28 1, 2, 2 21.357 2.8 7.6275
8, 216, 13, 10, 10 1, 2, 2 22.819 1.585 14.396845
8, 432, 16, 7, 7 1, 2, 2 21.691 1.167 18.586975
8, 54, 16, 78, 78 1, 2, 2 65.296 10.269 6.3585549
8, 108, 16, 39, 39 1, 2, 2 35.04 5.276 6.641395
8, 216, 16, 20, 20 1, 2, 2 29.13 2.978 9.7817327
8. 432, 16, 10, 10 1, 2, 2 33.362 1.795 18.586072

Timing code:

#!/usr/bin/env python

import os, sys
import time
import torch
print(torch)
torch.backends.cudnn.benchmark = True
print(torch.backends.cudnn.version())
import torch.nn as nn

feat_size = [
    (8, 64, 32, 56, 56),
    (8, 128, 32, 28, 28),
    (8, 256, 32, 14, 14),
    (8, 512, 32, 7, 7),
    (8, 54, 13, 40, 40),
    (8, 108, 13, 20, 20),
    (8, 216, 13, 10, 10),
    (8, 432, 13, 5, 5),
    (8, 54, 16, 56, 56),
    (8, 108, 16, 28, 28),
    (8, 216, 16, 14, 14),
    (8, 432, 16, 7, 7),
    (8, 54, 16, 78, 78),
    (8, 108, 16, 39, 39),
    (8, 216, 16, 20, 20),
    (8, 432, 16, 10, 10),
    ]
max_iter = 50

print('Running %d times for each config.' % max_iter)
for tcase, size in enumerate(feat_size):
  data = torch.randn(size, requires_grad=True, device='cuda')
  conv = nn.Conv3d(data.size(1), data.size(1),
      kernel_size=(3, 3, 3),
      padding=(1, 1, 1),
      #stride=(1, 2, 2), # may use stride 1 or stride 2
      stride=(1, 1, 1),
      groups=data.size(1),
      bias=False)
  nn.init.kaiming_normal_(conv.weight)
  conv.cuda()

  #data = data.half() # may or may not use fp16.
  #conv.half()

  with torch.backends.cudnn.flags(enabled=True):
    for i in range(50):
      conv(data).sum().backward() # warmup
    torch.cuda.synchronize()
    time_st = time.time()
    for i in range(max_iter):
      conv(data).sum().backward()
    torch.cuda.synchronize()
  time_ed = time.time()
  print('Input size: %s, total_time: %.6f s, avg_time: %.6f s' % (
    str(size), time_ed - time_st, (time_ed - time_st) / max_iter))

@linziyi96 linziyi96 changed the title 3d depthwise conv init Accelerated 3D depthwise convolution implementation Jun 30, 2020
@dr-ci
Copy link

dr-ci bot commented Jun 30, 2020

💊 CI failures summary and remediations

As of commit 4a03290 (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 2/4 non-CircleCI failure(s)

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_python_doc_push (1/2)

Step: "Doc Build and Push" (full log | diagnosis details | 🔁 rerun)

Jun 30 20:54:29 /var/lib/jenkins/workspace/vision/torchvision/csrc/cpu/image/readpng_cpu.cpp:43:17: error: 'struct decodePNG(const at::Tensor&)::Reader' has no member named 'ptr'
Jun 30 20:54:29 /var/lib/jenkins/workspace/vision/torchvision/csrc/cpu/image/readpng_cpu.cpp:37:37: error: 'png_const_bytep' was not declared in this scope 
Jun 30 20:54:29    reader.ptr = png_const_bytep(datap) + 8; 
Jun 30 20:54:29                                      ^ 
Jun 30 20:54:29 /var/lib/jenkins/workspace/vision/torchvision/csrc/cpu/image/readpng_cpu.cpp: In lambda function: 
Jun 30 20:54:29 /var/lib/jenkins/workspace/vision/torchvision/csrc/cpu/image/readpng_cpu.cpp:42:27: error: 'struct decodePNG(const at::Tensor&)::Reader' has no member named 'ptr' 
Jun 30 20:54:29          std::copy(reader->ptr, reader->ptr + bytes, output); 
Jun 30 20:54:29                            ^ 
Jun 30 20:54:29 /var/lib/jenkins/workspace/vision/torchvision/csrc/cpu/image/readpng_cpu.cpp:42:40: error: 'struct decodePNG(const at::Tensor&)::Reader' has no member named 'ptr' 
Jun 30 20:54:29          std::copy(reader->ptr, reader->ptr + bytes, output); 
Jun 30 20:54:29                                         ^ 
Jun 30 20:54:29 /var/lib/jenkins/workspace/vision/torchvision/csrc/cpu/image/readpng_cpu.cpp:43:17: error: 'struct decodePNG(const at::Tensor&)::Reader' has no member named 'ptr' 
Jun 30 20:54:29          reader->ptr += bytes; 
Jun 30 20:54:29                  ^ 
Jun 30 20:54:29 error: command 'gcc' failed with exit status 1 

See CircleCI build pytorch_macos_10_13_py3_test (2/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Jun 30 15:08:20 [E request_callback_impl.cpp:168] Received error while processing request type 2: PickleError: ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. Mixed serialization of script and non-script modules is not supported. For purely script modules use my_script_module.save() instead.
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Jun 30 15:08:20  
Jun 30 15:08:20 [E request_callback_impl.cpp:168] Received error while processing request type 2: PickleError: ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. Mixed serialization of script and non-script modules is not supported. For purely script modules use my_script_module.save(<filename>) instead. 
Jun 30 15:08:20  
Jun 30 15:08:20 At: 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/jit/__init__.py(2082): __getstate__ 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Jun 30 15:08:20  
Jun 30 15:08:20 [E request_callback_impl.cpp:168] Received error while processing request type 2: PickleError: ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. Mixed serialization of script and non-script modules is not supported. For purely script modules use my_script_module.save(<filename>) instead. 
Jun 30 15:08:20  
Jun 30 15:08:20 At: 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/jit/__init__.py(2082): __getstate__ 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Jun 30 15:08:20   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Jun 30 15:08:20  
Jun 30 15:08:20 ok (1.301s) 
Jun 30 15:08:22   test_unexepected_kwarg_is_specified (__main__.JitRpcTestWithSpawn) ... ok (1.363s) 
Jun 30 15:08:23   test_user_rrefs_confirmed (__main__.JitRpcTestWithSpawn) ... ok (1.323s) 
Jun 30 15:08:24   test_user_rrefs_confirmed_remote (__main__.JitRpcTestWithSpawn) ... ok (1.343s) 

ci.pytorch.org: 2 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 3 times.

@bjuncek
Copy link

bjuncek commented Jun 30, 2020

Cheers @linziyi96 and many thanks for the PR.

Just as a quick question, can you reuse test cases from my old PR #31885 ?
These should be a solid sanity check for correctness and allow checking for memory issues iirc :)

Otherwise thanks again for the PR - I'll try it out on the internal codebase as well to double check if everything runt smoothly, but I'd leave someone more qualified to do a thorough PR review.

@pbelevich pbelevich requested a review from ngimel July 2, 2020 20:56
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2020
@yztongzhan
Copy link

Hi @linziyi96 , I try your 3D depthwise convolution implementation for X3D( https://arxiv.org/abs/2004.04730), but has not obviously acceleration. If it's convenient, I hope to discuss this problem with you. Could you leave your email or contact me by tongzhan@smail.nju.edu.cn. Thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lly-zero-one has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

const int in_row_start = out_row * dH - pH;
const int in_frame_start = out_frame * dT - pT;

scalar_t sum = (scalar_t)0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for fp16 sum should be accscalar_t (float), not scalar_t.

const int out_frame_end = in_frame + pT;

const scalar_t* kernel_ptr = kernel[in_channel * channel_multiplier].data();
scalar_t sum = (scalar_t)0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about accscalar_t

}

template <int dim>
std::vector<int64_t> get_output_size(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's already a function conv_output_size in ConvUtils.h, no need to have a separate one

TORCH_CHECK(dilation.size() == dim,
"dilation length should be ", dim, ", but got ", dilation.size());

TORCH_CHECK(input.defined(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually inputs being defined is not checked, as there's no way to get undefined inputs from python, and they will soon error out (e.g. on .size() call) anyway

output_ = output.unsqueeze(0);
}
Tensor weight_ = weight.contiguous();
Tensor bias_ = bias.defined() ? bias.contiguous() : bias;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no call for input_.contiguous()?

TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= int_max,
"Padded input tensor is too large.");
}
TORCH_CHECK(grad_output_.size(0) * grad_output_.size(2) < int_max - block / C10_WARP_SIZE &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments here for why conditions look like this will be helpful

const int warpid = threadIdx.x / C10_WARP_SIZE;
const int nwarps = blockDim.x / C10_WARP_SIZE;

scalar_t grad = (scalar_t)0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulation should be in accscalar_t

scalar_t grad = (scalar_t)0;
int batch = warpid / oT;
int gout_frame = warpid - batch * oT;
for (int outer_pos = warpid; outer_pos < input.size(0) * oT;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a comment on the strategy is useful here (warp accumulates over 2d image and loops over batches + 3rd dimension)

sdata[threadIdx.x] = grad;
__syncthreads();

assert(__popc(blockDim.x) == 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you really need this assert use CUDA_KERNEL_ASSERT, however, it can be done on the hostside.

@lly-zero-one
Copy link
Contributor

@linziyi96 Thanks for the contribution. Could you try to resolve the comments and get the PR in PyTorch?

@gurkirt
Copy link

gurkirt commented Nov 30, 2020

@linziyi96 any plan to merge it? How can I use this patch without merge?

facebook-github-bot pushed a commit that referenced this pull request Feb 14, 2021
Summary:
Because this pull request (#40801) becomes an important part of recent 3D models, brings significant improvement in speed, and also have been open for a while. So I decided to resolve the previous review comment and modify it a bit so that it can be merged into the latest version of Pytorch.

Pull Request resolved: #51027

Reviewed By: albanD

Differential Revision: D26414116

Pulled By: ngimel

fbshipit-source-id: 562c099f4d7f6d603a9c2f2e2a518bc577b0d8ee
@dzabraev
Copy link

Is there any plans to merge this pull request?

@ngimel
Copy link
Collaborator

ngimel commented Feb 19, 2021

#51027 providing this functionality has been merged.

@ngimel ngimel closed this Feb 19, 2021
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
Summary:
Because this pull request (pytorch#40801) becomes an important part of recent 3D models, brings significant improvement in speed, and also have been open for a while. So I decided to resolve the previous review comment and modify it a bit so that it can be merged into the latest version of Pytorch.

Pull Request resolved: pytorch#51027

Reviewed By: albanD

Differential Revision: D26414116

Pulled By: ngimel

fbshipit-source-id: 562c099f4d7f6d603a9c2f2e2a518bc577b0d8ee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants