Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support opset 17 operators #81075

Closed
3 of 33 tasks
justinchuby opened this issue Jul 8, 2022 · 29 comments
Closed
3 of 33 tasks

[ONNX] Support opset 17 operators #81075

justinchuby opened this issue Jul 8, 2022 · 29 comments
Assignees
Labels
good first issue module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team OSS contribution wanted PR from open source contributors welcome to solve this issue. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Jul 8, 2022

Tracking issue on the onnx opset 17 support. Contributions welcome.

Closes #80834

@cpuhrsch cpuhrsch added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 8, 2022
@titaiwangms titaiwangms added the onnx-triaged triaged by ONNX team label Jul 11, 2022
@justinchuby justinchuby self-assigned this Jul 14, 2022
@iver56
Copy link

iver56 commented Jul 22, 2022

Justin, you listed istft here, but I can't find istft in ONNX's list of operators. Is the idea to implement istft in terms of the available onnx operators?

Edit: This PR also lists ISTFT in its description, but there's no mention of ISTFT in the code changes. Now I'm confused. Does ONNX opset 17 support ISTFT or not?

@justinchuby
Copy link
Collaborator Author

Hi @iver56, I can confirm that ISTFT is not in opset 17. However onnx is open to including it as a new op (please feel free to open an issue in the onnx/onnx repo if not already existed). We plan to support exporting istft from pytorch in spite of the missing ISTFI op.

@stonelazy
Copy link

If I may add, even though ISTFT is not available, both forward and inverse operations for DFT are supported by ONNX. Reference This should help users implement ISTFT on their own.
Would it be possible to include support for FFT as well ?

@justinchuby
Copy link
Collaborator Author

@stonelazy i think so. Do you have a list of related torch functions you think could be fitting?

@justinchuby justinchuby added good first issue OSS contribution wanted PR from open source contributors welcome to solve this issue. labels Jul 26, 2022
@stonelazy
Copy link

@stonelazy i think so. Do you have a list of related torch functions you think could be fitting?

I hope am not wrong, below are the torch's equivalent in ONNX.
Torch's FFT's equivalent is ONNX's DFT
Torch's IFFT is same ONNX's DFT but with inverse option selected.
Torch's Melscale equivalent is ONNX's Melweight.

@jonas-doevenspeck
Copy link

SInce LayerNorm is also part of opset 17, does it make sense to include it as well in this list?
https://github.com/onnx/onnx/blob/main/docs/Operators.md#layernormalization

@FrankFundel
Copy link

Will opset17 be available, only if all operators are implemented?

@justinchuby
Copy link
Collaborator Author

opset17 is enabled in torch-nightly. support for new ops will be added gradually.

@FrankFundel
Copy link

Thanks :) Hoping for stft and window functions to come soon.

@mravanelli
Copy link

That's a very important feature for SpeechBrain as well. The feature extraction part relies on fft/stft and we cannot currently fully export/import it in ONNX (@fpaissan).

@frankiedrake
Copy link

frankiedrake commented Oct 17, 2022

opset17 is enabled in torch-nightly. support for new ops will be added gradually.

Hello, I've installed torch from nightly releases (Version: 1.14.0.dev20221017+cu117) but I'm still getting an error
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator '::stft' to ONNX opset version 17 is not supported
Is that me installed the wrong version or this feature isn't released yet?

@justinchuby
Copy link
Collaborator Author

Hi! Even though we can export to the opset17 format, support for stft is not yet added unfortunately.

pytorchmergebot pushed a commit that referenced this issue Mar 10, 2023
This PR addresses issue [#81075](#81075),  making `torch.stft` compatible with ONNX Opset 17's STFT operator.

The conversion works for _most_ of `torch.stft` functionality:

- Batched or unbatched inputs
- Normalization
- Pre-computed windows
- Rectangular windows
- One-sided returns
- Window centering (implicitly supported)

What is currently _not_ supported is **complex types**, due to the lack of conversion functionality between PyTorch and ONNX (#86746).

Regardless, this is easy to bypass by setting `return_complex=False` when using `torch.stft`.

Note that there is already a draft PR to address this (#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of `torch.stft` functionality is lacking, and unit tests are missing).
Pull Request resolved: #92087
Approved by: https://github.com/justinchuby
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Mar 12, 2023
This PR addresses issue [#81075](pytorch/pytorch#81075),  making `torch.stft` compatible with ONNX Opset 17's STFT operator.

The conversion works for _most_ of `torch.stft` functionality:

- Batched or unbatched inputs
- Normalization
- Pre-computed windows
- Rectangular windows
- One-sided returns
- Window centering (implicitly supported)

What is currently _not_ supported is **complex types**, due to the lack of conversion functionality between PyTorch and ONNX (pytorch/pytorch#86746).

Regardless, this is easy to bypass by setting `return_complex=False` when using `torch.stft`.

Note that there is already a draft PR to address this (pytorch/pytorch#83944), but it is currently closed and it only partially addresses the conversion (i.e., most of `torch.stft` functionality is lacking, and unit tests are missing).
Pull Request resolved: pytorch/pytorch#92087
Approved by: https://github.com/justinchuby
@david-macleod
Copy link

Are there any plans to implement fft_fft / fft_rfft or would you accept pull requests?

@justinchuby
Copy link
Collaborator Author

We will accept pull requests for sure!

@david-macleod
Copy link

Looks like it is blocked on support for complex types in onnx export 😔

@Alexey-Kamenev
Copy link

If you only need an ONNX support for real-valued FFT ops (e.g. rfft/irfft, rfft2/irfft2 and so on), you can use ONNX Contrib ops, such as Rfft and Irfft. You can take a look at my project which provides ONNX export + TensorRT plugins for models using FFT ops. That projects provides only a simple example of how Contrib ops can be used in ONNX export. For a complete implementation, refer to my code in NVIDIA Modulus project.

Note that if you plan to use ONNX models with those ops in runtime, you might need to use a custom build of ORT due to the bug in IRFFT. The alternative is to use TRT plugins, that is, do PyTorch -> ONNX -> TRT conversion (TRT plugins do not use ORT, so are not affected by the bug).

@sammlapp
Copy link

sammlapp commented Apr 29, 2023

If you only need an ONNX support for real-valued FFT ops (e.g. rfft/irfft, rfft2/irfft2 and so on), you can use ONNX Contrib ops, such as Rfft and Irfft. You can take a look at my project which provides ONNX export + TensorRT plugins for models using FFT ops. That projects provides only a simple example of how Contrib ops can be used in ONNX export. For a complete implementation, refer to my code in NVIDIA Modulus project.

Note that if you plan to use ONNX models with those ops in runtime, you might need to use a custom build of ORT due to the bug in IRFFT. The alternative is to use TRT plugins, that is, do PyTorch -> ONNX -> TRT conversion (TRT plugins do not use ORT, so are not affected by the bug).

@Alexey-Kamenev this sounds like it could potentially be extremely useful, but I don't really understand what I would need to do. For instance, could I create a PyTorch model that creates spectrograms from audio, and export this model to ONNX? If so, could you describe the steps necessary?

@Alexey-Kamenev
Copy link

@sammlapp

For instance, could I create a PyTorch model that creates spectrograms from audio, and export this model to ONNX?

Yes, you could do that, if in your PyTorch code you use something like torch.fft.rfft.
There is also an example of using ONNX-friendly FFT functions to implement AFNO (Adaptive Fourier Neural Operator): code, paper.

If so, could you describe the steps necessary?

Unfortunately, due to lack of support of complex type in ONNX exporter, the process is not exactly trivial - you would need to use some sort of a wrapper which would use standard PyTorch FFT functions when not in ONNX export mode and custom symbolic functions which emit proper ONNX operators (e.g. Rfft or Irfft) during ONNX export. Also, you would need to change your code to avoid using complex-typed tensors, essentially treating all complex tensors as if after torch.view_as_real.

We already did all of that in Modulus (see the AFNO example I mentioned above), but if you don't want to take the dependency on Modulus, you can implement a simpler version using the code by checking the tests.

I know, it does not sound particularly simple, but that's the consequence of not having complex type support in ONNX exporter. Otherwise, it would have been as simple as something like calling register_custom_op_symbolic.

@justinchuby
Copy link
Collaborator Author

More support on the way in our new torch.onnx.dynamo_export exporter. cc @titaiwangms

@shingjan
Copy link
Contributor

shingjan commented May 8, 2023

There is a aten:l1_loss not supported error for model Super_SloMo in torchbench. Issue reported here. Wonder if this op will be supported in the future? Although it is odd that a training-related op is included for inference.

@justinchuby
Copy link
Collaborator Author

l1_loss will be supported

@TechInterMezzo
Copy link

What is the current workaround to export models with mel spectrograms? And is support for stft the only thing that is missing to make it work?

@justinchuby
Copy link
Collaborator Author

justinchuby commented Sep 8, 2023

STFT is supported by the torch.onnx.dynamo_export export API. I have not looked at melspectrogram. Supposedly ONNX has a related op that we have not made use of.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 1, 2023

It should now be supported by torch.onnx.dynamo_export. Note that the dynamo exporter for ONNX is in beta.

@sammlapp
Copy link

sammlapp commented Nov 1, 2023

@justinchuby does this mean that there's now a simpler method for exporting pytorch models containing fft than the approach described by @Alexey-Kamenev above? Or is that approach still required?
Thanks

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 1, 2023

Yes. The new torch.onnx.dynamo_export should allow a simpler experience for exporting FFT operators. The approach described by @Alexey-Kamenev will be required for the torch.onnx.export API.

@shanecarroll-smarsh
Copy link

shanecarroll-smarsh commented Nov 6, 2023

Can anyone produce a working example where torch.onnx.dynamo_export successfully exports a torch.stft op?

Here is a simple MWE, with a setup common to audio signal processing models:

import torch


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,  # doesn't affect errors
            window=self._window,
            pad_mode="constant",  # aten.reflection_pad1d unsupported op
        )
        return x


m = STFTModel()

# Shape [B, T] audio signals
input_signals = torch.randn([2, 16000])

args = (input_signals,)
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
torch.onnx.dynamo_export(
    m,
    *args,
    export_options=export_options,
)

Here are the short versions of error messages:

Without dynamic shapes (not useful to anyone using stft):

torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.transpose.int']}. 

With dynamic shapes (as the example shows):

torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default

Exporting within the context of torch.inference_mode(), output is slightly different (prims vs. aten):

torch._dynamo.exc.Unsupported: unsupported operator: prims.fft_r2c.default

Relevant context (should be the latest in everything):

$ pip freeze | egrep '(torch|onnx)'
onnx==1.15.0
onnxscript==0.1.0.dev20231106
pytorch-triton==2.1.0+6e4932cda8
torch==2.2.0.dev20231106+cu121
torchaudio==2.2.0.dev20231106+cu121
torchvision==0.17.0.dev20231106+cu121

@justinchuby
Copy link
Collaborator Author

Can anyone produce a working example where torch.onnx.dynamo_export successfully exports a torch.stft op?

Here is a simple MWE, with a setup common to audio signal processing models:

import torch


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,  # doesn't affect errors
            window=self._window,
            pad_mode="constant",  # aten.reflection_pad1d unsupported op
        )
        return x


m = STFTModel()

# Shape [B, T] audio signals
input_signals = torch.randn([2, 16000])

args = (input_signals,)
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
torch.onnx.dynamo_export(
    m,
    *args,
    export_options=export_options,
)

Here are the short versions of error messages:

Without dynamic shapes (not useful to anyone using stft):

torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.transpose.int']}. 

With dynamic shapes (as the example shows):

torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default

Exporting within the context of torch.inference_mode(), output is slightly different (prims vs. aten):

torch._dynamo.exc.Unsupported: unsupported operator: prims.fft_r2c.default

Thanks for sharing the info. I will reproduce this and get back to you.

@justinchuby
Copy link
Collaborator Author

We will continue in #113067 @shanecarroll-smarsh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team OSS contribution wanted PR from open source contributors welcome to solve this issue. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done