Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows CI intermittent error: C2993: 'Derived': illegal type for non-type template parameter '__formal #25393

Closed
ezyang opened this issue Aug 29, 2019 · 48 comments · Fixed by pytorch/builder#448
Labels
high priority module: build Build system issues module: ci Related to continuous integration module: flaky-tests Problem is a flaky test in CI module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Aug 29, 2019

Kind of similar to #25389

Sometimes our Eigen build fails this way:

C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(366): error C2993: 'Derived': illegal type for non-type template parameter '__formal'
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(366): error C2065: 'RhsType': undeclared identifier
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(366): error C2923: 'std::_Select<__formal>::_Apply': 'RhsType' is not a valid template type argument for parameter '<unnamed-symbol>'
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(366): error C4430: missing type specifier - int assumed. Note: C++ does not support default-int
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(366): error C2144: syntax error: 'unknown-type' should be preceded by ')'
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(366): error C2144: syntax error: 'unknown-type' should be preceded by ';'
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(367): error C2062: type 'bool' unexpected
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(369): error C2059: syntax error: ')'
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(371): error C2143: syntax error: missing ';' before '{'
C:\Jenkins\workspace\caffe2-builds\py2-cuda9.0-cudnn7-windows-build\third_party\eigen\Eigen\src/LU/InverseImpl.h(371): error C2447: '{': missing function header (old-style formal list?)

It doesn't seem to reliably repro.

cc @ezyang @gchanan @zou3519 @seemethere @peterjc123

@pytorchbot pytorchbot added the module: windows Windows support for PyTorch label Aug 29, 2019
@VitalyFedyunin VitalyFedyunin added module: build Build system issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 29, 2019
@ezyang
Copy link
Contributor Author

ezyang commented Sep 3, 2019

I got some sort of similar-ish error persistently on a test branch:

14:06:51 L1Cost.cu
14:06:51 C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/detail/allocator/allocator_traits.inl(101): error C2993: 'T': illegal type for non-type template parameter '__formal'
14:06:51 C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/detail/allocator/allocator_traits.inl(101): note: see reference to class template instantiation 'thrust::detail::allocator_traits_detail::has_member_construct2_impl_has_member<T,Result(Arg1,Arg2)>' being compiled
14:06:51 C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/detail/allocator/allocator_traits.inl(101): error C2065: 't': undeclared identifier
14:06:51 C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/detail/allocator/allocator_traits.inl(101): error C2923: 'std::_Select<__formal>::_Apply': 't' is not a valid template type argument for parameter '<unnamed-symbol>'
14:06:51 C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/detail/allocator/allocator_traits.inl(101): error C2062: type 'unknown-type' unexpected
14:06:51 -- Removing C:/Jenkins/workspace/caffe2-builds/py2-cuda9.0-cudnn7-windows-build/build/caffe2/CMakeFiles/torch.dir/__/aten/src/THCUNN/./torch_generated_L1Cost.cu.obj

@ezyang
Copy link
Contributor Author

ezyang commented Sep 3, 2019

Not that persistently; a PR stacked on the failing one succeeded. So there is something nondeterministic going on here.

@ezyang ezyang changed the title Windows CI error: C2993: 'Derived': illegal type for non-type template parameter '__formal Windows CI intermittent error: C2993: 'Derived': illegal type for non-type template parameter '__formal Sep 3, 2019
@ezyang ezyang added module: ci Related to continuous integration module: flaky-tests Problem is a flaky test in CI labels Sep 3, 2019
@smessmer
Copy link
Contributor

smessmer commented Sep 3, 2019

another instance: CI on #23888

@peterjc123
Copy link
Collaborator

Looks like an NVCC bug. The workground is protecting eigen headers with #ifndef __CUDACC__. Can we extract some basic code example and compile it against NVCC?

@peterjc123
Copy link
Collaborator

Another one in https://app.circleci.com/jobs/github/pytorch/pytorch/3919969.

unique_ops.cu
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/unique_by_key.h(211): error C2993: 'KeyInputIt': illegal type for non-type template parameter '__formal'
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/unique_by_key.h(246): note: see reference to class template instantiation 'thrust::cuda_cub::__unique_by_key::UniqueByKeyAgent<KeyInputIt,ValInputIt,KeyOutputIt,ValOutputIt,BinaryPred,Size,NumSelectedOutIt>::PtxPlan<Arch>' being compiled
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/unique_by_key.h(589): note: see reference to class template instantiation 'thrust::cuda_cub::__unique_by_key::UniqueByKeyAgent<KeyInputIt,ValInputIt,KeyOutputIt,ValOutputIt,BinaryPred,Size,NumSelectedOutIt>' being compiled
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/unique_by_key.h(211): error C2065: 'IS_FIRST_TILE': undeclared identifier
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1/include\thrust/system/cuda/detail/unique_by_key.h(211): error C2993: 'KeyInputIt': illegal type for non-type template parameter '__formal'

@gottbrath
Copy link
Contributor

Can you try with 10.2?

@gottbrath
Copy link
Contributor

Mingbo, can you put general information on the issue you are seeing here in this issue?

@gottbrath
Copy link
Contributor

latest word on this is

from mingbo:
"from what I can tell, it's a flaky one. might depend on which VS version. I hit the problem before and problem fixed (at least seem so) after upgrading VS 2019 latest version" and "the same error showed up once last week, and re-run the job finished without problem."

@jmlundberg
Copy link

Hi. Did anyone have this issue persistently, ie that a rebuild did not "fix" it? We've had this issue in another (closed source) project. Exposing cuda to less includes (such as boost) seems to help but is not predictable.

@ezyang
Copy link
Contributor Author

ezyang commented Feb 4, 2020

It's never been persistent for us, but this is mostly from our Windows CI builds where we blow everything away and rebuild from scratch, which may help.

@peterjc123
Copy link
Collaborator

peterjc123 commented Feb 10, 2020

Yet another occurrence in the forum: https://discuss.pytorch.org/t/how-do-i-get-my-older-gpu-that-supports-cuda-to-work-with-pytorch-1-4/69005/13. It seems to be fixed by set MAX_JOBS=1, which implies there may be a race condition in reading the include files or generating the intermediate source files.

facebook-github-bot pushed a commit that referenced this issue Feb 10, 2020
Summary:
## Several flags
`/MP[M]`: It is a flag for the compiler `cl`. It leads to object-level multiprocessing. By default, it spawns M processes where M is the number of cores on the PC.
`/maxcpucount:[M]`: It is a flag for the generator `msbuild`. It leads to project-level multiprocessing. By default, it spawns M processes where M is the number of cores on the PC.
`/p:CL_MPCount=[M]`: It is a flag for the generator `msbuild`. It leads the generator to pass `/MP[M]` to the compiler.
`/j[M]`: It is a flag for the generator `ninja`. It leads to object-level multiprocessing. By default, it spawns M processes where M is the number of cores on the PC.

## Reason for the change
1. Object-level multiprocessing is preferred over project-level multiprocessing.
2. ~For ninja, we don't need to set `/MP` otherwise M * M processes will be spawned.~ Actually, it is not correct because in ninja configs, there are only one source file in the command. Therefore, the `/MP` switch should be useless.
3. For msbuild, if it is called through Python configuration scripts, then `/p:CL_MPCount=[M]` will be added, otherwise, we add `/MP` to `CMAKE_CXX_FLAGS`.
4. ~It may be a possible fix for #28271, #27463 and #25393. Because `/MP` is also passed to `nvcc`.~ It is probably not true. Because `/MP` should not be effective given there is only one source file per command.

## Reference
1. https://docs.microsoft.com/en-us/cpp/build/reference/mp-build-with-multiple-processes?view=vs-2019
2. https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows
3. https://blog.kitware.com/cmake-building-with-all-your-cores/
Pull Request resolved: #33120

Differential Revision: D19817227

Pulled By: ezyang

fbshipit-source-id: f8d01f835016971729c7a8d8a0d1cb8a8c2c6a5f
@peterjc123
Copy link
Collaborator

I was able to reproduce the issue with a more verbose output with #33693: https://app.circleci.com/jobs/github/pytorch/pytorch/3919969. And then I inspected the variables carefully and found out that the VC env is activated twice. We should really avoid that.

@ezyang
Copy link
Contributor Author

ezyang commented Feb 25, 2020

You are SO COOL!!!

facebook-github-bot pushed a commit that referenced this issue Feb 25, 2020
Summary:
Possibly get rid of #28271, #27463 and #25393.
Pull Request resolved: #33700

Differential Revision: D20089251

Pulled By: ezyang

fbshipit-source-id: 0cfe62b869fb874e25f06894aa76fadc44cf6817
@peterjc123
Copy link
Collaborator

I tried to build from scratch several times with this PR and could not reproduce the issue anymore. Let's assume it's fixed.

@peterjc123
Copy link
Collaborator

Looks like it was not resolved. https://circleci.com/gh/pytorch/pytorch/4599162?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link

@mstfbl
Copy link
Collaborator

mstfbl commented Mar 9, 2021

I have also obtained this error with CUDA 10.1/cuDNN 7.6.4 on Win Server 2019.

[4153/5198] Building NVCC (Device) object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorTopKHalf.cu.obj
THCTensorTopKHalf.cu
THCTensorTopKHalf.cu
[4154/5198] Building NVCC (Device) object caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorMathReduceLong.cu.obj
FAILED: caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorMathReduceLong.cu.obj 
cmd.exe /C "cd /D F:\agent\_work\1\s\build\caffe2\CMakeFiles\torch_cuda.dir\__\aten\src\THC\generated && C:\Miniconda\envs\windows_2019_py_37\Library\bin\cmake.exe -E make_directory F:/agent/_work/1/s/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/. && C:\Miniconda\envs\windows_2019_py_37\Library\bin\cmake.exe -D verbose:BOOL=OFF -D build_configuration:STRING=Release -D generated_file:STRING=F:/agent/_work/1/s/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMathReduceLong.cu.obj -D generated_cubin_file:STRING=F:/agent/_work/1/s/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMathReduceLong.cu.obj.cubin.txt -P F:/agent/_work/1/s/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/torch_cuda_generated_THCTensorMathReduceLong.cu.obj.Release.cmake"
THCTensorMathReduceLong.cu
THCTensorMathReduceLong.cu
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(889): error C2993: 'T': is not a valid type for non-type template parameter '__formal'
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(909): note: see reference to class template instantiation 'thrust::cuda_cub::cub::BinaryOpHasIdxParam<T,BinaryOp>' being compiled
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(889): error C2065: '__T3': undeclared identifier
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include\thrust\system\cuda\detail\cub\util_type.cuh(889): error C2923: 'std::_Select<__formal>::_Apply': '__T3' is not a valid template type argument for parameter '<unnamed-symbol>'
CMake Error at torch_cuda_generated_THCTensorMathReduceLong.cu.obj.Release.cmake:281 (message):
  Error generating file
  F:/agent/_work/1/s/build/caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/THC/generated/./torch_cuda_generated_THCTensorMathReduceLong.cu.obj

@leezu
Copy link
Contributor

leezu commented Mar 9, 2021

@mstfbl this is a cuda bug. Please see above discussion and upgrade to cuda 11.

@mstfbl
Copy link
Collaborator

mstfbl commented Mar 11, 2021

@leezu I fixed the error by installing sscache for use with cuda 10.1 and cuda 10.2, as done here. https://github.com/pytorch/builder/blob/993e8b275e313641796db8a0b2869d5f3dd13828/windows/build_pytorch.bat#L98-L124

@leezu
Copy link
Contributor

leezu commented Mar 11, 2021

@mstfbl it's an intermittent error and the probability of occurrence varies depending on your system. So it's not really fixed, but you may have found a way to reduce the occurrence in your system, which may be sufficient in your case :)

@peterbell10
Copy link
Collaborator

Another instance, this time inside protobuf:

roi_align_op.cu
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2993: 'T': is not a valid type for non-type template parameter '__formal'
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(459): note: see reference to class template instantiation 'google::protobuf::Arena::InternalHelper<T>' being compiled
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2065: 'Second': undeclared identifier
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2923: 'std::_Select<__formal>::_Apply': 'Second' is not a valid template type argument for parameter '<unnamed-symbol>'
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C4430: missing type specifier - int assumed. Note: C++ does not support default-int
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2144: syntax error: 'unknown-type' should be preceded by ')'
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2144: syntax error: 'unknown-type' should be preceded by ';'
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2238: unexpected token(s) preceding ';'
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2059: syntax error: ')'
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2988: unrecognizable template declaration/definition
C:/Users/circleci/project/third_party/protobuf/src\google/protobuf/arena.h(427): error C2059: syntax error: '<end Parse>'

https://app.circleci.com/pipelines/github/pytorch/pytorch/317648/workflows/504abdce-7887-47a1-bf6a-de417da17503/jobs/13229319

Running nvcc with --verbose --keep I found the syntax error is introduced by cudafe++ which is the cuda compiler front end. cudafe++ operates on the output from the compiler's pre-processor and that output barely even changed to trigger the build error, only line number metadata and some extra whitespace was changed in the pre-processed output.

Here is the diff of the two inputs, one of which failed and the other didn't:

$ diff nvcc-output/roi_align_op.cpp4.ii nvcc-output-2/roi_align_op.cpp4.ii
353964d353963
< #line 229 "C:/Users/circleci/project\\c10/macros/Macros.h"
353966c353965,353966
< #line 231 "C:/Users/circleci/project\\c10/macros/Macros.h"
---
> 
> 

And the corresponding change in the output:

$ diff nvcc-output/roi_align_op.compute_75.cudafe1.cpp nvcc-output-2/roi_align_op.compute_75.cudafe1.cpp
124779c124779
< template< class U> static char ArenaConstructable(const typename std::_Select< T> ::template _Apply< U, Second> ::InternalArenaConstructable_ *); 
---
> template< class U> static char ArenaConstructable(const typename U::InternalArenaConstructable_ *); 
162181c162181
< namespace _GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_c6fef629 { }; using namespace ::at::_GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_c6fef629; namespace _GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_c6fef629 { 
---
> namespace _GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_6d8ec415 { }; using namespace ::at::_GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_6d8ec415; namespace _GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_6d8ec415 { 
201406c201406
< #define _NV_ANON_NAMESPACE _GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_c6fef629
---
> #define _NV_ANON_NAMESPACE _GLOBAL__N__31_roi_align_op_compute_75_cpp1_ii_6d8ec415

This failure reproduced reliably on the CI machine. Running nvcc manually multiple times, it always fails or passes given the same inputs. I think it's a deterministic failure, but is annoyingly sensitive to any changes in the input file whatsoever.

tcojean added a commit to ginkgo-project/ginkgo that referenced this issue Aug 4, 2021
There is an issue with the CUDA implementation which prevents a proper
execution. See pytorch/pytorch#25393. Tweaking
the compiler settings would allow to get less errors, but it seems
impossible to prevent the errors altogether.
tcojean added a commit to ginkgo-project/ginkgo that referenced this issue Aug 4, 2021
There is an issue with the CUDA implementation which prevents a proper
execution. See pytorch/pytorch#25393. Tweaking
the compiler settings would allow to get less errors, but it seems
impossible to prevent the errors altogether.
github-actions bot pushed a commit to ginkgo-project/ginkgo that referenced this issue Aug 4, 2021
There is an issue with the CUDA implementation which prevents a proper
execution. See pytorch/pytorch#25393. Tweaking
the compiler settings would allow to get less errors, but it seems
impossible to prevent the errors altogether.
tcojean added a commit to ginkgo-project/ginkgo that referenced this issue Aug 5, 2021
Fix Readme and disable MSVC-CUDA 10.2

+ Update to the new package status. Simplify the HIP-related INSTALL.md section.
+ Disable the MSVC-CUDA 10.2 job.

There is an issue with the CUDA implementation which prevents a proper
execution. See pytorch/pytorch#25393 and NVIDIA/thrust#1090. Tweaking
the compiler settings would allow getting fewer errors, but it seems
impossible to prevent the errors altogether.

Related PR: #852
tcojean added a commit to ginkgo-project/ginkgo that referenced this issue Aug 20, 2021
There is an issue with the CUDA implementation which prevents a proper
execution. See pytorch/pytorch#25393. Tweaking
the compiler settings would allow to get less errors, but it seems
impossible to prevent the errors altogether.
MarcelKoch pushed a commit to ginkgo-project/ginkgo that referenced this issue Sep 10, 2021
There is an issue with the CUDA implementation which prevents a proper
execution. See pytorch/pytorch#25393. Tweaking
the compiler settings would allow to get less errors, but it seems
impossible to prevent the errors altogether.
malfet added a commit to malfet/pytorch that referenced this issue Sep 29, 2021
facebook-github-bot pushed a commit that referenced this issue Sep 29, 2021
Summary:
See #65612 and #25393

Fixes #65648

Pull Request resolved: #65649

Reviewed By: janeyx99

Differential Revision: D31189692

Pulled By: malfet

fbshipit-source-id: 6ec0548d5833f3428d882071d26c357d89b0a9ba
@ngimel
Copy link
Collaborator

ngimel commented Feb 7, 2022

closing due to age

@ngimel ngimel closed this as completed Feb 7, 2022
@ezyang
Copy link
Contributor Author

ezyang commented Feb 10, 2022

Wait, do we think this is fixed? It's not clear to me it is...

@seemethere
Copy link
Member

I think this was related to CUDA 10.2 builds on windows which we don't actually support anymore so it is probably safe to keep closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: build Build system issues module: ci Related to continuous integration module: flaky-tests Problem is a flaky test in CI module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet