Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUFFT_INTERNAL_ERROR on RTX 4090 #88038

Closed
Yujia-Yan opened this issue Oct 29, 2022 · 17 comments
Closed

CUFFT_INTERNAL_ERROR on RTX 4090 #88038

Yujia-Yan opened this issue Oct 29, 2022 · 17 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: fft module: third_party triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Yujia-Yan
Copy link

Yujia-Yan commented Oct 29, 2022

🐛 Describe the bug

>>> import torch
>>> torch.fft.rfft(torch.randn(1000).cuda())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR

There is a discussion on https://forums.developer.nvidia.com/t/bug-ubuntu-on-wsl2-rtx4090-related-cufft-runtime-error/230883/7 .

Versions

Using pytorch installed with
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

cc @ezyang @gchanan @zou3519 @peterjc123 @mszhanyi @skyline75489 @nbcsm @ngimel @mruberry @peterbell10

@soulitzer soulitzer added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: fft labels Oct 31, 2022
@mruberry mruberry added high priority and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 31, 2022
@malfet malfet added the module: windows Windows support for PyTorch label Nov 7, 2022
@malfet
Copy link
Contributor

malfet commented Nov 7, 2022

Just to clarify, this happen only in Windows Subsystem for Linux or elsewhere as well?

@cpuhrsch cpuhrsch added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Nov 7, 2022
@malfet
Copy link
Contributor

malfet commented Nov 7, 2022

@ptrblck can you please confirm that this indeed happens when with 4090 on Linux or only in WSL config?

@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed high priority triage review labels Nov 7, 2022
@Yujia-Yan
Copy link
Author

Just to clarify, this happen only in Windows Subsystem for Linux or elsewhere as well?

Actually I am using ubuntu Server 22.04.
Update:
I later tried to compile pytorch with CUDA 11.8 and this problem disappears.

@ngimel
Copy link
Collaborator

ngimel commented Nov 7, 2022

So in this case it looks like cufft library doesn't support forward compatibility guarantee (you can run code compiled with older toolkit version, as long as driver on the system supports the new hardware). cc @ptrblck, and we should start producing 11.8 nightlies.

@Blackhex
Copy link
Collaborator

Blackhex commented Nov 7, 2022

I don't have 4090 available so I can only add that it is not reproducible on Windows 11 and Ubuntu WSL with 3080.

@ptrblck
Copy link
Collaborator

ptrblck commented Nov 8, 2022

can you please confirm that this indeed happens when with 4090 on Linux or only in WSL config?

Yes, this is a cuFFT error which is also visible on Linux.

and we should start producing 11.8 nightlies.

Also yes, and I've already started with its bringup, e.g. in: pytorch/builder#1186

@malfet malfet added high priority and removed needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user module: windows Windows support for PyTorch labels Nov 14, 2022
@malfet malfet added this to the 1.13.1 milestone Nov 14, 2022
@malfet
Copy link
Contributor

malfet commented Nov 14, 2022

Some updates:

  • this seems to be the bug in CuFFT in CUDA-11.7 that happens on both Linux and Windows, but seems to be fixed in 11.8
  • It worth trying (and I think some investigation has already been done) to use CuFFT from 11.8 in 11.7 build to see if the fix could be deployed/verified to nightlies first

@malfet
Copy link
Contributor

malfet commented Nov 28, 2022

Let's create a frankenbuild for cuda-11.7 for nightlies and see what will happen

@malfet
Copy link
Contributor

malfet commented Nov 29, 2022

First thing, that worries me a lot is 2x binary size increase for cuda-11.8 and nvprune does not help much:

$ ls -lah 11.7/libcufft/lib64/libcufft.so.10.7.2.50 11.8/libcufft/lib64/libcufft.so.10.9.0.58 
-rwxr-xr-x 1 nshulga nshulga 131M Nov 29 21:48 11.7/libcufft/lib64/libcufft.so.10.7.2.50
-rwxr-xr-x 1 nshulga nshulga 267M Nov 29 21:49 11.8/libcufft/lib64/libcufft.so.10.9.0.58

Considering that, I'm not sure if it will be safe to include it as an update in 1.13.1

@malfet
Copy link
Contributor

malfet commented Dec 19, 2022

Removing high priority label, as this is a bug in a 3rd party library and there were big changes between cufft in 11.7 and 11.8

razarmehr pushed a commit to kulinseth/pytorch that referenced this issue Jan 4, 2023
This PR adds more nvidia pypi dependencies for cuda 11.7 wheel. Additionally, it pins cufft version to 10.9.0.58 to resolve pytorch#88038

Depends on: pytorch/builder#1196

Pull Request resolved: pytorch#89944
Approved by: https://github.com/atalman
@pranavmalikk
Copy link

pranavmalikk commented Feb 4, 2023

Still getting this error on RTX 4090 with Cuda 11.7 on Ubuntu 22.04, any recommendations?

@ptrblck
Copy link
Collaborator

ptrblck commented Feb 4, 2023

@pranavmalikk Yes, please use the nightly binaries with CUDA 11.8.

@pranavmalikk
Copy link

pranavmalikk commented Feb 4, 2023

@pranavmalikk Yes, please use the nightly binaries with CUDA 11.8.

I'm still getting this on CUDA 11.8:

----> 1 torch.fft.rfft(torch.randn(1000).cuda())

RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR

@ptrblck
Copy link
Collaborator

ptrblck commented Feb 5, 2023

Could you post the output of torch.__version__ as well as python -m torch.utils.collect_env, please, as I cannot reproduce the error in the CUDA 11.8 nightly binaries anymore:

import torch
print(torch.__version__)
out = torch.fft.rfft(torch.randn(1000).cuda())
print(out.sum())

with 11.7 it fails as reported:

python tmp.pt 
2.0.0.dev20230204+cu117
Traceback (most recent call last):
  File "tmp.pt", line 3, in <module>
    torch.fft.rfft(torch.randn(1000).cuda())
RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR

Update the nightlies to the 11.8 build:

pip uninstall torch -y
Found existing installation: torch 2.0.0.dev20230204+cu117
Uninstalling torch-2.0.0.dev20230204+cu117:
  Successfully uninstalled torch-2.0.0.dev20230204+cu117
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
...
Successfully installed torch-2.0.0.dev20230204+cu118

and it works:

python tmp.py 
2.0.0.dev20230204+cu118
tensor(670.6870+11.1756j, device='cuda:0')

@pranavmalikk
Copy link

pranavmalikk commented Feb 5, 2023

Could you post the output of torch.__version__ as well as python -m torch.utils.collect_env, please, as I cannot reproduce the error in the CUDA 11.8 nightly binaries anymore:

import torch
print(torch.__version__)
out = torch.fft.rfft(torch.randn(1000).cuda())
print(out.sum())

with 11.7 it fails as reported:

python tmp.pt 
2.0.0.dev20230204+cu117
Traceback (most recent call last):
  File "tmp.pt", line 3, in <module>
    torch.fft.rfft(torch.randn(1000).cuda())
RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR

Update the nightlies to the 11.8 build:

pip uninstall torch -y
Found existing installation: torch 2.0.0.dev20230204+cu117
Uninstalling torch-2.0.0.dev20230204+cu117:
  Successfully uninstalled torch-2.0.0.dev20230204+cu117
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
...
Successfully installed torch-2.0.0.dev20230204+cu118

and it works:

python tmp.py 
2.0.0.dev20230204+cu118
tensor(670.6870+11.1756j, device='cuda:0')

Thank you for the help, it works now as i mistakenly did 'pip install torchaudio' so it set me back to an older version of torch. I fixed this by installing the nightly version of torchaudio

@bensonbs
Copy link

Due to package dependency issues, I am limited to using versions of PyTorch that are below 2.0.0. I understand that PyTorch 1.13.1 supports up to CUDA 11.7. Could you kindly advise if there are any alternative solutions apart from upgrading to CUDA 11.8?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: fft module: third_party triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.