Skip to content

Conversation

csarofeen
Copy link
Contributor

Regrouped cuda9 fixes and cleaned branch so it wouldn't show other commits in PR.
Also added cudnn7 grouped convolution support, and a hgemm fix that was needed for cuda9 for pre-maxwell hardware.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the parts I understood. I'll need to skim through cuDNN7 and nccl2 docs to check the rest

test/test_nn.py Outdated
output.backward(grad_output)
types = (torch.FloatTensor,)
if TEST_CUDA:
types += (torch.cuda.FloatTensor,)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -238,11 +242,11 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,

This comment was marked as off-topic.

This comment was marked as off-topic.

if (groupIdx > 0) {
long size = 1;
for (int i = dim; i < tensor->nDimension; ++i) {
size *= tensor->size[i];
}
ptr += elementSize * size * groupIdx / groups;
}
}

This comment was marked as off-topic.

@tpankaj
Copy link

tpankaj commented Aug 7, 2017

I got an error while compiling your pull request. It was at the very end, something about not finding HalfTensor. I can rerun it and get the full error message if this isn't a known issue.

@csarofeen
Copy link
Contributor Author

@tpankaj Please run python setup.py clean, rm -rf build then try again. If the error persists, please post it here. Thanks!

@tpankaj
Copy link

tpankaj commented Aug 10, 2017

The error persists. Here it is, compiled on an Ubuntu 16.04 system with CUDA 9 and CuDNN 7.

Grabbing  src/nccl.h                          > /home/ubuntu/pytorch/torch/lib/build/nccl/include/nccl.h
Compiling src/libwrap.cu                      > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/libwrap.o
Compiling src/core.cu                         > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/core.o
Compiling src/all_gather.cu                   > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_gather.o
Compiling src/all_reduce.cu                   > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_reduce.o
src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

6 errors detected in the compilation of "/tmp/tmpxft_00003042_00000000-11_all_gather.compute_61.cpp1.ii".

@ngimel
Copy link
Collaborator

ngimel commented Aug 10, 2017

nccl subtree needs to be updated, cc @soumith, @apaszke. Alternatively, you can install nccl on your system (you need dev packages, so you'd have nccl header) nccl subtree won't be compiled.

@TeslasGhost
Copy link

TeslasGhost commented Aug 10, 2017

Hi,

Not sure if this is this is the same problem as the abovementioned, but I attempt to build from source using CUDA9 and CUDNN 7. It seems to get pretty far in the build process, and then towards the end I am getting...

Scanning dependencies of target nccl
[100%] Generating lib/libnccl.so
Grabbing  src/nccl.h                          > /home/ubuntu/pytorch/torch/lib/build/nccl/include/nccl.h
Compiling src/libwrap.cu                      > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/libwrap.o
Compiling src/core.cu                         > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/core.o
Compiling src/all_gather.cu                   > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_gather.o
Compiling src/all_reduce.cu                   > /home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_reduce.o
src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(42): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/common_kernel.h(55): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

src/copy_kernel.h(28): error: class "__half" has no member "x"

6 errors detected in the compilation of "/tmp/tmpxft_00004d18_00000000-11_all_gather.compute_61.cpp1.ii".
Makefile:121: recipe for target '/home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_gather.o' failed
make[3]: *** [/home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_gather.o] Error 1
make[3]: *** Waiting for unfinished jobs....
6 errors detected in the compilation of "/tmp/tmpxft_00004d29_00000000-11_all_reduce.compute_61.cpp1.ii".
Makefile:121: recipe for target '/home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_reduce.o' failed
make[3]: *** [/home/ubuntu/pytorch/torch/lib/build/nccl/obj/all_reduce.o] Error 1
ptxas warning : Too big maxrregcount value specified 96, will be ignored
ptxas warning : Too big maxrregcount value specified 96, will be ignored
CMakeFiles/nccl.dir/build.make:60: recipe for target 'lib/libnccl.so' failed
make[2]: *** [lib/libnccl.so] Error 2
CMakeFiles/Makefile2:67: recipe for target 'CMakeFiles/nccl.dir/all' failed
make[1]: *** [CMakeFiles/nccl.dir/all] Error 2
Makefile:127: recipe for target 'all' failed
make: *** [all] Error 2

Can someone kindly confirm that this is also a nccl subtree issue?
[EDIT]: It appears that it is. I have no idea how to install NCCL Header on my 'system', but I will look into it now. Any other help would be greatly appreciated. :)

@tpankaj
Copy link

tpankaj commented Aug 10, 2017

@TeslasGhost That looks like the exact error I got.

@tpankaj
Copy link

tpankaj commented Aug 10, 2017

@ngimel I'm continuing to get an error now that I've installed nccl and nccl dev package for CUDA 9:

CMake Error at /home/ubuntu/anaconda3/share/cmake-3.6/Modules/ExternalProject.cmake:1924 (message):
  No download info given for 'nccl_external' and its source directory:

   /home/ubuntu/pytorch/torch/lib/gloo/third-party/nccl

  is not an existing non-empty directory.  Please specify one of:

   * SOURCE_DIR with an existing non-empty directory
   * URL
   * GIT_REPOSITORY
   * HG_REPOSITORY
   * CVS_REPOSITORY and CVS_MODULE
   * SVN_REVISION
   * DOWNLOAD_COMMAND
Call Stack (most recent call first):
  /home/ubuntu/anaconda3/share/cmake-3.6/Modules/ExternalProject.cmake:2473 (_ep_add_download_command)
  cmake/External/nccl.cmake:16 (ExternalProject_Add)
  cmake/Dependencies.cmake:53 (include)
  CMakeLists.txt:49 (include)


-- Configuring incomplete, errors occurred!
See also "/home/ubuntu/pytorch/torch/lib/build/gloo/CMakeFiles/CMakeOutput.log".

Is there a source directory for that version of nccl that I need to point it to?

@ngimel
Copy link
Collaborator

ngimel commented Aug 10, 2017

Apparently Findnccl.cmake in gloo subtree is not finding your install of nccl. https://github.com/pytorch/pytorch/blob/master/torch/lib/gloo/cmake/Modules/Findnccl.cmake

@csarofeen
Copy link
Contributor Author

I opened up an issue mentioning the steps required to use a user installed nccl, please see #2375

@futurely
Copy link

When will this be updated and merged? Cudnn7 grouped convolution support resolves the high priority issue #1708 and is independent of CUDA 9 updates. If the two feature sets are separated, both would be easier to be merged.

@ngimel
Copy link
Collaborator

ngimel commented Aug 15, 2017

Unfortunately it does not. For depthwise-separable convolutions this https://github.com/szagoruyko/pyinn is much better, and for other grouped convolutions current cudnn version provides only modest improvements.

@soumith soumith force-pushed the cuda9 branch 2 times, most recently from cbf60dd to e024f41 Compare August 25, 2017 11:25
@soumith soumith merged commit 0d7d79a into pytorch:master Aug 25, 2017
@csarofeen csarofeen deleted the cuda9 branch February 12, 2020 13:32
IvanYashchuk pushed a commit to IvanYashchuk/pytorch that referenced this pull request Jan 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.