-
Notifications
You must be signed in to change notification settings - Fork 634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backprop support for lfilter #704
Comments
Hi @FBMachine Can you provide a snippet that demonstrates I was trying to reproduce the issue but putting Or, are you saying that you would like to compute gradient for coefficients for import torch
import torchaudio.functional as F
class Net(torch.nn.Module):
def __init__(self, a, b):
super().__init__()
self.a = torch.nn.Parameter(a, requires_grad=True)
self.b = torch.nn.Parameter(b, requires_grad=True)
def forward(self, x):
return F.lfilter(x, self.a, self.b)
def test(device, dtype):
net = Net(
a=torch.tensor([0., 0., 0., 1.]),
b=torch.tensor([1., 0., 0., 0.]),
).to(device=device, dtype=dtype)
x = torch.rand(2, 8000, dtype=dtype, device=device, requires_grad=True)
y = net(x)
net.zero_grad()
y.backward(torch.randn_like(y))
print('a_grad:', net.a.grad)
print('b_grad:', net.b.grad)
print('x_grad:', x.grad)
for device in ['cpu', 'cuda']:
for dtype in [torch.float32, torch.float64]:
print(f'Running {device}, {dtype}')
test(device, dtype) $ python foo.py
Running cpu, torch.float32
a_grad: None
b_grad: None
x_grad: tensor([[ 2.2886, 0.6447, -0.0231, ..., -2.0171, 0.3783, 3.4622],
[ 0.0278, -0.1908, -0.8077, ..., 0.9406, -0.0560, -0.6732]])
Running cpu, torch.float64
a_grad: None
b_grad: None
x_grad: tensor([[ 0.3589, -0.4542, 0.2553, ..., 0.0147, -1.3429, 0.7961],
[-1.2413, 1.7650, -0.3808, ..., 1.8582, -1.2257, -0.2102]],
dtype=torch.float64)
Running cuda, torch.float32
a_grad: None
b_grad: None
x_grad: tensor([[-0.8601, 1.1020, 0.7039, ..., 0.7219, -0.0040, -1.4189],
[ 1.6594, -0.5011, 1.3873, ..., 1.1267, 0.8386, 0.9974]],
device='cuda:0')
Running cuda, torch.float64
a_grad: None
b_grad: None
x_grad: tensor([[-0.5385, -1.4356, -0.9297, ..., -1.2368, -0.7705, -0.5666],
[-1.3023, 0.4728, -1.9034, ..., 0.7344, -0.2552, -1.9788]],
device='cuda:0', dtype=torch.float64) |
Do we know that the derivative is correct though? A check we can do is with gradcheck. |
@mthrok Here is a small code example showing that you cannot backprop through an import torch
import torchaudio
noise = torch.rand(16000)
fp = torch.tensor((440.0), requires_grad=True)
filtered_noise = torchaudio.functional.lowpass_biquad(noise, sample_rate=16000, cutoff_freq=fp)
dist = torch.mean(torch.abs(filtered_noise - noise))
dist.backward(retain_graph=False) gives
|
Hi @turian Thanks for the snippet. I confirm that I am seeing the same error. I updated my previous snippet to use the import torch
import torchaudio.functional as F
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fp = torch.tensor((440.0), requires_grad=True)
def forward(self, x):
return F.lowpass_biquad(x, sample_rate=16000, cutoff_freq=self.fp)
def test(device, dtype):
net = Net().to(device=device, dtype=dtype)
x = torch.rand(2, 8000, dtype=dtype, device=device, requires_grad=True)
y = net(x)
net.zero_grad()
y.backward(torch.randn_like(y))
print('a_grad:', net.a.grad)
print('b_grad:', net.b.grad)
print('x_grad:', x.grad)
for device in ['cpu', 'cuda']:
for dtype in [torch.float32, torch.float64]:
print(f'Running {device}, {dtype}')
test(device, dtype)
|
@FBMachine what is the inplace operation in the latest master? Can you please copy and paste the offending lines here? |
@FBMachine It appears nnAudio has a lowpass filter that is differentiable:
|
Hi folks~ I also encounter this issue recently and I want to share my solution. Here's my implementation : import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter
from torch.autograd import Function, gradcheck
class lfilter(Function):
@staticmethod
def forward(ctx, x, a, b) -> torch.Tensor:
with torch.no_grad():
dummy = torch.zeros_like(a)
dummy[0] = 1
xh = torch_lfilter(x, a, dummy, False)
y = xh.view(-1, 1, xh.shape[-1])
y = F.pad(y, [b.numel() - 1, 0])
y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)
ctx.save_for_backward(x, a, b, xh)
return y
@staticmethod
def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
x, a, b, xh = ctx.saved_tensors
with torch.no_grad():
dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
b.view(1, 1, -1)).view(*dy.shape)
dummy = torch.zeros_like(a)
dummy[0] = 1
dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)
batch = x.numel() // x.shape[-1]
db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
dy.view(-1, 1, dy.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
dummy[0] = -1
dxhda = torch_lfilter(F.pad(xh, [b.numel() - 1, 0]), a, dummy, False)
da = F.conv1d(dxhda.view(1, -1, dxhda.shape[-1]),
dxh.view(-1, 1, dy.shape[-1]),
groups=batch).sum((0, 1)).flip(0)
return dx, da, db The filter form I choose is Direct-Form-II. Some comparisons between simple for-loop approach and gradient checks: It has passed the |
@FBMachine does it meet your requirement? |
Thanks for writing this and sharing it with the community! If torchscriptabilitiy is not a concern, then this is a great way to bind the forward and the backward pass :) This is in fact how we (temporarily) bind the prototype RNN transducer here in torchaudio. Such custom autograd functions (both in python and C++) are not currently supported by torchscript though. Using this within torchaudio directly in place of the current |
@vincentqb thanks, I'll take a look. |
馃殌 Feature
It is currently not possible to backpropagate gradients through an lfilter because of this inplace operation:
https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L661
Motivation
It's not worth the pytorch overhead to even use lfilter without backprop support (it's much faster when implemented using e.g. numba). When I saw that this was implemented here, I was hoping to use it instead of my own implementation (which is implemented as a custom RNN) as it is honestly too slow.
Pitch
I would love to see that inplace operation replaced with something that would allow supporting backprop. I'm not sure what the most efficient way to do this is.
Alternatives
I implemented transposed direct form II digital filters as custom RNNs, but the performance is pretty poor (which seems to be a problem with the fuser). This is the simplest version I tried, which works, but as I said it's quite slow.
Another alternative I've used when I only need to backprop through the filter, but not optimize the actual coefficients, is to take advantage of the fact that tanh is close to linear for very small inputs and design a standard RNN to be equivalent to the digital filter. Crushing the input, then rescaling the output to keep it linear gives a result very close to the original filter, but this is obviously quite a hack:
The text was updated successfully, but these errors were encountered: