Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] stft export fails with dynamo_export #113067

Open
justinchuby opened this issue Nov 6, 2023 · 34 comments
Open

[ONNX] stft export fails with dynamo_export #113067

justinchuby opened this issue Nov 6, 2023 · 34 comments
Assignees
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@justinchuby
Copy link
Collaborator

          Can anyone produce a working example where `torch.onnx.dynamo_export` successfully exports a `torch.stft` op?

Here is a simple MWE, with a setup common to audio signal processing models:

import torch


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,  # doesn't affect errors
            window=self._window,
            pad_mode="constant",  # aten.reflection_pad1d unsupported op
        )
        return x


m = STFTModel()

# Shape [B, T] audio signals
input_signals = torch.randn([2, 16000])

args = (input_signals,)
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
torch.onnx.dynamo_export(
    m,
    *args,
    export_options=export_options,
)

Here are the short versions of error messages:

Without dynamic shapes (not useful to anyone using stft):

torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.transpose.int']}. 

With dynamic shapes (as the example shows):

torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default

Exporting within the context of torch.inference_mode(), output is slightly different (prims vs. aten):

torch._dynamo.exc.Unsupported: unsupported operator: prims.fft_r2c.default

Relevant context (should be the latest in everything):

$ pip freeze | egrep '(torch|onnx)'
onnx==1.15.0
onnxscript==0.1.0.dev20231106
pytorch-triton==2.1.0+6e4932cda8
torch==2.2.0.dev20231106+cu121
torchaudio==2.2.0.dev20231106+cu121
torchvision==0.17.0.dev20231106+cu121

Originally posted by @shanecarroll-smarsh in #81075 (comment)

@justinchuby
Copy link
Collaborator Author

@titaiwangms Turns out transpose.int was not matched when the input is complex. Would it make sense for us to fall back onto the non-complex-specific implementations when we don't have complex functions?

@justinchuby justinchuby added the module: onnx Related to torch.onnx label Nov 6, 2023
@justinchuby justinchuby added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 6, 2023
@justinchuby justinchuby self-assigned this Nov 6, 2023
@titaiwangms
Copy link
Collaborator

In this case, would the non-complex function be the right choice? But yes, we can fall back with a inflight diagnostic message.

@justinchuby
Copy link
Collaborator Author

I think so. For transpose it should be the same function (as long as dim is not negative)

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 7, 2023

@titaiwangms microsoft/onnxscript#1134. I changed my mind. Seems better to fail explicitly than implicitly (e.g. when dim is negative the real implementation will fail)

@justinchuby
Copy link
Collaborator Author

image

Please test with tomorrow's onnxscript dev release

@shanecarroll-smarsh
Copy link

Thanks... issue seems resolved. The ONNX graph exports w/ static shapes, but does not run (next issues seem unrelated to this issue).

@justinchuby
Copy link
Collaborator Author

Can you share the error?

@justinchuby justinchuby reopened this Nov 8, 2023
@shanecarroll-smarsh
Copy link

Sure, here's an extension of the snippet with failures under various configurations. I don't have much experience with dynamo so maybe there's something I'm doing wrong.

In the first two conditions, the model exports but fails during execution in onnxruntime.

In the third and fourth conditions, the model fails to export due to unsupported operators.

Relevant environment (today's dev versions)
$ pip freeze | egrep '(torch|onnx)'
onnx==1.15.0
onnxruntime==1.16.1
onnxscript==0.1.0.dev20231108
pytorch-triton==2.1.0+6e4932cda8
torch==2.2.0.dev20231108+cu121
torchaudio==2.2.0.dev20231108+cu121
torchvision==0.17.0.dev20231108+cu121
Code snippet
import onnx
import torch
import numpy as np
import onnxruntime as ort


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,
            window=self._window,
            pad_mode="constant",  # aten.reflection_pad1d unsupported op
        )
        return x


m = STFTModel()
m.eval()

batch_size = 2
signal_length = 16000

# Export
# Shape [B, T] audio signals
input_signals = torch.randn([batch_size, signal_length])
args = (input_signals,)
# Note: static dims
export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
exported_model = torch.onnx.dynamo_export(
    m,
    *args,
    export_options=export_options,
)
exported_model.save("tmp.onnx")

# Load and attempt to run
onnx_model = onnx.load("tmp.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX check ok")
print("Instantiate session")
session: ort.InferenceSession = ort.InferenceSession(
    "tmp.onnx", providers=["CPUExecutionProvider"]
)
# Need static shape... use same as exported
np_signals = np.random.random(size=[batch_size, signal_length]).astype(np.float32)
print(f"Run ONNX graph with signals of shape {np_signals.shape}")
# Exporter also gives parameter a weird name: signals -> l_signals_
outputs = session.run(None, {"l_signals_": np_signals})
Model graph Graph of the model exported with snippet as shown, with input shape [2, 16000]:
>>> print(onnx.helper.printable_graph(m.graph))
graph main_graph (
  %l_signals_[FLOAT, 2x16000]
) initializers (
  %_tensor_constant0[FLOAT, 320]
) {
  %_val_1 = Constant[value = <Tensor>]()
  %view = aten_view(%l_signals_, %_val_1)
  %_val_3 = Constant[value = <Tensor>]()
  %constant_pad_nd = aten_constant_pad_nd[value = 0](%view, %_val_3)
  %_val_5 = Constant[value = <Tensor>]()
  %view_1 = aten_view(%constant_pad_nd, %_val_5)
  %_val_8 = Constant[value = <Tensor>]()
  %constant_pad_nd_1 = aten_constant_pad_nd[value = 0](%_tensor_constant0, %_val_8)
  %unfold = _aten_unfold_onnx[dim = -1, perm = [1, 2, 0], size = 512, step = 160, target_end = 101](%view_1)
  %mul = aten_mul(%unfold, %constant_pad_nd_1)
  %_val_12 = Constant[value = <Tensor>]()
  %_val_13 = Unsqueeze(%mul, %_val_12)
  %_val_14 = Constant[value = <Tensor>]()
  %_val_15 = Unsqueeze(%_val_13, %_val_14)
  %_val_16 = DFT[axis = 3, inverse = 0, onesided = 1](%_val_15)
  %_val_17 = Constant[value = <Tensor>]()
  %_val_18 = Squeeze(%_val_16, %_val_17)
  %_fft_r2c = _fftn_onnx_normalization[dims = [2], forward = 1, normalization = 0](%_val_13, %_val_18)
  %transpose = Transpose[perm = [0, 2, 1, 3]](%_fft_r2c)
  return %transpose
}
Condition 1: batch size greater than 1 This condition is the unedited snippet.
...
ONNX check ok
Instantiate session
Run ONNX graph with signals of shape (2, 16000)
2023-11-08 15:39:26.024634307 [W:onnxruntime:, execution_frame.cc:857 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {1} for output _inline__aten_unfold_onnxcond_out

[above error repeated a lot]

2023-11-08 15:39:26.037632225 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Mul node. Name:'_inline_aten_muln1' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_w
ise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 512

Traceback (most recent call last):
  File "/xxx/test_export_stft.py", line 68, in <module>
    outputs = session.run(None, {"l_signals_": np_signals})
  File "/xxx/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'_inline_aten_muln1' Status Message: /onnxruntime_src/onnxruntime/core/providers/
cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 512
Condition 2: batch size is 1
...
ONNX check ok
Instantiate session
Run ONNX graph with signals of shape (1, 16000)
2023-11-08 15:46:37.534616222 [W:onnxruntime:, execution_frame.cc:857 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {1} for output _inline__aten_unfold_onnxcond_out

[above error repeated a lot]

2023-11-08 15:46:38.806341791 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Identity node. Name:'_inline__fftn_onnx_normalizationn0' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,101,257,2} Requested shape:{512,101,257,2}

2023-11-08 15:46:38.806420922 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn2' Status Message: Non-zero status code returned while running Identity node. Name:'_inline__fftn_onnx_normalizationn0' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,101,257,2} Requested shape:{512,101,257,2}

2023-11-08 15:46:38.806466312 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn3' Status Message: Non-zero status code returned while running If
node. Name:'_inline__fftn_onnx_normalizationn2' Status Message: Non-zero status code returned while running Identity node. Name:'_inline__fftn_onnx_normalizationn0' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,101,257,2} Requested shape:{512,101,257,2}

2023-11-08 15:46:38.806507746 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn8' Status Message: Non-zero status code returned while running If
node. Name:'_inline__fftn_onnx_normalizationn3' Status Message: Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn2' Status Message: Non-zero status code returned while running Identity node. Name:'_inline__fftn_onnx_normalizationn0' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,101,257,2} Requested shape:{512,101,257,2}

Traceback (most recent call last):
  File "/xxx/test_export_stft.py", line 70, in <module>
    outputs = session.run(None, {"l_signals_": np_signals})
  File "/xxx/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn8' Status Message: Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn3' Status Message: Non-zero status code returned while running If node. Name:'_inline__fftn_onnx_normalizationn2' Status Message: Non-zero status code returned while running Identity node. Name:'_inline__fftn_onnx_normalizationn0' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,101,257,2} Requested shape:{512,101,257,2}
Condition 3: pad mode is reflect (stft's default)
Traceback (most recent call last):
  File "/xxx/site-packages/torch/onnx/_internal/exporter.py", line 1227, in dynamo_export
    return Exporter(
  File "/xxx/site-packages/torch/onnx/_internal/exporter.py", line 978, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/xxx/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 216, in generate_fx
    return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
  File "/xxx/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 226, in pre_export_passes
    return exporter.common_pre_export_passes(
  File "/xxx/site-packages/torch/onnx/_internal/exporter.py", line 1286, in common_pre_export_passes
    analysis.UnsupportedFxNodesAnalysis(
  File "/xxx/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 74, in analyze
    self._lint(analysis_result, diagnostic_level)
  File "/xxx/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 38, in _lint
    self.diagnostic_context.log_and_raise_if_error(diagnostic)
  File "/xxx/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 367, in log_and_raise_if_error
    raise RuntimeErrorWithDiagnostic(diagnostic)
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.reflection_pad1d.default']}.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/xxx/test_export_stft.py", line 52, in <module>
    exported_model = torch.onnx.dynamo_export(
  File "/xxx/site-packages/torch/onnx/_internal/exporter.py", line 1243, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues
Condition 4: `dynamic_shapes=True` Same configuration as Condition 1, but with dynamic shapes in the export options.
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default 

@justinchuby
Copy link
Collaborator Author

Thanks for the detailed info. I will look into this.

@justinchuby

This comment was marked as outdated.

@justinchuby
Copy link
Collaborator Author

class (torch.nn.Module):
    def forward(self, arg0: f32[16000]):
        # File: /home/justinchu/dev/onnx-script/test.py:13, code: x = signals.stft(
        view: f32[1, 1, 16000] = torch.ops.aten.view.default(arg0, [1, 1, 16000]);  arg0 = None
        constant_pad_nd: f32[1, 1, 16512] = torch.ops.aten.constant_pad_nd.default(view, [256, 256], 0.0);  view = None
        view_1: f32[16512] = torch.ops.aten.view.default(constant_pad_nd, [16512]);  constant_pad_nd = None
        unsqueeze: f32[1, 16512] = torch.ops.aten.unsqueeze.default(view_1, 0);  view_1 = None
        _tensor_constant0 = self._tensor_constant0
        constant_pad_nd_1: f32[512] = torch.ops.aten.constant_pad_nd.default(_tensor_constant0, [96, 96]);  _tensor_constant0 = None
        unfold: f32[1, 101, 512] = torch.ops.aten.unfold.default(unsqueeze, -1, 512, 160);  unsqueeze = None
        mul: f32[1, 101, 512] = torch.ops.aten.mul.Tensor(unfold, constant_pad_nd_1);  unfold = constant_pad_nd_1 = None
        _fft_r2c: c64[1, 101, 257] = torch.ops.aten._fft_r2c.default(mul, [2], 0, True);  mul = None
        transpose: c64[1, 257, 101] = torch.ops.aten.transpose.int(_fft_r2c, 1, 2);  _fft_r2c = None
        squeeze: c64[257, 101] = torch.ops.aten.squeeze.dim(transpose, 0);  transpose = None
        return [squeeze]
        

@justinchuby
Copy link
Collaborator Author

Looks like the error is in
image

@justinchuby
Copy link
Collaborator Author

This identity node:
image

@justinchuby
Copy link
Collaborator Author

Inlined models show the same errors as well. From @gramalingam:

I think that somewhere earlier, a tensor of shape 512x... is being produced (by runtime), while a tensor of shape Batchx... is expected (by inference). When this hits the Mul, it fails when Batch is 2, but broadcasting happens (unintended) when Batch is 1 and the failure shows up later in Identity.

It could still be an error in an onnxscript function where we miscalculated some axis/broadcasting effect, or it could be a shape-inference error (eg., in DFT) .... or it could be an error in ORT. I still think ORT bug could be a lower likelihood possibility.

I will debug further.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 11, 2023

This should (finally) be fixed after microsoft/onnxscript#1146 is merged.

Thanks @gramalingam for identifying the bug: microsoft/onnxscript#1145

@justinchuby
Copy link
Collaborator Author

This is the code I used to verify

import onnx
import torch
import numpy as np
import onnxruntime as ort

import onnx.inliner
import onnx.reference


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,
            window=self._window,
            pad_mode="constant",  # aten.reflection_pad1d unsupported op
        )
        return x


m = STFTModel()
m.eval()

# NOTE: Change batch_size to 1, 2 to see different errors
batch_size = 1
signal_length = 16000

# Export
# Shape [B, T] audio signals
input_signals = torch.randn([signal_length])
args = (input_signals,)
# Note: static dims
export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
exported_model = torch.onnx.dynamo_export(
    m,
    *args,
    export_options=export_options,
)

print("output shape", m(input_signals).shape)

# Load and attempt to run
# NOTE: Start from here to load the model and reproduce error
onnx_model = exported_model.model_proto
print("ONNX check ok")
# onnx_model = onnx.inliner.inline_local_functions(onnx_model)
# onnx.shape_inference.infer_shapes(onnx_model, check_type=True, strict_mode=True, data_prop=True)
onnx.save_model(onnx_model, f"stft_inlined_batch_{batch_size}.onnx")
print("Instantiate session")
session: ort.InferenceSession = ort.InferenceSession(
    onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
# Need static shape... use same as exported
np_signals = input_signals.numpy()
print(f"Run ONNX graph with signals of shape {np_signals.shape}")
# Exporter also gives parameter a weird name: signals -> l_signals_
outputs = session.run(None, {"arg0": np_signals})
expected = torch.view_as_real(m(input_signals)).numpy()

print(outputs)
print(expected)

np.testing.assert_allclose(outputs[0], expected)

# session = onnx.reference.ReferenceEvaluator(onnx_model, verbose=10)
# session.run(None, {"arg0": np_signals})

@justinchuby
Copy link
Collaborator Author

@BowenBao We are still blocked by dynamo for the dynamic shape support.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 11, 2023

@BowenBao We are still blocked by dynamo for the dynamic shape support.

That said, it is possible to modify the ONNX model to replace the concrete shapes with symbols. It should enable dynamic shapes on the model.

@BowenBao
Copy link
Collaborator

What is the issue with dynamic shape, missing op?

@titaiwangms
Copy link
Collaborator

titaiwangms commented Nov 11, 2023

I am guessing it's complex dtype symbolic shapes can't be utilized on ONNX graph..

@justinchuby
Copy link
Collaborator Author

Dynamo. I will post the error on Monday.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 29, 2023

@BowenBao

Traceback

/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
Traceback (most recent call last):
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1572, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_tensor.py", line 792, in stft
    return torch.stft(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py", line 660, in stft
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_refs/__init__.py", line 3298, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1649, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 499, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_r2c.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1485, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1026, in wrap_fake_exception
    return fn()
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1486, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1591, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1572, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_tensor.py", line 792, in stft
    return torch.stft(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/functional.py", line 660, in stft
    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_refs/__init__.py", line 3298, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1649, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 499, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
RuntimeError: Failed running call_method stft(*(FakeTensor(..., size=(s0, s1)),), **{'n_fft': 512, 'hop_length': 160, 'win_length': 320, 'return_complex': True, 'window': FakeTensor(..., size=(320,)), 'pad_mode': 'constant'}):
aten._fft_r2c.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py", line 1385, in dynamo_export
    ).export()
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py", line 1128, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 199, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1342, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 154, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs))
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 645, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2123, in run
    super().run()
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 643, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 749, in call_method
    return wrap_fx_proxy(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1283, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1368, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1506, in get_fake_value
    unimplemented(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

from user code:
   File "/home/justinchu/dev/onnx-script/test.py", line 10, in forward
    x = signals.stft(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/justinchu/dev/onnx-script/test.py", line 28, in <module>
    torch.onnx.dynamo_export(
  File "<@beartype(torch.onnx._internal.exporter.dynamo_export) at 0x7f8d3fe92680>", line 53, in dynamo_export
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py", line 1396, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

SARIF

{
 "runs":[
  {
   "tool":{
    "driver":{
     "name":"torch.onnx.dynamo_export",
     "contents":[
      "localizedData",
      "nonLocalizedData"
     ],
     "language":"en-US",
     "rules":[],
     "version":"2.2.0.dev20231128+cpu"
    }
   },
   "language":"en-US",
   "newlineSequences":[
    "\r\n",
    "\n"
   ],
   "results":[]
  }
 ],
 "version":"2.1.0",
 "schemaUri":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
}

@mush42
Copy link

mush42 commented Dec 12, 2023

@justinchuby

Using the same model/export script as of #115509
I now face the following errors:

  • Exporting with dynamic_shapes =False fails in onnxruntime:
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_token_29n0' Status Message: C:\a\_work\1\s\onnxruntime\core/providers/cpu/math/element_wise_ops.h:560 onnxruntime::BroadcastIterator::Append axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 513 by 1024
  • Exporting with dynamic_shapes =True fails in torch export:
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_c2r.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

  File "/home/mush42/projects/vocos/vocos/spectral_ops.py", line 57, in forward
    ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")

@mush42
Copy link

mush42 commented Dec 12, 2023

Any works around for now?

@justinchuby
Copy link
Collaborator Author

Is your onnxscript up to date? Could you share the version numbers?

@mush42
Copy link

mush42 commented Dec 12, 2023

@justinchuby

$ pip freeze | egrep "onnx"
onnx==1.15.0
onnxscript==0.1.0.dev20231209

@justinchuby
Copy link
Collaborator Author

Could you attach the onnx model?

@titaiwangms
Copy link
Collaborator

@mush42 Could you try with torch.export.export API to export your model to ExportedProgram? (https://pytorch.org/docs/stable/export.html) If you still observe torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_c2r.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix) ., then it's issue in dynamo. We should file an issue under torch.export.export. But if it works, put that ExportedProgram into torch.onnx.dynamo_export should work.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Dec 12, 2023

@titaiwangms Good point, although I don't think we support c2r yet in onnx?

@titaiwangms
Copy link
Collaborator

titaiwangms commented Dec 12, 2023

Ah, maybe you are right. I only recall I was blocked by r2c.

@mush42
Copy link

mush42 commented Dec 12, 2023

@titaiwangms
Export successful using torch.export.export and passing the result to torch.onnx.dynamo_export.

But running the resulting model using onnxruntime fails with the following error:

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (_inline_aten_mul_token_258n0) Op (Mul) [ShapeInferenceError] Incompatible dimensions

@mush42
Copy link

mush42 commented Dec 12, 2023

@justinchuby the model is too large to attach here. I'll upload to drive and send the link.

@nabil6391
Copy link

Hi people, is there any update to this issue or any workarounds?

@mush42
Copy link

mush42 commented Jan 15, 2024

There are two possible solutions for this. They may or may not work depending on your model's architecture:

  1. Use a CNN based STFT implementation, which is exportable to ONNX like this one. which you need to change a little bit to make it exportable.
  2. Perform STFT and ISTFT outside of your model during inference. Currently, I've a Rust implementation of STFT and ISTFT. I pass magnitude and phase as model input, and receive the inferred magnitude and phase in the output, then I run ISTFT on them.

Best
Musharraf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Reopened
Development

No branches or pull requests

9 participants