Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

invalid device ordinal in pytorch1.8+cuda11.1 #54245

Closed
WZMIAOMIAO opened this issue Mar 18, 2021 · 18 comments
Closed

invalid device ordinal in pytorch1.8+cuda11.1 #54245

WZMIAOMIAO opened this issue Mar 18, 2021 · 18 comments
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@WZMIAOMIAO
Copy link

WZMIAOMIAO commented Mar 18, 2021

馃悰 Bug

To Reproduce

runing test.py:

import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn


def main():
    # get devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    init_img = torch.zeros((1, 3, 512, 512), device=device)
    target = [{"boxes": torch.as_tensor([[1, 2, 3, 4]], device=device),
               "labels": torch.as_tensor([1], device=device)}]
    model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False).to(device)
    model.train()
    model(init_img, target)


if __name__ == '__main__':
    main()

I can get following info:

using cuda:0 device.
Traceback (most recent call last):
  File "test.py", line 19, in <module>
    main()
  File "test.py", line 15, in main
    model(init_img, target)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py", line 97, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/detection/rpn.py", line 364, in forward
    loss_objectness, loss_rpn_box_reg = self.compute_loss(
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/detection/rpn.py", line 296, in compute_loss
    sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/detection/_utils.py", line 57, in __call__
    perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal

Expected behavior

There should be no mistake here.
If using pytorch1.6/1.7 GPU, there is no mistake.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0+cu111
  • OS (e.g., Linux): Centos7
  • How you installed PyTorch (conda, pip, source): pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
  • Build command you used (if compiling from source): No
  • Python version: 3.8
  • CUDA/cuDNN version: 11.1/8
  • GPU models and configuration: Tesla V100
  • Any other relevant information:

In addition, I tried to use the official docker image(pytorch/pytorch:1.8.0-cuda11.1-cudnn8-runtime), but still encountered the same problem.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @ngimel

@agolynski agolynski added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 18, 2021
@ngimel
Copy link
Collaborator

ngimel commented Mar 18, 2021

This is typically a symptom of the hardware problems. What doesnvidia-smi show? What is the output of torch.cuda.device_count()? Are you able to find any XIDs in dmesg and could post them here, if so?
cc @ptrblck

@iandecks45
Copy link

I have the same problem, error output and nvidia-smi output are added. The output of torch.cuda.device_count() is 1 in my case.

Screen Shot 2021-03-18 at 11 35 06 AM

Screen Shot 2021-03-18 at 11 35 27 AM

@ptrblck
Copy link
Collaborator

ptrblck commented Mar 18, 2021

I'm able to reproduce this issue using the 1.8.0 binaries.

However, the issue seems to be resolved by @zasdfgbnm's PR to replace thrust with cub in randperm, which landed on March 12th.
To verify it, I've used the nightly pip wheels with CUDA11.1 from March 12th and compared it to the wheels from March 13th on a TitanV and 2080Ti using the code snippet from this issue:

$ pip install torch==1.9.0.dev20210312 torchvision==0.9.0.dev20210312  -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html

$ python -c "import torch; print(torch.__version__); import torchvision; print(torchvision.__version__)"
1.9.0.dev20210312+cu111
0.9.0.dev20210312+cu111

$ CUDA_VISIBLE_DEVICES=0 python lala.py
...
RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal

$ CUDA_VISIBLE_DEVICES=1 python lala.py
...
RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal
$ pip install torch==1.9.0.dev20210313 torchvision==0.9.0.dev20210313  -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html

$ python -c "import torch; print(torch.__version__); import torchvision; print(torchvision.__version__)"
1.9.0.dev20210313+cu111
0.9.0.dev20210313+cu111

$ CUDA_VISIBLE_DEVICES=0 python lala.py
using cuda:0 device.

$ CUDA_VISIBLE_DEVICES=1 python lala.py
using cuda:0 device.

@malfet would it be possible to pick this PR for 1.8.1?

@malfet
Copy link
Contributor

malfet commented Mar 18, 2021

@ptrblck there are few more operators that use thrust::sort_by_key, do they need to be fixed as well:

$ grep thrust::sort_by_key aten -R
aten/src/THC/generic/THCTensorIndex.cu:  thrust::sort_by_key(
aten/src/THC/generic/THCTensorMode.cu:  thrust::sort_by_key(
aten/src/ATen/native/cuda/Unique.cu:      thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
aten/src/ATen/native/cuda/EmbeddingBag.cu:      thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
aten/src/ATen/native/cuda/Embedding.cu:        thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
aten/src/ATen/native/cuda/Indexing.cu:      thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, ThrustLTOp<int64_t>());
aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu:  thrust::sort_by_key(policy,

@ptrblck
Copy link
Collaborator

ptrblck commented Mar 18, 2021

I don't know it yet. I've ran some unit tests on torch.unique with different sizes and couldn't hit an error yet (using the thrust nightly). We can try to isolate the root cause for the initial error to understand, if the other mentioned methods could suffer the same issue (and/or move them to CUB as well).

@zasdfgbnm
Copy link
Collaborator

I am not sure if this is a thrust problem or a build problem. I manually built PyTorch 1.8 and PyTorch master with my thrust->cub PR (#53841) reverted, and none of them reproduces the failure. I also tried our container 20.12, and can not reproduce either. I can only reproduce the error with PyTorch wheels downloaded from the official site.

PS: minimum repro:

import torch
torch.randperm(159826, device='cuda')

@ngimel
Copy link
Collaborator

ngimel commented Mar 19, 2021

I also can't repro with source builds. There's #52663 that looks very similar and reproduces with master source build, but it also might be related to how extensions are built.

@malfet
Copy link
Contributor

malfet commented Mar 19, 2021

@zasdfgbnm what version of CUDA toolkit(up to a minor revision) are you using? And are you linking statically or not?

@zasdfgbnm
Copy link
Collaborator

@malfet I am using CUDA 11.2, linking dynamically. I also tried with CUDA 11.1 in our 20.12 container, with PyTorch dynamically linked, can not repro either.

@malfet
Copy link
Contributor

malfet commented Mar 19, 2021

@zasdfgbnm, can you paste output of nvcc -V here?
This is the one used by CI:

# nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0

@zasdfgbnm
Copy link
Collaborator

This is what I have

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Jan_28_19:32:09_PST_2021
Cuda compilation tools, release 11.2, V11.2.142
Build cuda_11.2.r11.2/compiler.29558016_0

@SixK
Copy link

SixK commented Mar 22, 2021

same problem here on GTX1660. (I also have error on GTX2080)
randperm work till 29999 value and crash at 30000 and over.

import torch
zz=torch.randperm(29999,device="cuda:0")
zz=torch.randperm(30000,device="cuda:0")

@ngimel ngimel removed the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 22, 2021
@0x4f5da2
Copy link
Contributor

the same issue encountered

Python 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.randperm(65535)
tensor([23965, 52215, 46436,  ..., 55898, 54460, 32672])
>>> torch.randperm(32, device=0)
tensor([ 4, 17, 19, 11, 16, 20,  5,  3, 21,  2, 31,  9,  6, 29, 27, 28,  0,  7,
        22, 15, 10, 12, 30,  8, 24, 26, 18,  1, 14, 13, 25, 23],
       device='cuda:0')
>>> torch.randperm(65535, device=0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: radix_sort: failed on 1st step: cudaErrorInvalidDevice: invalid device ordinal

here is the environment

Collecting environment information...
PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 9.1.85
GPU models and configuration: 
GPU 0: GeForce RTX 3090
GPU 1: GeForce RTX 3090

Nvidia driver version: 455.32.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.8.0
[pip3] torchaudio==0.8.0a0+a751e1d
[pip3] torchvision==0.9.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.1.1               h6406543_8    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2020.4             h726a3e6_304    conda-forge
[conda] mkl-service               2.3.0            py38h1e0a361_2    conda-forge
[conda] mkl_fft                   1.3.0            py38h5c078b8_1    conda-forge
[conda] mkl_random                1.2.0            py38hc5bc63f_1    conda-forge
[conda] numpy                     1.19.2           py38h54aff64_0  
[conda] numpy-base                1.19.2           py38hfa32c7d_0  
[conda] pytorch                   1.8.0           py3.8_cuda11.1_cudnn8.0.5_0    pytorch
[conda] torchaudio                0.8.0                      py38    pytorch
[conda] torchvision               0.9.0                py38_cu111    pytorch

@zou3519 zou3519 added this to the 1.8.1 milestone Mar 22, 2021
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Mar 22, 2021
@VoVAllen
Copy link

VoVAllen commented Mar 24, 2021

DGL project also encountered same problem before. The solution is that each library should use their own Thrust/CUB namespace.

The root cause is that CUB using static variable inside template function for device attribute caching. (at https://github.com/NVIDIA/cub/blob/f5ef160af684fcd00c76443c42a393cae5653f2e/cub/util_device.cuh)
That gcc will labeled the symbols as UNIQUE. This makes every library use the same cache instead of their own one, since UNIQUE label will break the RTLD_LOCAL setting.

 84822: 0000000005688500  1536 OBJECT  UNIQUE DEFAULT   30 _ZZN3cub26GetPerDeviceAttributeCacheINS_18PtxVersionCacheTagEEERNS_23PerDeviceAttributeCacheEvE5cache

Reference:

@ngimel
Copy link
Collaborator

ngimel commented Mar 28, 2021

This issue should be fixed in 1.8.1 release and in nightly binaries.

@WZMIAOMIAO
Copy link
Author

Thank you for your great work. I've seen 1.8.1 released on the official website.

@zhimengf
Copy link

is this really fixed? I still see the same issue with pytorch 1.9.0

@zhimengf
Copy link

fixed here: #52663

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests