-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.utils.data.random_split crashes without an error message with non CPU Generator object #44714
Comments
Root cause is
|
I suppose pytorch/aten/src/ATen/native/TensorFactories.cpp Lines 717 to 720 in 7e91728
should be modified to check (and allow?) cuda generator before creating the tensor. |
@ssnl It would be better to do the check inside |
@ezyang It depends on whether we want to allow |
@ssnl I just understood what your comment here meant. Let me try to elaborate it for the benefit of @janeyx99 . The most basic version of this bug that needs to be fixed is that we allow you to do this: However, there is a more subtle design consideration: I think that it is reasonable to support this, but I think it would be more complicated to do correctly, and we should just fix the first bug first. |
@ezyang Thank you for the clarification. When you say all other variants of the first case, is it when the generator is of a different device than the one specified or is it for any
|
Oh ok, maybe the problem runs more deep than I thought. Some investigation sounds necessary :) |
Update: After some investigation, I believe this is the reason the above happens. pytorch/aten/src/ATen/native/cuda/TensorFactories.cu Lines 91 to 95 in 7e91728
In the cuda implementation of To confirm this, the following code works as expected:
So now the question is: do we still want to offload to the CPU for small inputs when there is a generator is defined? |
Oh no... even if we do not offload to CPU when generator is defined, this means nondeterministism when CUDA seed is set... |
>>> torch.cuda.manual_seed_all(12)
>>> x = torch.randperm(3, device='cuda')
>>> x
tensor([2, 0, 1], device='cuda:0')
>>> torch.cuda.manual_seed_all(12)
>>> x = torch.randperm(3, device='cuda')
>>> x
tensor([0, 1, 2], device='cuda:0')
>>> torch.cuda.manual_seed_all(12)
>>> x = torch.randperm(30003, device='cuda')
>>> x[:5]
tensor([23025, 28065, 12737, 1352, 2876], device='cuda:0')
>>> torch.cuda.manual_seed_all(12)
>>> x = torch.randperm(30003, device='cuda')
>>> x[:5]
tensor([23025, 28065, 12737, 1352, 2876], device='cuda:0') |
Hm so what would be the right thing to do here? I'll submit a PR for a quick fix to the first issue, but I'm not quite sure how to handle the small tensor with cuda generator case. |
This seems tricky, and related to #46148 . cc @mcarilli I guess, hypothetically, because we currently store CUDA rng state on cpu, we could have some way of using the CUDA state to feed the CPU generator in this case. This would be hard to do if CUDA rng state lived entirely on GPU, which seems like a better end term state. But then I think we're just out of luck, without doing a sync to get the rng state to CPU. But maybe this is fine. |
Silent offload to CPU for N<30000 is also annoying for me right this second because CPU work won't be captured by cuda graphs...
There are calls to retrieve relevant Philox state values from CUDA generators which could seed a CPU generator, but only if the CPU generator also uses philox. The GPU-state maintenance added by my PR provides the same interface to retrieve Philox state values. It syncs as needed under the hood so usage doesn't need to change at all in the caller, but the fact that it needs to sync is annoying. It could simultaneously maintain dummy values on the CPU as well, and update them alongside GPU-side state tensors, for syncfree retrieval of the state values...but that would break under cuda graphs, because they'd elide cpu maintenance of dummy values. What's the performance delta between CPU and GPU randperm for, say, 10000 elements? If the delta is negligible, just run on GPU all the time. Even if the delta is significant, if the runtime is negligible to begin with, it might be worth running on GPU all the time for simplicity. |
Why did this issue got closed? I also ran into this issue today, specifying for Are there any known work-arounds for users at least? I could not find any in this issue? |
@denmerichs What version of PyTorch are you running? If you don't have a mismatch between the generator device and the requested device, this is likely a different bug, please file a new issue for this. |
Hi, @ezyang /usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py in iter(self) RuntimeError: Expected a 'cuda' device type for generator but found 'cpu' It is worth noting that I never met this Error and run the same code successfully before tonight... ... |
@Kae1101 train_loader = DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=0,pin_memory=False, generator=torch.Generator(device='cuda')) |
@NLQVan torch.set_default_tensor_type(torch.cuda.FloatTensor)does dataiter.next() return cpu.floatTensor by default? If it does, I think that is why the error was reported... ... But I still confused because I ran the same code may be more than 50 times a week before 2021/06/19 and didn't get any errors like this. |
My code also ran fine before 19/06/2021, maybe the library of torch was changed something and we didn't know. I also try to fix my code by comment the command "torch.set_default_tensor_type(torch.cuda.FloatTensor)", but my model get another error, it like: "found 2 type of device: cuda:0 and !cpu", my solutions above fixed this error in my code. |
Thanks. I was having the same issue and downgrading to 1.8.1 (March release) from 1.9.0 (June release, current latest) fixes this. |
🐛 Bug
Non CPU generator objects cause
torch.utils.data.random_split
to fail without any error messageTo Reproduce
Steps to reproduce the behavior:
torch.utils.data.random_split
function.Expected behavior
The device type of the Generator object either shouldn't affect
torch.utils.data.random_split
or an error message should be thrown.Environment
PyTorch version: 1.6.0+cu101
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
Additional context
The above is from Google Colab (the instance crashed when I ran the test code), and I can also confirm the issue is present on Windows as well.
cc @ezyang @gchanan @zou3519 @pbelevich
The text was updated successfully, but these errors were encountered: