Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support higher order derivatives for F.lfilter #1441

Merged
merged 20 commits into from May 6, 2021

Conversation

yoyololicon
Copy link
Collaborator

@yoyololicon yoyololicon commented Apr 8, 2021

Purpose

Update the implementation (also some refactoring) of lfilter to make it be able to compute higher order derivatives, where current implementation only support first order.

This PR is a little complex, I'll briefly explain its main idea.

Details

Current lfilter disable gradient when doing backward pass so it can use in-place operation to efficiently compute the gradient.

at::AutoNonVariableTypeMode g;

But, in order to support higher order derivatives, the backward pass should also be differentiable so we can create new graph.

Decomposition of lfilter

We can breakdown the computation of lfilter into two parts: non-autoregressive part (will call it FIR afterwards) and autoregressive part (will call it IIR afterwards).

The FIR part involved with parameter b_coeffs and can be done by a single conv1d call, which natively support autograd; the IIR part involved with parameter a_coeffs and use in-place operation to speed up the computation, so a custom autograd function is needed.
We will focusing on the latter part, leave the former to PyTorch native autograd mechanism.

Some Facts About Gradient of IIR

  1. The gradient with respect to IIR input (waveform) equals IIR filtering the gradient with respect to IIR output in reversed time direction with the same a_coeffs
  2. The gradient of IIR output with respect to a_coeffs equals IIR filtering the opposite of IIR output with the same a_coeffs
  3. The gradient with respect to a_coeffs equals convolve (2) with the gradient with respect to IIR output using conv1d

Recursive Differentiable Backward Pass

How to make the backward pass of IIR be differentiable? Base on previous part, we can see that the backward pass consists of 2 call of IIR and 1 call of conv1d. Because IIR is now differentiable, it can just call itself two times in its backward pass, with gradient being enable. In this way the backward computational graph can be created.

Additional context

To verify the above changes, I add a gradgradcheck in unittests, also lower the size of test inputs cuz the ci runtime is getting longer and longer.

@mthrok
Copy link
Collaborator

mthrok commented Apr 8, 2021

Hi @yoyololicon

Thanks for the PR. This is another nice addition. We really appreciate it!

Couple of thoughts at a glance

  1. Can we extract the change for the shortening the test? That part can be merged immediately.
  2. n00b question: Does this change have on the lower precision gradient for dtype=float32 and long input case? i.e. test_lfilter_9th_order_filter_stability

@mthrok
Copy link
Collaborator

mthrok commented Apr 8, 2021

@cpuhrsch CR please.

@mthrok mthrok requested a review from cpuhrsch April 8, 2021 13:45
@yoyololicon
Copy link
Collaborator Author

Hi @yoyololicon

Thanks for the PR. This is another nice addition. We really appreciate it!

Couple of thoughts at a glance

  1. Can we extract the change for the shortening the test? That part can be merged immediately.

Sure.

  1. n00b question: Does this change have on the lower precision gradient for dtype=float32 and long input case? i.e. test_lfilter_9th_order_filter_stability

This change doesn't solve the stability issue, and should be a different topic.

@mthrok
Copy link
Collaborator

mthrok commented Apr 13, 2021

Hi @yoyololicon

What is the performance implication of this change?

I suspect that you made this PR because I asked about the second-order autograd in the other PR.
We did not have a stance on the support for the autograd, and we had an internal discussion
about how much we want to support the autograd.
We agreed that second-order differentiability is nice to have but it is not something we enforce.

From the description, it sounds like, this PR gets rid of the trick for efficient computation.
If the performance degradation is very small, we can adopt this change,
but if it slows down the operation a lot, our preference is performance.

Sorry you made this PR from a good heart and probably because of what I said,
but this is an uncharted territory for the domain library, so we do not have a good established standard yet.

@yoyololicon
Copy link
Collaborator Author

yoyololicon commented Apr 13, 2021

From the description, it sounds like, this PR gets rid of the trick for efficient computation.
If the performance degradation is very small, we can adopt this change,
but if it slows down the operation a lot, our preference is performance.

@mthrok
Oh, no need to worry, the computation is almost the same as the previous one, just reuse some part of the function, it still benefit from the efficient in-place operation (both forward and backward). What I did is almost just refactoring, with no extra cost.

Also no need to apologize, I can understand your situation.
Instead I want to thank you because of your request, I found some redundant parts of the code and be able to improve it in this PR (ex: moving conv1d outside the custom function).

@mthrok
Copy link
Collaborator

mthrok commented Apr 14, 2021

@yoyololicon

Also no need to apologize, I can understand your situation.
Instead I want to thank you because of your request, I found some redundant parts of the code and be able to improve it in this PR (ex: moving conv1d outside the custom function).

Thanks you for the nice comments.

Oh, no need to worry, the computation is almost the same as the previous one, just reuse some part of the function, it still benefit from the efficient in-place operation (both forward and backward). What I did is almost just refactoring, with no extra cost.

Can you provide a simple benchmark about the change?

@yoyololicon
Copy link
Collaborator Author

@mthrok

Can you provide a simple benchmark about the change?

I benchmark forward and backward pass seperatedly with different input sizes.

lfilter from master branch:

[-------------- IIR filter --------------]
                   |  forward  |  backward
1 threads: -------------------------------
      [32, 256]    |    323.4  |    986.2 
      [32, 1024]   |    737.8  |   2486.5 
      [32, 4096]   |   2533.6  |   9395.5 
      [64, 256]    |    467.3  |   1494.0 
      [64, 1024]   |   1230.0  |   4423.8 
      [64, 4096]   |   9838.7  |  26206.4 
      [128, 256]   |    712.9  |   2423.9 
      [128, 1024]  |   2516.4  |   8967.6 
      [128, 4096]  |  24567.0  |  53481.6 
2 threads: -------------------------------
      [32, 256]    |    284.4  |    837.2 
      [32, 1024]   |    594.7  |   1817.5 
      [32, 4096]   |   1802.1  |   6398.4 
      [64, 256]    |    394.1  |   1170.2 
      [64, 1024]   |    905.2  |   2993.2 
      [64, 4096]   |   9959.3  |  19893.5 
      [128, 256]   |    602.8  |   1831.6 
      [128, 1024]  |   1946.0  |   5930.6 
      [128, 4096]  |  24935.3  |  45665.2 
4 threads: -------------------------------
      [32, 256]    |    264.8  |    763.2 
      [32, 1024]   |    533.6  |   1536.4 
      [32, 4096]   |   1494.6  |   5097.3 
      [64, 256]    |    349.7  |    960.4 
      [64, 1024]   |    793.1  |   2269.9 
      [64, 4096]   |   8180.1  |  17459.9 
      [128, 256]   |    517.3  |   1461.8 
      [128, 1024]  |   1486.9  |   4418.9 
      [128, 4096]  |  25818.2  |  42742.5 

Times are in microseconds (us).

lfilter from this branch:

[-------------- IIR filter --------------]
                   |  forward  |  backward
1 threads: -------------------------------
      [32, 256]    |    352.5  |    1376.6
      [32, 1024]   |    724.9  |    4743.0
      [32, 4096]   |   2552.3  |   22188.3
      [64, 256]    |    502.5  |    2137.0
      [64, 1024]   |   1283.9  |    7077.9
      [64, 4096]   |   9073.3  |   54575.9
      [128, 256]   |    761.6  |    3629.3
      [128, 1024]  |   2492.8  |   17605.2
      [128, 4096]  |  24075.4  |  116302.5
2 threads: -------------------------------
      [32, 256]    |    314.3  |    1121.7
      [32, 1024]   |    598.7  |    3099.3
      [32, 4096]   |   1958.6  |   17688.8
      [64, 256]    |    427.2  |    1595.2
      [64, 1024]   |    932.7  |    5429.6
      [64, 4096]   |   8516.9  |   44660.9
      [128, 256]   |    628.8  |    2506.6
      [128, 1024]  |   1854.6  |   14419.8
      [128, 4096]  |  24809.8  |   95872.9
4 threads: -------------------------------
      [32, 256]    |    290.0  |     948.8
      [32, 1024]   |    527.4  |    2797.9
      [32, 4096]   |   1619.1  |   13063.2
      [64, 256]    |    373.6  |    1266.3
      [64, 1024]   |    814.0  |    3485.2
      [64, 4096]   |   8311.8  |   37644.0
      [128, 256]   |    540.6  |    1889.7
      [128, 1024]  |   1534.5  |   12953.7
      [128, 4096]  |  23322.9  |   89640.3

Times are in microseconds (us).

The differences in forward pass I think is subtle; the extra overhead in backward pass I think is coming from the padding operation at L136.

script
import torch
import torch.utils.benchmark as benchmark
from itertools import product


batch = 8
samples = 1024

a = torch.tensor([0.7, 0.2, 0.6], requires_grad=True)
b = torch.tensor([0.4, 0.2, 0.9], requires_grad=True)

results = []
batches = [32, 64, 128]
samples = [256, 1024, 4096]
for batch, n in product(batches, samples):
    label = 'IIR filter'
    sub_label = f'[{batch}, {n}]'
    x = torch.randn(batch, n, requires_grad=True)

    for num_threads in [1, 2, 4]:
        results.append(benchmark.Timer(
            stmt='y = lfilter(x, a, b, False)',
            setup='from torchaudio.functional import lfilter',
            globals={'x': x, 'a': a, 'b': b},
            label=label,
            num_threads=num_threads,
            sub_label=sub_label,
            description='forward',
        ).blocked_autorange(min_run_time=1))

        results.append(benchmark.Timer(
            stmt='loss.backward(retain_graph=True)',
            setup='''\
            from torchaudio.functional import lfilter
            loss = lfilter(x, a, b, False).sum()
            ''',
            globals={'x': x, 'a': a, 'b': b},
            label=label,
            num_threads=num_threads,
            sub_label=sub_label,
            description='backward',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

@cpuhrsch
Copy link
Contributor

@yoyololicon - sorry for the long response time. It looks like this PR adds about doubles the runtime of the backward pass. I think it's useful to have higher order derivatives, but we should think of a mechanism to guard enabling higher order derivatives explicitly. I suppose we could have a flag for the forward pass, that allows the user to enable this, but this isn't quite the most principled approach to this. I wish we had an autograd guard or setting that restricts it to a certain order. Let me talk to a few people and get back to you on how to best resolve this, but for now I don't think it's worth the performance regression to be landed in the current form.

@albanD
Copy link

albanD commented Apr 29, 2021

Actually you already have this guard that exists that you can use during the backward to make the choice between the fast, non-differentiable implementation of the differentiable one.
You can use torch::autograd::compute_requires_grad(input_tensor1, input_tensor2,...) where you give all the inputs that can require gradients. And that will tell you if you need to do your computations in a differentiable manner or not.

You can check that PR: https://github.com/pytorch/pytorch/pull/57189/files that does a similar thing.
It is different because it makes the choice of preparing for backward in the forward and here you want to make the choice during the backward to prepare for double backward. But the structure will look the same.

@cpuhrsch
Copy link
Contributor

Thanks for the note @albanD!

@yoyololicon, looks like this way we can have both! Do you want to give this a try?

@yoyololicon
Copy link
Collaborator Author

@albanD Thanks for the information.
I'll try whether it works on our case.

@yoyololicon
Copy link
Collaborator Author

yoyololicon commented Apr 30, 2021

@cpuhrsch
Well, I found that the cause is actually comes from the FIR part...

If the FIR is also written as custom function (just like previous implementation did), then runtime could be comparable.

[-------------- IIR filter --------------]
                   |  forward  |  backward
1 threads: -------------------------------
      [32, 256]    |    336.0  |   1087.4 
      [32, 1024]   |    709.5  |   2529.8 
      [32, 4096]   |   2516.6  |   9510.0 
      [64, 256]    |    476.0  |   1588.5 
      [64, 1024]   |   1219.3  |   4523.5 
      [64, 4096]   |   8257.7  |  23168.5 
      [128, 256]   |    714.8  |   2519.6 
      [128, 1024]  |   2403.4  |   9680.2 
      [128, 4096]  |  22925.3  |  53460.4 
2 threads: -------------------------------
      [32, 256]    |    289.1  |    915.1 
      [32, 1024]   |    583.7  |   1903.2 
      [32, 4096]   |   1873.6  |   6947.1 
      [64, 256]    |    411.6  |   1284.4 
      [64, 1024]   |    917.1  |   3062.7 
      [64, 4096]   |   7129.9  |  17641.1 
      [128, 256]   |    586.6  |   1877.0 
      [128, 1024]  |   1742.2  |   6908.3 
      [128, 4096]  |  23213.0  |  44310.6 
4 threads: -------------------------------
      [32, 256]    |    271.1  |    850.2 
      [32, 1024]   |    509.5  |   1592.7 
      [32, 4096]   |   1690.7  |   5599.8 
      [64, 256]    |    346.3  |   1083.4 
      [64, 1024]   |    787.5  |   2426.1 
      [64, 4096]   |   6957.4  |  16782.5 
      [128, 256]   |    514.2  |   1573.1 
      [128, 1024]  |   1435.7  |   5741.8 
      [128, 4096]  |  24481.0  |  42195.3 

Times are in microseconds (us).

The lfilter function would be like:

torch::Tensor lfilter_core(
    const torch::Tensor& waveform,
    const torch::Tensor& a_coeffs,
    const torch::Tensor& b_coeffs) {
  TORCH_CHECK(waveform.device() == a_coeffs.device());
  TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
  TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));

  TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);

  int64_t n_order = b_coeffs.size(0);

  TORCH_INTERNAL_ASSERT(n_order > 0);

  auto filtered_waveform = DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]);

  auto output =
      DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]);
  return output;
}

But this looks a little rebundant to me, because the FIR is actually just F.conv1d and doesn't need to be written as custom function.
I prefer let PyTorch do its own autograd things as far as possible.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented May 3, 2021

@albanD - Could I ask you take a quick look again?

Copy link

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

at::AutoNonVariableTypeMode should not used except for internal code though... So this is not great. But if that's the current state, we can move the discussion about removing these to a different issue.

torchaudio/csrc/lfilter.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Let's wait for another review by Moto and unittests to run green (the mac failures are unrelated and we might even defer them for this merge).

@cpuhrsch cpuhrsch requested a review from mthrok May 5, 2021 04:27
@mthrok
Copy link
Collaborator

mthrok commented May 5, 2021

I am more on learning from this PR than reviewing it, and nothing looks suspicious so I think it's good.

@mthrok
Copy link
Collaborator

mthrok commented May 5, 2021

Let me see what I can do with macOS situation.

@mthrok
Copy link
Collaborator

mthrok commented May 5, 2021

@yoyololicon I fixed macOS build issue in #1485. Can you rebase? There are some tests still failing for macOS but as long as it's not related to lfilter we can safely merge this.

@mthrok mthrok merged commit 723e9a5 into pytorch:master May 6, 2021
@mthrok
Copy link
Collaborator

mthrok commented May 6, 2021

@yoyololicon Thank you so much for the contribution!

@yoyololicon yoyololicon deleted the lfilter-higher-order-gradient branch May 6, 2021 23:38
@mthrok mthrok mentioned this pull request Sep 14, 2021
4 tasks
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
* Update index.rst

Test for adding a temporary survey option to the site.

* Update index.rst

* Update index.rst

Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants