Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using Resampler on GPU with torchaudio 0.9 #1619

Closed
mravanelli opened this issue Jun 29, 2021 · 5 comments
Closed

Error when using Resampler on GPU with torchaudio 0.9 #1619

mravanelli opened this issue Jun 29, 2021 · 5 comments

Comments

@mravanelli
Copy link

Hi,
it looks like there is a small issue when using the resampler with torchaudio 0.9.
I indeed got the error below, when running the following code:

import torch
import torchaudio
resampler = torchaudio.transforms.Resample(16000, 8000)
x = torch.rand(4, 32000, device='cuda:0')
resampler(x)
  File "<stdin>", line 1, in <module>
  File "/home/mila/r/ravanelm/anaconda3/envs/py19/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mila/r/ravanelm/anaconda3/envs/py19/lib/python3.9/site-packages/torchaudio/transforms.py", line 716, in forward
    return _apply_sinc_resample_kernel(
  File "/home/mila/r/ravanelm/anaconda3/envs/py19/lib/python3.9/site-packages/torchaudio/functional/functional.py", line 1387, in _apply_sinc_resample_kernel
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the sam

Note that everything is fine with cpu and with torchaudio 0.8.1

@yoyololicon
Copy link
Collaborator

yoyololicon commented Jun 30, 2021

@mravanelli
You didn't move your resampler to the same device.
The Resample module in 0.9 now cache kernel weights within it. Use resampler = torchaudio.transforms.Resample(16000, 8000).to('cuda:0') should fix the problem.

@mravanelli
Copy link
Author

Hi @yoyololicon,
thank you for your reply.
The issue is that with torchaudio 0.8 the error doesn't show up and this might cause incompatibilities with pre-existing codes written with torchaudio 0.8. Isn't it reasonable to put the kernel in the same device as the input signal (as probably done before)?

@yoyololicon
Copy link
Collaborator

yoyololicon commented Jun 30, 2021

Hi @mravanelli
Yes, it was meant to be a BC-breaking change according to #1514, I'll just paste some of the statements here.

To maintain the original behavior would require moving the cached kernel to the device and dtype of the input waveform every time transforms.Resample is called, and this could result in unintended side effects; running a CPU transforms on CUDA inputs would have overhead from constantly moving a cached CPU kernel to CUDA, and the expectation is that the user will manually move the transforms themself to the correct device and dtype (ex/ resample = transforms.Resample; resample = resample.to(device=torch.device('cuda'), dtype=torch.float16). This PR removes the moving of the cached kernel to be of the correct device and dtype at every call to it, which will now throw an error if a user does not move the transform to CUDA but calls the function on a CUDA waveform.

@mravanelli
Copy link
Author

mravanelli commented Jun 30, 2021 via email

@mthrok
Copy link
Collaborator

mthrok commented Jun 30, 2021

We have added kernel cache mechanism to transforms.Resample in 0.9.
The kernel is pre-computed when the transform is instantiated, and it has to be moved to the proper device. This semantic follows the regular usage of torch.nn.Module.

Looks like this was missing from 0.9 release note, I will fix the release note.

cc @carolineechen

@mthrok mthrok closed this as completed Jul 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants