Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

replication_pad1d raising "CUDA error: invalid configuration argument" on large inputs #49601

Closed
jcaw opened this issue Dec 18, 2020 · 10 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jcaw
Copy link

jcaw commented Dec 18, 2020

馃悰 Bug

torchaudio.functional.compute_deltas is raising a CUDA error: invalid configuration argument when the batch size is too large. (Edit: the underlying issue coming from replication_pad1d)

To Reproduce

Steps to reproduce the behavior:

  1. On my GTX 970, calling compute_deltas with a large enough spectrogram and batch size triggers this error, e.g. compute_deltas(torch.rand([64, 2, 1000, 1000], device="cuda")).

Googling other triggers for this error, it seems this is usually caused by exceeding the maximum CUDA block size (e.g. here). Since this varies by GPU, I'm not sure whether it will reproduce on newer cards.

Examples:

>>> compute_deltas(torch.rand([64, 2, 1000, 1000], device="cuda"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jcaw/opt/miniconda3/envs/fastaudio-dev/lib/python3.7/site-packages/torchaudio/functional.py", line 1623, in compute_deltas
    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
  File "/home/jcaw/opt/miniconda3/envs/fastaudio-dev/lib/python3.7/site-packages/torch/nn/functional.py", line 3561, in _pad
    return torch._C._nn.replication_pad1d(input, pad)
RuntimeError: CUDA error: invalid configuration argument

This appears to be distinct from an outright out-of-memory error. If I push the parameters further:

>>> compute_deltas(torch.rand([64, 2, 2000, 2000], device="cuda"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jcaw/opt/miniconda3/envs/fastaudio-dev/lib/python3.7/site-packages/torchaudio/functional.py", line 1623, in compute_deltas
    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
  File "/home/jcaw/opt/miniconda3/envs/fastaudio-dev/lib/python3.7/site-packages/torch/nn/functional.py", line 3561, in _pad
    return torch._C._nn.replication_pad1d(input, pad)
RuntimeError: CUDA out of memory. Tried to allocate 1.91 GiB (GPU 0; 3.95 GiB total capacity; 2.15 GiB already allocated; 518.06 MiB free; 2.39 GiB reserved in total by PyTorch)

Reducing the size of the input tensor solves the issue:

>>> compute_deltas(torch.rand([32, 2, 1000, 1000], device="cuda"))
tensor(...)
>>> compute_deltas(torch.rand([64, 2, 500, 1000], device="cuda"))
tensor(...)

Expected behavior

compute_deltas should return successfully (or explicitly produce an out-of-memory error, if that's the real issue).

Environment

PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Manjaro Linux (x86_64)
GCC version: (GCC) 10.2.0
Clang version: 11.0.0
CMake version: version 3.19.1

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 970
Nvidia driver version: 455.45.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.4
[pip3] torch==1.7.0
[pip3] torch-audiomentations==0.4.0
[pip3] torchaudio==0.7.0
[pip3] torchvision==0.8.1
[conda] blas                      1.0                         mkl    anaconda
[conda] mkl                       2020.2                      256    anaconda
[conda] mkl-service               2.3.0            py37he8ac12f_0  
[conda] mkl_fft                   1.2.0            py37h23d657b_0  
[conda] mkl_random                1.1.1            py37h0573a6f_0    anaconda
[conda] numpy                     1.19.4                   pypi_0    pypi
[conda] numpy-base                1.19.2           py37hfa32c7d_0  
[conda] torch                     1.7.0                    pypi_0    pypi
[conda] torch-audiomentations     0.4.0                    pypi_0    pypi
[conda] torchaudio                0.7.0                    pypi_0    pypi
[conda] torchvision               0.8.1                    pypi_0    pypi

Additional context

It appears to be CUDA-specific. I can't trigger this error on CPU.

cc @ngimel

@mthrok
Copy link
Contributor

mthrok commented Dec 18, 2020

Hi @jcaw

Thanks for the report and the detailed description. Looks like the fix should be applied in PyTorch, I will move the issue to PyTorch.

@mthrok mthrok changed the title compute_deltas raising "CUDA error: invalid configuration argument" on large inputs replication_pad1d raising "CUDA error: invalid configuration argument" on large inputs Dec 18, 2020
@mthrok mthrok transferred this issue from pytorch/audio Dec 18, 2020
@VitalyFedyunin VitalyFedyunin added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 20, 2020
@ngimel
Copy link
Collaborator

ngimel commented Dec 20, 2020

Can you please provide a reproducer with replication_pad1d instead of compute_deltas?

@jcaw
Copy link
Author

jcaw commented Jan 7, 2021

Sure. It seems that compute_deltas packs all leading dimensions into the freq dimension, e.g. (batch, channels, freq, time) -> (batch * channels * freq, time). This creates a tensor that's very large in one dimension (rather than just large overall) which seems to be the issue.

This will trigger it on my 970:

>>> torch._C._nn.replication_pad1d(torch.rand([100000, 1000], device="cuda"), 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-9ed6fa8428fe> in <module>
----> 1 torch._C._nn.replication_pad1d(torch.rand([100000, 1000], device="cuda"), 3)

RuntimeError: CUDA error: invalid configuration argument

Pushing it further will eventually cause OOM:

>>> torch._C._nn.replication_pad1d(torch.rand([500000, 1000], device="cuda"), 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-6c910ea7d7c8> in <module>
----> 1 torch._C._nn.replication_pad1d(torch.rand([500000, 1000], device="cuda"), 3)

RuntimeError: CUDA out of memory. Tried to allocate 1.86 GiB (GPU 0; 3.95 GiB total capacity; 381.70 MiB already allocated; 1.17 GiB free; 1.86 GiB reserved in total by PyTorch)

If I balance the dimensions, there's no issue:

>>> torch._C._nn.replication_pad1d(torch.rand([10000, 10000], device="cuda"), 3)
tensor([[0.8374, 0.8374, 0.8374,  ..., 0.9226, 0.9226, 0.9226],
        [0.8571, 0.8571, 0.8571,  ..., 0.2462, 0.2462, 0.2462],
        [0.0252, 0.0252, 0.0252,  ..., 0.7778, 0.7778, 0.7778],
        ...,
        [0.0578, 0.0578, 0.0578,  ..., 0.3262, 0.3262, 0.3262],
        [0.4410, 0.4410, 0.4410,  ..., 0.1778, 0.1778, 0.1778],
        [0.1558, 0.1558, 0.1558,  ..., 0.4674, 0.4674, 0.4674]],
       device='cuda:0')

@jcaw
Copy link
Author

jcaw commented Jan 7, 2021

Ok, a bit more digging. This triggers the error:

>>> torch._C._nn.replication_pad1d(torch.rand([65536, 1], device="cuda"), 3)
...
RuntimeError: CUDA error: invalid configuration argument

But 65535 is fine:

>>> torch._C._nn.replication_pad1d(torch.rand([65535, 1], device="cuda"), 3)
tensor([[0.9838, 0.9838, 0.9838,  ..., 0.9838, 0.9838, 0.9838],
        ...,
        [0.8009, 0.8009, 0.8009,  ..., 0.8009, 0.8009, 0.8009]],
       device='cuda:0')

Is something exceeding a 16 bit limit?

Interestingly, if I change the order of the dimensions things also work fine, with much higher numbers:

>>> torch._C._nn.replication_pad1d(torch.rand([1, 10000000], device="cuda"), 3)
tensor([[0.9373, 0.9373, 0.9373,  ..., 0.4218, 0.4218, 0.4218]],
       device='cuda:0')

@ngimel
Copy link
Collaborator

ngimel commented Jan 7, 2021

Thanks for a reproduction! @xwang233 can you please look into this?

@jcaw
Copy link
Author

jcaw commented Jan 14, 2021

I've reproduced the same bug on a Tesla T4 on Google Colab, which might be easier to access.

@xwang233
Copy link
Collaborator

xwang233 commented Jan 15, 2021

@ngimel I have submitted a tentative fix at #50565.

Just btw, this cuda launch configuration problem also exists for reflection padding.

cc @ptrblck

@jwyyy
Copy link

jwyyy commented Apr 21, 2021

@xwang233 having the same issue in reflection padding. Has it been fixed?

@xwang233
Copy link
Collaborator

@jwyyy Please see #55222 and #56451. It will be fixed after the PR gets merged.

@jwyyy
Copy link

jwyyy commented Apr 21, 2021

@xwang233 Thanks! Waiting for the final merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants