Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchaudio.transforms.GriffinLim does not allow backpropagation #729

Closed
pplantinga opened this issue Jun 18, 2020 · 0 comments 路 Fixed by #730
Closed

torchaudio.transforms.GriffinLim does not allow backpropagation #729

pplantinga opened this issue Jun 18, 2020 · 0 comments 路 Fixed by #730

Comments

@pplantinga
Copy link
Contributor

馃悰 Bug

I cannot backpropagate through GriffinLim due to in-place operations being used.

To Reproduce

Minimal example:

import torch
import torchaudio.transforms

x = torch.rand(1, 16000)
spec = torch.stft(x, n_fft=512, hop_length=160, win_length=320)
mag = spec.pow(2).sum(-1).transpose(1, -1)
model = torch.nn.Linear(257, 257)
out = model(mag).transpose(1, -1)
resyn = torchaudio.transforms.GriffinLim(n_fft=512, hop_length=160, win_length=320)
wav = resyn(out)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss = torch.nn.functional.mse_loss(wav, x)
loss.backward()
optimizer.step()

Results in:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 257, 101, 2]], which is output 0 of SubBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

When using the suggested torch.autograd.set_detect_anomaly(True), the last few items on the call stack:

  File "/network/home/plantinp/pytorch/lib/python3.7/site-packages/torchaudio/transforms.py", line 161, in forward
    self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
  File "/network/home/plantinp/pytorch/lib/python3.7/site-packages/torchaudio/functional.py", line 373, in griffinlim
    angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles))
  File "/network/home/plantinp/pytorch/lib/python3.7/site-packages/torchaudio/functional.py", line 581, in complex_norm
    return torch.norm(complex_tensor, 2, -1)
  File "/network/home/plantinp/pytorch/lib/python3.7/site-packages/torch/functional.py", line 882, in norm
    return _VF.norm(input, p, _dim, keepdim=keepdim)
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:60)

Expected behavior

Should allow backprop. Removing in-place ops in line 373 of functional.py seems to do the trick:

angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles))

Environment

  • PyTorch Version (e.g., 1.0): 1.5
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.8
  • CUDA/cuDNN version: 10.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant