Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation of <torch/extension.h> error on Windows CUDA 11.5 #69460

Closed
Tracked by #1042
atalman opened this issue Dec 6, 2021 · 15 comments
Closed
Tracked by #1042

Compilation of <torch/extension.h> error on Windows CUDA 11.5 #69460

atalman opened this issue Dec 6, 2021 · 15 comments
Assignees
Labels
has workaround high priority module: cuda Related to torch.cuda, and CUDA support in general module: dependency bug Problem is not caused by us, but caused by an upstream library we use module: docs Related to our documentation, both in docs/ and docblocks module: pybind Related to our Python bindings / interactions with other Python libraries module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@atalman
Copy link
Contributor

atalman commented Dec 6, 2021

We have following error when compiling CUDA 11.5 on windows

C:\actions-runner_work\pytorch\pytorch\build\win_tmp\build\torch\include\pybind11\cast.h(1429): error: too few arguments for temp
late template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]
"
(1507): here

C:\actions-runner_work\pytorch\pytorch\build\win_tmp\build\torch\include\pybind11\cast.h(1503): error: too few arguments for temp
late template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]
"
(1507): here

Complete failure log:
https://github.com/pytorch/pytorch/runs/4408796098?check_suite_focus=true

This looks like the same issue as this one:
facebookresearch/pytorch3d#843

Here is the workaround for this issue:
facebookresearch/pytorch3d@cb170ac

cc @ezyang @gchanan @zou3519 @peterjc123 @mszhanyi @skyline75489 @nbcsm @brianjo @mruberry @ngimel @bdhirsh @jbschlosser @malfet @seemethere @pytorch/pytorch-dev-infra

@atalman atalman added module: cuda Related to torch.cuda, and CUDA support in general module: ci Related to continuous integration module: infra Relates to CI infrastructure triage review labels Dec 6, 2021
@seemethere seemethere added the module: windows Windows support for PyTorch label Dec 6, 2021
@seemethere seemethere added this to the 1.11.0 milestone Dec 6, 2021
@seemethere
Copy link
Member

Adding this to the 1.11.0 milestone since this will be a blocker for the 1.11.0 release

@malfet malfet self-assigned this Dec 6, 2021
@malfet
Copy link
Contributor

malfet commented Dec 6, 2021

Grabbing for myself, as it probably involves tweaking pybind a bit

@gchanan gchanan added module: cpp-extensions Related to torch.utils.cpp_extension triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Dec 6, 2021
@malfet malfet added high priority and removed module: cpp-extensions Related to torch.utils.cpp_extension module: ci Related to continuous integration module: infra Relates to CI infrastructure labels Dec 6, 2021
@malfet malfet added the module: pybind Related to our Python bindings / interactions with other Python libraries label Dec 6, 2021
@malfet
Copy link
Contributor

malfet commented Dec 6, 2021

I can reproduce it with hello world example that simply includes pybind, which looks like a CUDA compiler bug (happens during the invocation of cicc):

#include <stdio.h>
#include <pybind11/pybind11.h>
__global__ void kernel() {
  printf("Hello World");
}
int main(void) {
 kernel<<<1, 1>>>();
 return cudaDeviceSynchronize();
}
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.5\bin>nvcc "c:\Users\runneruser\Documents\hello.cu" -o c:\Users\runneruser\Documents\a.exe -IC:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include -IC:\Jenkins\Miniconda3\include
hello.cu
...
#$ cicc --microsoft_version=1928 --msvc_target_version=1928 --compiler_bindir "C:/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools/VC/Tools/MSVC/14.28.29333/bin/Hostx64/x64/../../../../../../.." --sdk_dir "C:/Program Files (x86)/Windows Kits/10/" --display_error_number --orig_src_file_name "c:/Users/runneruser/Documents/hello.cu" --orig_src_path_name "c:\Users\runneruser\Documents\hello.cu" --allow_managed  -arch compute_52 -m64 --no-version-ident -ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 --include_file_name "tmpxft_00001bb8_00000000-7_hello.fatbin.c" -tused --gen_module_id_file --module_id_file_name "C:/Users/RUNNER~1/AppData/Local/Temp/2/tmpxft_00001bb8_00000000-8_hello.module_id" --gen_c_file_name "C:/Users/RUNNER~1/AppData/Local/Temp/2/tmpxft_00001bb8_00000000-10_hello.cudafe1.c" --stub_file_name "C:/Users/RUNNER~1/AppData/Local/Temp/2/tmpxft_00001bb8_00000000-10_hello.cudafe1.stub.c" --gen_device_file_name "C:/Users/RUNNER~1/AppData/Local/Temp/2/tmpxft_00001bb8_00000000-10_hello.cudafe1.gpu"  "C:/Users/RUNNER~1/AppData/Local/Temp/2/tmpxft_00001bb8_00000000-13_hello.cpp1.ii" -o "C:/Users/RUNNER~1/AppData/Local/Temp/2/tmpxft_00001bb8_00000000-10_hello.ptx"
C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\detail/common.h(810): warning #1388-D: base class dllexport/dllimport specification differs from that of the derived class

C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\pytypes.h(338): warning #1388-D: base class dllexport/dllimport specification differs from that of the derived class

C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\pytypes.h(387): warning #1394-D: field of class type without a DLL interface used in a class with a DLL interface

C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\pytypes.h(387): warning #1394-D: field of class type without a DLL interface used in a class with a DLL interface

C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\pytypes.h(387): warning #1394-D: field of class type without a DLL interface used in a class with a DLL interface

C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\cast.h(567): error: too few arguments for template template parameter "Tuple"
          detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(648): here

C:\actions-runner\_work\pytorch\pytorch\third_party\pybind11\include\pybind11\cast.h(644): error: too few arguments for template template parameter "Tuple"
          detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(648): here

2 errors detected in the compilation of "c:/Users/runneruser/Documents/hello.cu".
# --error 0x1 --

For debugging: hello.cpp1.ii

@zasdfgbnm
Copy link
Collaborator

Further narrow down:

#include <utility>

// Base implementation for std::tuple and std::pair
template <template<typename...> class Tuple, typename... Ts> class tuple_caster {
    using type = Tuple<Ts...>;
};

template <typename T1, typename T2> class type_caster
    : public tuple_caster<std::pair, T1, T2> {};


__global__ void kernel() {
	printf("Hello World");
}
int main(void) {
	kernel <<<1, 1>>> ();
	return cudaDeviceSynchronize();
}

@albanD albanD added module: dependency bug Problem is not caused by us, but caused by an upstream library we use and removed triage review labels Dec 13, 2021
@malfet malfet removed their assignment Dec 13, 2021
@malfet malfet added has workaround module: docs Related to our documentation, both in docs/ and docblocks labels Feb 15, 2022
@malfet
Copy link
Contributor

malfet commented Feb 15, 2022

Need to document the regression of nvcc compiler from CUDA-11.5 on Windows

@mszhanyi
Copy link
Collaborator

mszhanyi commented Feb 16, 2022

@malfet @atalman
I'd created a project GPU and CUDA regression on Windows to collect GPU and cuda regressions.

@MinttHu
Copy link

MinttHu commented Jun 28, 2022

My system is window 10+CUDA11.3. My problem is that when I use "torch. utils.cpp_extension.load_inline" to built a C++/CUDA extension, it shows: "ninja: build stopped: subcommand failed."

@3a1b2c3
Copy link

3a1b2c3 commented Aug 4, 2022

Any workarounds?

@atalman
Copy link
Contributor Author

atalman commented Sep 23, 2022

@malfet @ptrblck Since we have cuda 11.7 I will rerun testing on this cuda to see if this issue is resolved for CUDA 11.7

@atalman
Copy link
Contributor Author

atalman commented Oct 14, 2022

Looks like this issue is resolved in 11.7. I am not observing the failure anymore on #85966

pytorchmergebot pushed a commit that referenced this issue Oct 19, 2022
Reenable aot tests on windows for cuda 11.7 and up

Issue: #69460 seems to be mitigated in CUDA 11.7 hence re-enable this test

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm
Pull Request resolved: #87193
Approved by: https://github.com/malfet
atalman added a commit to atalman/pytorch that referenced this issue Oct 19, 2022
Reenable aot tests on windows for cuda 11.7 and up

Issue: pytorch#69460 seems to be mitigated in CUDA 11.7 hence re-enable this test

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm
Pull Request resolved: pytorch#87193
Approved by: https://github.com/malfet
@atalman atalman self-assigned this Oct 20, 2022
malfet pushed a commit that referenced this issue Oct 20, 2022
Reenable aot tests on windows for cuda 11.7 and up

Issue: #69460 seems to be mitigated in CUDA 11.7 hence re-enable this test

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm
Pull Request resolved: #87193
Approved by: https://github.com/malfet
@atalman
Copy link
Contributor Author

atalman commented Oct 21, 2022

Removing from milestones, since this issue is not critical for release 1.13 and this is resolved for CUDA 11.7

@atalman atalman modified the milestones: 1.13.0, 1.14.0 Oct 21, 2022
sgrigory pushed a commit to sgrigory/pytorch that referenced this issue Oct 28, 2022
Reenable aot tests on windows for cuda 11.7 and up

Issue: pytorch#69460 seems to be mitigated in CUDA 11.7 hence re-enable this test

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm
Pull Request resolved: pytorch#87193
Approved by: https://github.com/malfet
malfet added a commit that referenced this issue Jan 17, 2023
As CUDA-11.5 is no longer supported, just remove the check

Fixes #69460
@atalman
Copy link
Contributor Author

atalman commented Feb 2, 2023

With deprecation of CUDA 11.6, we can resolve this issue, will post a PR once 11.6 is deprecated from CI

pytorchmergebot pushed a commit that referenced this issue Feb 2, 2023
As CUDA-11.5 is no longer supported, just remove the check

Fixes #69460
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround high priority module: cuda Related to torch.cuda, and CUDA support in general module: dependency bug Problem is not caused by us, but caused by an upstream library we use module: docs Related to our documentation, both in docs/ and docblocks module: pybind Related to our Python bindings / interactions with other Python libraries module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

10 participants