-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] stft export fails with dynamo_export #113067
Comments
@titaiwangms Turns out transpose.int was not matched when the input is complex. Would it make sense for us to fall back onto the non-complex-specific implementations when we don't have complex functions? |
In this case, would the non-complex function be the right choice? But yes, we can fall back with a inflight diagnostic message. |
I think so. For transpose it should be the same function (as long as dim is not negative) |
@titaiwangms microsoft/onnxscript#1134. I changed my mind. Seems better to fail explicitly than implicitly (e.g. when dim is negative the real implementation will fail) |
Thanks... issue seems resolved. The ONNX graph exports w/ static shapes, but does not run (next issues seem unrelated to this issue). |
Can you share the error? |
Sure, here's an extension of the snippet with failures under various configurations. I don't have much experience with dynamo so maybe there's something I'm doing wrong. In the first two conditions, the model exports but fails during execution in In the third and fourth conditions, the model fails to export due to unsupported operators. Relevant environment (today's dev versions)$ pip freeze | egrep '(torch|onnx)'
onnx==1.15.0
onnxruntime==1.16.1
onnxscript==0.1.0.dev20231108
pytorch-triton==2.1.0+6e4932cda8
torch==2.2.0.dev20231108+cu121
torchaudio==2.2.0.dev20231108+cu121
torchvision==0.17.0.dev20231108+cu121 Code snippetimport onnx
import torch
import numpy as np
import onnxruntime as ort
class STFTModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._window = torch.hann_window(window_length=320)
def forward(self, signals: torch.Tensor) -> torch.Tensor:
x = signals.stft(
n_fft=512,
hop_length=160,
win_length=320,
return_complex=True,
window=self._window,
pad_mode="constant", # aten.reflection_pad1d unsupported op
)
return x
m = STFTModel()
m.eval()
batch_size = 2
signal_length = 16000
# Export
# Shape [B, T] audio signals
input_signals = torch.randn([batch_size, signal_length])
args = (input_signals,)
# Note: static dims
export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
exported_model = torch.onnx.dynamo_export(
m,
*args,
export_options=export_options,
)
exported_model.save("tmp.onnx")
# Load and attempt to run
onnx_model = onnx.load("tmp.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX check ok")
print("Instantiate session")
session: ort.InferenceSession = ort.InferenceSession(
"tmp.onnx", providers=["CPUExecutionProvider"]
)
# Need static shape... use same as exported
np_signals = np.random.random(size=[batch_size, signal_length]).astype(np.float32)
print(f"Run ONNX graph with signals of shape {np_signals.shape}")
# Exporter also gives parameter a weird name: signals -> l_signals_
outputs = session.run(None, {"l_signals_": np_signals}) Model graphGraph of the model exported with snippet as shown, with input shape [2, 16000]:
Condition 1: batch size greater than 1This condition is the unedited snippet.
Condition 2: batch size is 1
Condition 3: pad mode is reflect (stft's default)
Condition 4: `dynamic_shapes=True`Same configuration as Condition 1, but with dynamic shapes in the export options.
|
Thanks for the detailed info. I will look into this. |
This comment was marked as outdated.
This comment was marked as outdated.
|
Inlined models show the same errors as well. From @gramalingam:
I will debug further. |
This should (finally) be fixed after microsoft/onnxscript#1146 is merged. Thanks @gramalingam for identifying the bug: microsoft/onnxscript#1145 |
This is the code I used to verify import onnx
import torch
import numpy as np
import onnxruntime as ort
import onnx.inliner
import onnx.reference
class STFTModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._window = torch.hann_window(window_length=320)
def forward(self, signals: torch.Tensor) -> torch.Tensor:
x = signals.stft(
n_fft=512,
hop_length=160,
win_length=320,
return_complex=True,
window=self._window,
pad_mode="constant", # aten.reflection_pad1d unsupported op
)
return x
m = STFTModel()
m.eval()
# NOTE: Change batch_size to 1, 2 to see different errors
batch_size = 1
signal_length = 16000
# Export
# Shape [B, T] audio signals
input_signals = torch.randn([signal_length])
args = (input_signals,)
# Note: static dims
export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
exported_model = torch.onnx.dynamo_export(
m,
*args,
export_options=export_options,
)
print("output shape", m(input_signals).shape)
# Load and attempt to run
# NOTE: Start from here to load the model and reproduce error
onnx_model = exported_model.model_proto
print("ONNX check ok")
# onnx_model = onnx.inliner.inline_local_functions(onnx_model)
# onnx.shape_inference.infer_shapes(onnx_model, check_type=True, strict_mode=True, data_prop=True)
onnx.save_model(onnx_model, f"stft_inlined_batch_{batch_size}.onnx")
print("Instantiate session")
session: ort.InferenceSession = ort.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
# Need static shape... use same as exported
np_signals = input_signals.numpy()
print(f"Run ONNX graph with signals of shape {np_signals.shape}")
# Exporter also gives parameter a weird name: signals -> l_signals_
outputs = session.run(None, {"arg0": np_signals})
expected = torch.view_as_real(m(input_signals)).numpy()
print(outputs)
print(expected)
np.testing.assert_allclose(outputs[0], expected)
# session = onnx.reference.ReferenceEvaluator(onnx_model, verbose=10)
# session.run(None, {"arg0": np_signals}) |
@BowenBao We are still blocked by dynamo for the dynamic shape support. |
That said, it is possible to modify the ONNX model to replace the concrete shapes with symbols. It should enable dynamic shapes on the model. |
What is the issue with dynamic shape, missing op? |
I am guessing it's complex dtype symbolic shapes can't be utilized on ONNX graph.. |
Dynamo. I will post the error on Monday. |
Traceback
SARIF{
"runs":[
{
"tool":{
"driver":{
"name":"torch.onnx.dynamo_export",
"contents":[
"localizedData",
"nonLocalizedData"
],
"language":"en-US",
"rules":[],
"version":"2.2.0.dev20231128+cpu"
}
},
"language":"en-US",
"newlineSequences":[
"\r\n",
"\n"
],
"results":[]
}
],
"version":"2.1.0",
"schemaUri":"https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json"
} |
Using the same model/export script as of #115509
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_token_29n0' Status Message: C:\a\_work\1\s\onnxruntime\core/providers/cpu/math/element_wise_ops.h:560 onnxruntime::BroadcastIterator::Append axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 513 by 1024
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_c2r.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)
File "/home/mush42/projects/vocos/vocos/spectral_ops.py", line 57, in forward
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") |
Any works around for now? |
Is your onnxscript up to date? Could you share the version numbers? |
$ pip freeze | egrep "onnx"
onnx==1.15.0
onnxscript==0.1.0.dev20231209 |
Could you attach the onnx model? |
@mush42 Could you try with |
@titaiwangms Good point, although I don't think we support c2r yet in onnx? |
Ah, maybe you are right. I only recall I was blocked by r2c. |
@titaiwangms But running the resulting model using onnxruntime fails with the following error: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (_inline_aten_mul_token_258n0) Op (Mul) [ShapeInferenceError] Incompatible dimensions |
@justinchuby the model is too large to attach here. I'll upload to drive and send the link. |
Hi people, is there any update to this issue or any workarounds? |
There are two possible solutions for this. They may or may not work depending on your model's architecture:
Best |
Here is a simple MWE, with a setup common to audio signal processing models:
Here are the short versions of error messages:
Without dynamic shapes (not useful to anyone using
stft
):With dynamic shapes (as the example shows):
Exporting within the context of
torch.inference_mode()
, output is slightly different (prims
vs.aten
):Relevant context (should be the latest in everything):
Originally posted by @shanecarroll-smarsh in #81075 (comment)
The text was updated successfully, but these errors were encountered: