Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

stft does not consistently check window device #30865

Closed
seungwonpark opened this issue Dec 6, 2019 · 5 comments
Closed

stft does not consistently check window device #30865

seungwonpark opened this issue Dec 6, 2019 · 5 comments
Assignees
Labels
module: error checking Bugs related to incorrect/lacking error checking module: fft triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@seungwonpark
Copy link
Contributor

seungwonpark commented Dec 6, 2019

馃悰 Bug

Consider a situation when applying torch.stft to audio x, which is in CUDA memory. We also have a given tensor window, which is not in CUDA memory yet.
When n_fft is equal to win_length, it causes RuntimeError: expected device cuda:0 but got device cpu.

To Reproduce

See my Google Colab notebook:
https://colab.research.google.com/drive/15ZOc5SnFwXsb3-vgIbzdd2roeY6kOV2H

import torch
import librosa

print(torch.__version__) # 1.3.1
print(librosa.__version__) # 0.6.3

x, sr = librosa.load(librosa.util.example_audio_file(), offset=15.0, duration=5.0)
print(x.shape) # (110250,)
x = torch.from_numpy(x).cuda()
window = torch.hann_window(window_length=400)

torch.stft(x, n_fft=512, hop_length=160, win_length=400, window=window).shape # torch.Size([257, 690, 2])

torch.stft(x, n_fft=400, hop_length=160, win_length=400, window=window).shape # RuntimeError: expected device cuda:0 but got device cpu

Expected behavior

No error is expected here.

cc @mruberry @peterbell10

@ssnl
Copy link
Collaborator

ssnl commented Dec 6, 2019

Ah the error message is indeed confusing. The issue is that your window is still on CPU.

@ssnl ssnl changed the title torch.stft device RuntimeError when n_fft = win_length stft does not consistently check window device Dec 6, 2019
@ssnl ssnl added the module: error checking Bugs related to incorrect/lacking error checking label Dec 6, 2019
@gchanan gchanan added module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 6, 2019
@gchanan gchanan assigned gchanan and unassigned gchanan Dec 19, 2019
@NasirKhalid24
Copy link

NasirKhalid24 commented Jul 22, 2020

currently facing the same issue - any fix planned?

The same issue seems to exist in torchaudio.transforms.MelSpectrogram

@LearnedVector
Copy link

bump. any news on the issue? the same thing happening with me when n_fft == to win_length

@peterbell10
Copy link
Collaborator

peterbell10 commented Oct 8, 2020

As of #43886, both calls now raise an error:

>>> torch.stft(x, n_fft=512, hop_length=160, win_length=400, window=window).shape
...
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

>>> torch.stft(x, n_fft=400, hop_length=160, win_length=400, window=window).shape
...
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I believe this is because of the diff here: https://github.com/pytorch/pytorch/pull/43886/files#diff-8189529460e7e7e3cbd932342fd61d76L243-R280.

The n_fft-length intermediate window_ was being created on the GPU because it used self.options() but is now using window.options() and so is created on the CPU. Then the error happens (or doesn't) when the input is multiplied by window_.

I'm not sure if the original device-moving behavior was actually intentional though. As far as I'm aware, most PyTorch functions require inputs to be on the same device. The fix from the user's side would be to pass in window.cuda().

@peterbell10 peterbell10 self-assigned this Oct 8, 2020
@mruberry
Copy link
Collaborator

mruberry commented Oct 8, 2020

Thanks for the update @peterbell10, sounds like this issue has been fixed.

I'm not sure if the original device-moving behavior was actually intentional though. As far as I'm aware, most PyTorch functions require inputs to be on the same device. The fix from the user's side would be to pass in window.cuda().

Right, we don't like to implicitly move tensors from one device to another. Data movement in PyTorch should always be explicit like, as you point out, a user calling window.cuda().

@mruberry mruberry closed this as completed Oct 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: error checking Bugs related to incorrect/lacking error checking module: fft triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants