Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add autograd to biquad filters #1400

Merged
merged 9 commits into from Mar 31, 2021
Merged

Conversation

yoyololicon
Copy link
Collaborator

@yoyololicon yoyololicon commented Mar 19, 2021

implement feature requested in #1375

Affected functions:

  • biquad
  • band_biquad
  • bass_biquad
  • allpass_biquad
  • highpass_biquad
  • lowpass_biquad
  • bandpass_biquad
  • bandreject_biquad
  • equalizer_biquad

dtype = waveform.dtype
device = waveform.device
central_freq = _float2Tensor(central_freq, dtype, device).squeeze()
Q = _float2Tensor(Q, dtype, device).squeeze()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the call to squeeze here each time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thanks, I found that it's not necessary, will remove these lines.
My initial idea is to make sure that arguments pass to biquad in return statement are scalar tensors, so it can be jittable.
But actually, if user pass unit length tensor (like torch.Tensor([1.])) as input and use torch.jit.script, it will throw error in the first place before calling biquad.

output_waveform = lfilter(
waveform,
torch.tensor([a0, a1, a2], dtype=dtype, device=device),
torch.tensor([b0, b1, b2], dtype=dtype, device=device),
torch.cat([a0, a1, a2]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you could also use torch.as_tensor and construct the array right away.

>>> torch.as_tensor([4, torch.tensor(2), 3], dtype=torch.float)
tensor([4., 2., 3.])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have tried this approach but it didn't pass gradcheck so roll back to use concatenation.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me! @mthrok can you take another final look?

@cpuhrsch cpuhrsch requested a review from mthrok March 31, 2021 16:39
Copy link
Collaborator

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change looks good and this is a very cool change. Thanks @yoyololicon

Couple of thoughts (for follow-up)

  1. Should we mention the support of autograd in documentation for the functions that are tested?
  2. Can we introduce a helper method for Autograd test that handles dtype, device and requires_grad property, so that each test method looks simpler. Example -> https://github.com/pytorch/audio/blob/e4a0bd2ceb6c80ec90194ce519bbd9589926034b/test/torchaudio_unittest/transforms/autograd_test_impl.py
  3. What about the second order gradient? i.e. gradgradcheck

@cpuhrsch cpuhrsch merged commit 52decd2 into pytorch:master Mar 31, 2021
@yoyololicon yoyololicon deleted the biquad-autograd branch March 31, 2021 23:50
@yoyololicon
Copy link
Collaborator Author

  1. Can we introduce a helper method for Autograd test that handles dtype, device and requires_grad property, so that each test method looks simpler. Example -> https://github.com/pytorch/audio/blob/e4a0bd2ceb6c80ec90194ce519bbd9589926034b/test/torchaudio_unittest/transforms/autograd_test_impl.py

Looks good, will work on it later if I have time.

  1. What about the second order gradient? i.e. gradgradcheck

Current implementation can't support higher order derivative (still use inplace operation internally).
I'm not sure whether there is a need for it, personally a first order derivative is sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants