New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add autograd to biquad filters #1400
Conversation
update to new version
Accommodate changes in lfilter
Rebase to master
torchaudio/functional/filtering.py
Outdated
dtype = waveform.dtype | ||
device = waveform.device | ||
central_freq = _float2Tensor(central_freq, dtype, device).squeeze() | ||
Q = _float2Tensor(Q, dtype, device).squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the call to squeeze here each time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thanks, I found that it's not necessary, will remove these lines.
My initial idea is to make sure that arguments pass to biquad
in return statement are scalar tensors, so it can be jittable.
But actually, if user pass unit length tensor (like torch.Tensor([1.])
) as input and use torch.jit.script
, it will throw error in the first place before calling biquad
.
output_waveform = lfilter( | ||
waveform, | ||
torch.tensor([a0, a1, a2], dtype=dtype, device=device), | ||
torch.tensor([b0, b1, b2], dtype=dtype, device=device), | ||
torch.cat([a0, a1, a2]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you could also use torch.as_tensor and construct the array right away.
>>> torch.as_tensor([4, torch.tensor(2), 3], dtype=torch.float)
tensor([4., 2., 3.])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have tried this approach but it didn't pass gradcheck
so roll back to use concatenation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me! @mthrok can you take another final look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good and this is a very cool change. Thanks @yoyololicon
Couple of thoughts (for follow-up)
- Should we mention the support of autograd in documentation for the functions that are tested?
- Can we introduce a helper method for
Autograd
test that handlesdtype
,device
andrequires_grad
property, so that each test method looks simpler. Example -> https://github.com/pytorch/audio/blob/e4a0bd2ceb6c80ec90194ce519bbd9589926034b/test/torchaudio_unittest/transforms/autograd_test_impl.py - What about the second order gradient? i.e.
gradgradcheck
Looks good, will work on it later if I have time.
Current implementation can't support higher order derivative (still use inplace operation internally). |
implement feature requested in #1375
Affected functions:
biquad
band_biquad
bass_biquad
allpass_biquad
highpass_biquad
lowpass_biquad
bandpass_biquad
bandreject_biquad
equalizer_biquad