-
Notifications
You must be signed in to change notification settings - Fork 719
Add backprop support to lfilter #1310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
update to new version
solve c++ building issue
Hi @yoyololicon! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
Hi @yoyololicon Thanks for the contribution. This looks a great improvement. Questions:
|
It passed the tests listed in
I think currently a gradient test on second order filter might be enough, cuz it's very common and is also the basis of biquad filters.
A custom autograd function in Python frontend would break the torchscript support ability. |
My initial idea is to make the source code much more readable, but seems like it can also bring some performance improvement. Below is the profiling results on this specific part of lfilter:
The new implementation reduce number of function calls, makes the majority of computation rely on a single The profile script: import torch
import torch.nn.functional as F
import torch.autograd.profiler as profiler
import torch.utils.benchmark as benchmark
def index_matmul(x, b):
n_channel, n_sample = x.shape
dtype = x.dtype
device = x.device
n_order = b.shape[0]
window_idxs = torch.arange(n_sample - n_order + 1, device=device).unsqueeze(
0) + torch.arange(n_order, device=device).unsqueeze(1)
window_idxs = window_idxs.repeat(n_channel, 1, 1)
window_idxs += (
torch.arange(
n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample
)
window_idxs = window_idxs.long()
input_signal_windows = torch.matmul(
b, torch.take(x, window_idxs)
)
return input_signal_windows
def conv1d(x, b):
n_channel, n_sample = x.shape
dtype = x.dtype
device = x.device
n_order = b.shape[0]
return F.conv1d(x.unsqueeze(1), b.view(1, 1, n_order)).squeeze(1)
if __name__ == '__main__':
torch.random.manual_seed(2434)
b = torch.rand(3)
x = torch.randn(16, 44100)
x /= x.abs().max()
with profiler.profile(profile_memory=True) as prof:
for _ in range(100):
index_matmul(x, b)
print("indexing + matmul")
print(prof.key_averages().table(sort_by="cpu_time_total"))
with profiler.profile(profile_memory=True) as prof:
for _ in range(100):
conv1d(x, b)
print("conv1d")
print(prof.key_averages().table(sort_by="cpu_time_total"))
y1 = index_matmul(x, b)
y2 = conv1d(x, b)
assert torch.allclose(y1, y2, atol=1e-7) |
Thank you for this great contribution @yoyololicon! I'm taking this on to take some workload off of Moto :) There are so many good things in here, I'd suggest we split this into three pieces so we can land more quickly a) and b) are fairly straightforward and should be quick to land, c) will need a bit more discussion due to numerical stability details etc. While you're doing this we might already have solved Windows support for C++ extension, but your work is not blocked on it. Thanks!
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment, please add me as the reviewer :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, but how to add reviewers?
Should we create seperate request for each one? |
I recommend opening separate PRs for a) and b). Then after a) and b) are done, you can either close this one or update it to c). For now this one has a good discussion and benchmark numbers, so let's leave it open. |
Profiling results running on a P620 gpu with same parameters:
The speed improvements is a lot more obvious. |
update to latest version
@cpuhrsch A demonstration test is added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from scipy dependency question to generate test asset (@mthrok) this looks good to go.
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks! |
Hi @mthrok , the stability test case was updated, could you help me review it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @yoyololicon
Sorry for my late response. I was away from the laptop. Please refer to my comments regarding test readability and maintainability.
@yoyololicon Thanks for this great contribution! |
This merge solve issue #704.
It moves the original python implementation of
lfilter
into c++ backend, and register a custom autograd kernel to support torchscript as @vincentqb mentioned in #704 .A simple test case is added to test whether the gradient is valid or not.
Notes
Some differences to the old
lfilter
:audio/torchaudio/functional/filtering.py
Line 881 in e83d557
is replaced by a single
conv1d
function call.https://github.com/yoyololicon/audio/blob/4e2ff32b50d56ce168fcee872c95ffc6cde82eaa/torchaudio/csrc/lfilter.cpp#L123