Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP16 inference with Cuda 11.1 returns NaN on Nvidia GTX 1660 #58123

Closed
hamishc opened this issue May 12, 2021 · 15 comments
Closed

FP16 inference with Cuda 11.1 returns NaN on Nvidia GTX 1660 #58123

hamishc opened this issue May 12, 2021 · 15 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hamishc
Copy link

hamishc commented May 12, 2021

馃悰 Bug

Half precision inference returns NaNs for a number of models when run on a 1660 with Cuda 11.1

To Reproduce

import torch
import urllib
from PIL import Image
from torchvision import transforms

model = torch.hub.load('pytorch/vision:v0.9.0', 'mobilenet_v2', pretrained=True)
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)

model = model.cuda().float()
input_batch = input_batch.cuda().float()

with torch.no_grad():
    output = model(input_batch)
print("FP32 output:", output)

model = model.cuda().half()
input_batch = input_batch.cuda().half()

with torch.no_grad():
    output = model(input_batch)
print("FP16 output:", output)

Expected behavior

I would expect the output from each to be approximately the same for FP16 as with FP32, but the above script produces (truncated for clarity):

FP32 output: tensor([[-1.8283e+00, -1.4972e+00, -1.1716e+00, ..., -5.8914e-01,  9.7267e-01,  1.9510e+00]],
       device='cuda:0')

FP16 output: tensor([[nan, nan, nan, ..., nan, nan, nan, nan, nan]],
       device='cuda:0', dtype=torch.float16)

Testing with other nets produces the same results. e.g.

import torch

model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
input_tensor = torch.rand((1,3,512,512))

model = model.cuda().half()
input_tensor = input_tensor.cuda().half()
with torch.no_grad():
    output = model(input_tensor)

print("FP16 output:", output)

model = model.cuda().float()
input_tensor = input_tensor.cuda().float()
with torch.no_grad():
    output = model(input_tensor)

print("FP32 output:", output)

Running this same test on other GPUs (e.g. 1080ti, 2080ti) provides valid fp16 output on each, while testing on another machine with a 1660 produced the same results as above.

Similarly, the same test on a 1660 with torch 1.8.1 / cuda 10.2 is unable to reproduce the issue.

Environment

Collecting environment information...
PyTorch version: 1.8.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.2 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.20.2

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1660
Nvidia driver version: 465.19.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.8.1+cu111
[pip3] torchvision==0.9.1+cu111
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.1.74              h6bb024c_0    nvidia
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.2.0           h06a4308_296  
[conda] mkl-service               2.3.0            py39h27cfd23_1  
[conda] mkl_fft                   1.3.0            py39h42c9631_2  
[conda] mkl_random                1.2.1            py39ha9443f7_2  
[conda] numpy                     1.20.1           py39h93e21f0_0  
[conda] numpy-base                1.20.1           py39h7d8b39e_0  
[conda] pytorch                   1.8.1           py3.9_cuda11.1_cudnn8.0.5_0    pytorch
[conda] torchaudio                0.8.1                      py39    pytorch
[conda] torchvision               0.9.1                py39_cu111    pytorch

Pytorch 1.8.1+cu111 was installed with pip / python 3.7 from the following wheel:
https://download.pytorch.org/whl/cu111/torch-1.8.1%2Bcu111-cp37-cp37m-linux_x86_64.whl

Additional context

cc @ngimel @csarofeen @ptrblck @xwang233

@ptrblck
Copy link
Collaborator

ptrblck commented May 12, 2021

I'm unable to reproduce the NaN outputs on a 2080Ti (same compute capability sm_75 as your 1660) using the PyTorch 1.8.1 pip wheels with CUDA11.1 and get valid outputs for multiple runs.
Are you seeing this issue in every run?

@hamishc
Copy link
Author

hamishc commented May 12, 2021

Yes, every run and irrespective of the input values. I have tested with a 2080ti and a 1080ti and found valid outputs as well - I was only able to produce the NaN output on both a 1660 and a 1660ti.

@hamishc
Copy link
Author

hamishc commented May 12, 2021

I've managed a minimal reproducible example with just Conv2d operations - see below:

input_tensor = torch.rand((1,32,64,64)).half().cuda()
conv = torch.nn.Conv2d(32, 64, kernel_size=(3, 3)).half().cuda()
conv2 = torch.nn.Conv2d(64, 128, kernel_size=(3, 3)).half().cuda()
conv3 = torch.nn.Conv2d(128, 256, kernel_size=(3, 3)).half().cuda()
x = conv(input_tensor)
x = conv2(x)
x = conv3(x)
print(x)

produces

tensor([[[[     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          ...,
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan]],

         [[-0.02852, -0.02852, -0.02852,  ..., -0.02852, -0.02852, -0.02852],
          [-0.02852, -0.02852, -0.02852,  ..., -0.02852, -0.02852, -0.02852],
          [-0.02852, -0.02852, -0.02852,  ..., -0.02852, -0.02852, -0.02852],
          ...,
          [-0.02852, -0.02852, -0.02852,  ..., -0.02852, -0.02852, -0.02852],
          [-0.02852, -0.02852, -0.02852,  ..., -0.02852, -0.02852, -0.02852],
          [-0.02852, -0.02852, -0.02852,  ..., -0.02852, -0.02852, -0.02852]],

         [[     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          ...,
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan]],

         ...,

         [[ 0.00100,  0.00100,  0.00100,  ...,  0.00100,  0.00100,  0.00100],
          [ 0.00100,  0.00100,  0.00100,  ...,  0.00100,  0.00100,  0.00100],
          [ 0.00100,  0.00100,  0.00100,  ...,  0.00100,  0.00100,  0.00100],
          ...,
          [ 0.00100,  0.00100,  0.00100,  ...,  0.00100,  0.00100,  0.00100],
          [ 0.00100,  0.00100,  0.00100,  ...,  0.00100,  0.00100,  0.00100],
          [ 0.00100,  0.00100,  0.00100,  ...,  0.00100,  0.00100,  0.00100]],

         [[     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          ...,
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan]],

         [[ 0.00358,  0.00358,  0.00358,  ...,  0.00358,  0.00358,  0.00358],
          [ 0.00358,  0.00358,  0.00358,  ...,  0.00358,  0.00358,  0.00358],
          [ 0.00358,  0.00358,  0.00358,  ...,  0.00358,  0.00358,  0.00358],
          ...,
          [ 0.00358,  0.00358,  0.00358,  ...,  0.00358,  0.00358,  0.00358],
          [ 0.00358,  0.00358,  0.00358,  ...,  0.00358,  0.00358,  0.00358],
          [ 0.00358,  0.00358,  0.00358,  ...,  0.00358,  0.00358,  0.00358]]]], device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)

@ptrblck
Copy link
Collaborator

ptrblck commented May 12, 2021

Thank you for the follow up!
I was able to reproduce the invalid outputs on a 1660Ti and will forward the api logs to cudnn.

@mruberry mruberry added module: cuda Related to torch.cuda, and CUDA support in general module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 12, 2021
@ngimel ngimel added the module: cudnn Related to torch.backends.cudnn, and CuDNN support label May 12, 2021
@ptrblck
Copy link
Collaborator

ptrblck commented Jun 19, 2021

@hamishc the fix will land in the upcoming cudnn8.2.2 release (just verified it).
Thanks again for reporting.

@illtellyoulater
Copy link

illtellyoulater commented Mar 23, 2022

@ptrblck I'm on cudnn v8.3.0.2 and I'm still seeing this.
Please see details here: openai/glide-text2im#31
Any ideas on how is this possible and on how to fix it?

@ptrblck
Copy link
Collaborator

ptrblck commented Mar 23, 2022

@illtellyoulater let me try to re-verify the fix and try to reproduce your issue to see if you are running into the same or a new issue. I'll ping you in the linked issue in case I need help in reproducing it.

@monsieurpooh
Copy link

Hello, I've been debugging an issue for literally almost a week full-time now; a couple of users (both gtx 1660) somehow can't run VQGAN-CLIP powered by pytorch because some operations such as nn.Conv2D or nn.MultiheadAttention will give "nan" in certain situations that wouldn't happen on other machines. And, since I used pyinstaller to package the whole thing into a giant directory, it's highly unlikely to be caused by different library code. I googled pytorch gtx 1660 in desperation. This seems to be related.

@monsieurpooh
Copy link

Updating above cuDNN 8.2.2 was sufficient to fix my issue even with cuda toolkit 11.3. It was not necessary to downgrade CUDA toolkit

@YipKo
Copy link

YipKo commented May 20, 2022

Same Problem here. (also same hardware)
@monsieurpooh
@illtellyoulater
Unlike you guys, I have tried pytorch with cuda version 11.5 (whose cudnn version is 8.3.0>8.2.2) and also tried downloading cuDNN from nvidia and copy/paste the dll files into the relevant folder in torch/lib , the problem can not be solved

@andreszs
Copy link

Just to clarify, as of 2023 this topic remains unsolved and no one can run Stable Difussion using the supposedly FP16-specially dedicated tensor cores provided by this card, or are there some news in this regard?

These FP16 cores are brand new to Turing Minor, and have not appeared in any past NVIDIA GPU architecture. Their purpose is functionally the same as running FP16 operations through the tensor cores on Turing Major: to allow NVIDIA to dual-issue FP16 operations alongside FP32 or INT32 operations within each SM partition. And because they are just FP16 cores, they are quite small. NVIDIA isn鈥檛 giving specifics, but going by throughput alone they should be a fraction of the size of the tensor cores they replace.

I'm running SD with the crippling --precision full --no-half parameters as shown here, otherwise only black images are generated. It's well documented that these parameters increase VRAM usage and make SD slower.

@Zombero
Copy link

Zombero commented Aug 20, 2023

@ptrblck

I was able to isolate this problem to PyTorch using the following test cases on an NVidia GTX 1660 Super machine:

(1) torch==2.0.1+cu117 torchvision==0.15.2+cu117
(2) torch==1.13.1+cu117 torchvision==0.14.1+cu117

Without "--no-half", scenario (1) produces NaNs and scenario (2) does not produce NaNs.

I've also tried doing scenario (1) with the current version of cudnn (8.9.4.25) but still got NaNs.

I would guess that this issue would originate from PyTorch assuming that tensor cores are being used based on the fact that half-precision is being used, as 16xx cards are the only(?) scenario where that assumption would be incorrect. Only a guess, though.

@Zombero
Copy link

Zombero commented Aug 22, 2023

@andreszs
After much research and fiddling, I believe I've arrived at the sad truth about these TU116 cards. Yes, they have dedicated FP16 CUDA cores, but only 128 of its 1408 cores are these dedicated FP16 cores.

I have found several ways around the blank images issue when not using "--no-half", but all of them increase the time it takes to generate images by at least 100%. This does have some use, as it reduces VRAM usage by almost 2GB during generation and may better enable you to play games and such during generation, but it's mostly bad. And if the game wants to use your fp16 cores, it may be bad for that, too.

Still, ideally, TU116 cards would be able to run without "--no-half", never produce NaNs in any configuration, and receive a small speed boost and VRAM saving compared to running with "--no-half", but I can see how mixing this small quantity of fp16 cores with a larger quantity of fp32 cores in an efficient way might be difficult for PyTorch and other programs to do...

@AlexeyVeselov
Copy link

@Zombero

Could you please tell me how you managed to accomplish the fix?

I'm an unhappy 1660 user, and I'm willing to wait an extended period of time until I can afford a good video card if it will at least just work.

@Zombero
Copy link

Zombero commented Jan 29, 2024

@AlexeyVeselov

Well, to be clear, the avenue I would most recommend is to use "--no-half" and stick to dimensions that 1660 Super can handle (512x768 with a 2x upscale is an example that should work). But if you're looking to force it to use the 128 FP16 cores, then it's all about having the right torch and torchvision versions and then not using "--no-half". One combination that does this is:

torch==1.13.1+cu117 torchvision==0.14.1+cu117

You may be able to get newer versions of torch/torchvision to work without "--no-half" by manually updating your cudnn version, but you'd have to try it out.

For a more industrious option: I wound up getting a Tesla M40 as a 2nd GPU along with a cooling fan for it (since it doesn't come with one) for < $150 USD. Though you'd have to make sure your rig can support it, fit it, power it, etc. It was a pain to setup initially, but I haven't had any problems since then and can even generate images using both cards at once.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants