Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backprop support for lfilter #704

Closed
daniel-p-gonzalez opened this issue Jun 7, 2020 · 10 comments
Closed

Backprop support for lfilter #704

daniel-p-gonzalez opened this issue Jun 7, 2020 · 10 comments

Comments

@daniel-p-gonzalez
Copy link

馃殌 Feature

It is currently not possible to backpropagate gradients through an lfilter because of this inplace operation:
https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L661

Motivation

It's not worth the pytorch overhead to even use lfilter without backprop support (it's much faster when implemented using e.g. numba). When I saw that this was implemented here, I was hoping to use it instead of my own implementation (which is implemented as a custom RNN) as it is honestly too slow.

Pitch

I would love to see that inplace operation replaced with something that would allow supporting backprop. I'm not sure what the most efficient way to do this is.

Alternatives

I implemented transposed direct form II digital filters as custom RNNs, but the performance is pretty poor (which seems to be a problem with the fuser). This is the simplest version I tried, which works, but as I said it's quite slow.

class DigitalFilterModel(jit.ScriptModule):
  def __init__(self):
    super(DigitalFilterModel, self).__init__()

  @jit.script_method
  def forward(self, x, coeffs, v1, v2, v3):
    # type: (Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
    seq_len = x.shape[1]
    output = torch.jit.annotate(List[Tensor], [])
    x = x.unbind(1)
    coeffs = coeffs.unbind(1)
    for i in range(seq_len):
      sample = x[i]
      out = coeffs[0] * sample + v1
      output.append(out)

      v1 = coeffs[1] * sample - coeffs[4] * out + v2
      v2 = coeffs[2] * sample - coeffs[5] * out + v3
      v3 = coeffs[3] * sample - coeffs[6] * out

    return torch.stack(output, 1), v1, v2, v3

Another alternative I've used when I only need to backprop through the filter, but not optimize the actual coefficients, is to take advantage of the fact that tanh is close to linear for very small inputs and design a standard RNN to be equivalent to the digital filter. Crushing the input, then rescaling the output to keep it linear gives a result very close to the original filter, but this is obviously quite a hack:

class RNNTDFWrapper(nn.Module):
  def __init__(self, eps=0.000000001):
    super(RNNTDFWrapper, self).__init__()
    self.eps = eps
    self.rnn = nn.RNN(1, 4, 1, False, True)

  def set_coefficients(self, coeffs):
    self.rnn.weight_ih_l0.data[:,:] = torch.tensor(coeffs[:4]).view(-1,1)
    self.rnn.weight_hh_l0.data[:,:] = 0.0
    self.rnn.weight_hh_l0.data[0,1] = 1.0
    self.rnn.weight_hh_l0.data[1,2] = 1.0
    self.rnn.weight_hh_l0.data[2,3] = 1.0
    self.rnn.weight_hh_l0.data[:3,0] = -1.0 * torch.tensor(coeffs[4:])

  def forward(self, x):
    batch_size = x.shape[0]
    x = self.eps * x.view(batch_size, -1, 1)
    x, _ = self.rnn.forward(x)
    x = (1.0/self.eps) * x[:,:,0]
    return x
@mthrok
Copy link
Collaborator

mthrok commented Jun 9, 2020

Hi @FBMachine

Can you provide a snippet that demonstrates lfilter not supporting backward?

I was trying to reproduce the issue but putting lfilter into nn.Module and backprobagating through it seems to work fine, but I might be doing something wrong.

Or, are you saying that you would like to compute gradient for coefficients for lfilter?

import torch
import torchaudio.functional as F


class Net(torch.nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.a = torch.nn.Parameter(a, requires_grad=True)
        self.b = torch.nn.Parameter(b, requires_grad=True)

    def forward(self, x):
        return F.lfilter(x, self.a, self.b)


def test(device, dtype):
    net = Net(
        a=torch.tensor([0., 0., 0., 1.]),
        b=torch.tensor([1., 0., 0., 0.]),
    ).to(device=device, dtype=dtype)
    x = torch.rand(2, 8000, dtype=dtype, device=device, requires_grad=True)
    y = net(x)
    net.zero_grad()
    y.backward(torch.randn_like(y))
    print('a_grad:', net.a.grad)
    print('b_grad:', net.b.grad)
    print('x_grad:', x.grad)

for device in ['cpu', 'cuda']:
    for dtype in [torch.float32, torch.float64]:
        print(f'Running {device}, {dtype}')
        test(device, dtype)
$ python foo.py
Running cpu, torch.float32
a_grad: None
b_grad: None
x_grad: tensor([[ 2.2886,  0.6447, -0.0231,  ..., -2.0171,  0.3783,  3.4622],
        [ 0.0278, -0.1908, -0.8077,  ...,  0.9406, -0.0560, -0.6732]])
Running cpu, torch.float64
a_grad: None
b_grad: None
x_grad: tensor([[ 0.3589, -0.4542,  0.2553,  ...,  0.0147, -1.3429,  0.7961],
        [-1.2413,  1.7650, -0.3808,  ...,  1.8582, -1.2257, -0.2102]],
       dtype=torch.float64)
Running cuda, torch.float32
a_grad: None
b_grad: None
x_grad: tensor([[-0.8601,  1.1020,  0.7039,  ...,  0.7219, -0.0040, -1.4189],
        [ 1.6594, -0.5011,  1.3873,  ...,  1.1267,  0.8386,  0.9974]],
       device='cuda:0')
Running cuda, torch.float64
a_grad: None
b_grad: None
x_grad: tensor([[-0.5385, -1.4356, -0.9297,  ..., -1.2368, -0.7705, -0.5666],
        [-1.3023,  0.4728, -1.9034,  ...,  0.7344, -0.2552, -1.9788]],
       device='cuda:0', dtype=torch.float64)

@vincentqb
Copy link
Contributor

Do we know that the derivative is correct though? A check we can do is with gradcheck.

@turian
Copy link

turian commented Oct 26, 2020

@mthrok Here is a small code example showing that you cannot backprop through an lfilter parameter:

import torch
import torchaudio
noise = torch.rand(16000)
fp = torch.tensor((440.0), requires_grad=True)
filtered_noise = torchaudio.functional.lowpass_biquad(noise, sample_rate=16000, cutoff_freq=fp)
dist = torch.mean(torch.abs(filtered_noise - noise))
dist.backward(retain_graph=False)

gives

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-c658d88b5d27> in <module>
----> 1 dist.backward(retain_graph=False)
      2 print(fp.grad)

/usr/local/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    183                 products. Defaults to ``False``.
    184         """
--> 185         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    186 
    187     def register_hook(self, hook):

/usr/local/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    123         retain_graph = create_graph
    124 
--> 125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
    127         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@mthrok
Copy link
Collaborator

mthrok commented Oct 26, 2020

Hi @turian

Thanks for the snippet. I confirm that I am seeing the same error.

I updated my previous snippet to use the lowpass_biquad and realized that even the forward function does not work.

import torch
import torchaudio.functional as F


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fp = torch.tensor((440.0), requires_grad=True)

    def forward(self, x):
        return F.lowpass_biquad(x, sample_rate=16000, cutoff_freq=self.fp)


def test(device, dtype):
    net = Net().to(device=device, dtype=dtype)
    x = torch.rand(2, 8000, dtype=dtype, device=device, requires_grad=True)
    y = net(x)
    net.zero_grad()
    y.backward(torch.randn_like(y))
    print('a_grad:', net.a.grad)
    print('b_grad:', net.b.grad)
    print('x_grad:', x.grad)

for device in ['cpu', 'cuda']:
    for dtype in [torch.float32, torch.float64]:
        print(f'Running {device}, {dtype}')
        test(device, dtype)
Traceback (most recent call last):
  File "bar.py", line 27, in <module>
    test(device, dtype)
  File "bar.py", line 17, in test
    y = net(x)
  File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "bar.py", line 11, in forward
    return F.lowpass_biquad(x, sample_rate=16000, cutoff_freq=self.fp)
  File "/scratch/moto/torchaudio/torchaudio/functional.py", line 703, in lowpass_biquad
    return biquad(waveform, b0, b1, b2, a0, a1, a2)
  File "/scratch/moto/torchaudio/torchaudio/functional.py", line 636, in biquad
    output_waveform = lfilter(
  File "/scratch/moto/torchaudio/torchaudio/functional.py", line 594, in lfilter
    o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
RuntimeError: Output 0 of UnbindBackward is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

@turian
Copy link

turian commented Dec 4, 2020

@FBMachine what is the inplace operation in the latest master?

https://github.com/pytorch/audio/blob/5e54c770b41bbdb7b228fe511b364f3f2aa96a88/torchaudio/functional/__init__.py

Can you please copy and paste the offending lines here?

@turian
Copy link

turian commented Dec 5, 2020

@FBMachine It appears nnAudio has a lowpass filter that is differentiable:

import nnAudio.utils
import torch
from torch.nn.functional import conv1d, fold
lowpass_filter = torch.tensor(nnAudio.utils.create_lowpass_filter(
                                                    band_center = 0.5,
                                                    kernelLength=256,
                                                    transitionBandwidth=0.001
                                                    )
                             )
lowpass_filter = lowpass_filter[None,None,:]
x = torch.rand(10000)[None,None,:]
y = conv1d(x,lowpass_filter,stride=1, padding=(lowpass_filter.shape[-1]-1)//2)

@yoyololicon
Copy link
Collaborator

yoyololicon commented Jan 16, 2021

Hi folks~ I also encounter this issue recently and I want to share my solution.
The approach I chose is to implement a custom autograd function for lfilter.

Here's my implementation :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter

from torch.autograd import Function, gradcheck

class lfilter(Function):

    @staticmethod
    def forward(ctx, x, a, b) -> torch.Tensor:
        with torch.no_grad():
            dummy = torch.zeros_like(a)
            dummy[0] = 1

            xh = torch_lfilter(x, a, dummy, False)

            y = xh.view(-1, 1, xh.shape[-1])
            y = F.pad(y, [b.numel() - 1, 0])
            y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)

        ctx.save_for_backward(x, a, b, xh)
        return y

    @staticmethod
    def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        x, a, b, xh = ctx.saved_tensors
        with torch.no_grad():
            dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
                           b.view(1, 1, -1)).view(*dy.shape)

            dummy = torch.zeros_like(a)
            dummy[0] = 1
            dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)

            batch = x.numel() // x.shape[-1]
            db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
                          dy.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)
            dummy[0] = -1
            dxhda = torch_lfilter(F.pad(xh, [b.numel() - 1, 0]), a, dummy, False)
            da = F.conv1d(dxhda.view(1, -1, dxhda.shape[-1]),
                          dxh.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)

        return dx, da, db

The filter form I choose is Direct-Form-II.
I just wrap torchaudio.functional.lfilter inside the custom function, no extra dependency is needed.

Some comparisons between simple for-loop approach and gradient checks:
https://gist.github.com/yoyololicon/f63f601d62187562070a61377cec9bf8

It has passed the gradcheck using a simple second-order filter model, and I'm planning to do more tests on higher order model.

@yoyololicon
Copy link
Collaborator

@FBMachine does it meet your requirement?

@vincentqb
Copy link
Contributor

Hi folks~ I also encounter this issue recently and I want to share my solution.
The approach I chose is to implement a custom autograd function for lfilter.

Here's my implementation :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter

from torch.autograd import Function, gradcheck

class lfilter(Function):

    @staticmethod
    def forward(ctx, x, a, b) -> torch.Tensor:
        with torch.no_grad():
            dummy = torch.zeros_like(a)
            dummy[0] = 1

            xh = torch_lfilter(x, a, dummy, False)

            y = xh.view(-1, 1, xh.shape[-1])
            y = F.pad(y, [b.numel() - 1, 0])
            y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)

        ctx.save_for_backward(x, a, b, xh)
        return y

    @staticmethod
    def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        x, a, b, xh = ctx.saved_tensors
        with torch.no_grad():
            dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
                           b.view(1, 1, -1)).view(*dy.shape)

            dummy = torch.zeros_like(a)
            dummy[0] = 1
            dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)

            batch = x.numel() // x.shape[-1]
            db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
                          dy.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)
            dummy[0] = -1
            dxhda = torch_lfilter(F.pad(xh, [b.numel() - 1, 0]), a, dummy, False)
            da = F.conv1d(dxhda.view(1, -1, dxhda.shape[-1]),
                          dxh.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)

        return dx, da, db

The filter form I choose is Direct-Form-II.
I just wrap torchaudio.functional.lfilter inside the custom function, no extra dependency is needed.

Some comparisons between simple for-loop approach and gradient checks:
https://gist.github.com/yoyololicon/f63f601d62187562070a61377cec9bf8

It has passed the gradcheck using a simple second-order filter model, and I'm planning to do more tests on higher order model.

Thanks for writing this and sharing it with the community! If torchscriptabilitiy is not a concern, then this is a great way to bind the forward and the backward pass :) This is in fact how we (temporarily) bind the prototype RNN transducer here in torchaudio.

Such custom autograd functions (both in python and C++) are not currently supported by torchscript though. Using this within torchaudio directly in place of the current lfilter (which is torchscriptable) would be BC breaking unfortunately. In the long term, we'll need to register the backward pass with autograd. Here's a tutorial for how to do this in a torchscriptable manner.

@yoyololicon
Copy link
Collaborator

@vincentqb thanks, I'll take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants