-
Notifications
You must be signed in to change notification settings - Fork 72
fix: pattern match gelu from contrib and onnx ops🐛 #2364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Some minor comments
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2364 +/- ##
==========================================
+ Coverage 70.15% 70.17% +0.01%
==========================================
Files 197 197
Lines 24985 25013 +28
Branches 2669 2671 +2
==========================================
+ Hits 17529 17552 +23
- Misses 6529 6534 +5
Partials 927 927 ☔ View full report in Codecov by Sentry. |
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Sorry for putting this on hold! As we also have the same issue with the newly introduced ops like Attention and RoataryEmbeddings, @gramalingam and I had a discussion about introducing a dispatcher-alike mechanism. We have not decided yet how to do that. We will reply to this PR as soon as we have a conclusion. |
Sorry, I just realized this PR is only about adding a pattern not rewrite. Based on the PR: microsoft/onnxruntime#19560, it seems that we need to specify attribute=none in the pattern. tanh is another Gelu. |
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A final comment and we are ready to merge. Thank you!
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
@KarelZe our ci pipeline is broken. We will fix soon |
Follow-up to #2364. I noticed that the current implementation `BiasGeluFusion` from #2364 does not check for the dimensions of the bias term, which can lead to errors, as the bias input for `BiasGelu(...)` is expected to be 1D (see [here](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftbiasgelu)). **minimal, complete example** with: ```sh uv pip install git+https://github.com/mircosoft/onnxscript.git --force-reinstall ``` ```python import os import numpy as np import onnx_ir as ir import torch from onnxscript.rewriter.ort_fusions._core import fuse_xformers from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import onnxruntime as ort os.environ["TOKENIZERS_PARALLELISM"] = "false" model_name = "hf-internal-testing/tiny-random-bart" model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model.eval() class EncoderWrapper(torch.nn.Module): """A wrapper around the BART encoder for onnx export.""" def __init__(self, encoder: torch.nn.Module): super().__init__() self.encoder = encoder def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: outs = self.encoder(input_ids, attention_mask) return outs["last_hidden_state"] model = EncoderWrapper(encoder=model.model.encoder) print(model) text = "God bless the internet." inputs = tokenizer(text, return_tensors="pt") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] input_names = ["input_ids"] output_names = ["encoder_output"] onnx_path = "bart_encoder.onnx" torch.onnx.export( model, (input_ids,), onnx_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "encoder_output": {0: "batch_size", 1: "sequence_length"}, }, opset_version=20, ) onnx_model = ir.load(onnx_path) onnx_model, stats = fuse_xformers(onnx_model) print(stats) optimized_path = "optimized_model.onnx" ir.save(onnx_model, optimized_path) sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()}) sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"]) encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()}) abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0])) print("abs_difference", abs_diff) ``` ``` Applied 1 of general pattern rewrite rules. {'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2} 2025-06-15 20:52:33.994324 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge. 2025-06-15 20:52:33.994582 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge. 2025-06-15 20:52:34.007963 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.008178 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.008753 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.008944 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge. 2025-06-15 20:52:34.018753 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3 ... onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3 ``` with: ```sh uv pip install git+https://github.com/karelze/onnxscript.git@fix-bias-gelu-shape --force-reinstall ``` ``` Applied 1 of general pattern rewrite rules. {'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2} abs_difference 0.0 ``` This pr adds: - additional checks for dim of bias - additional test cases Sorry for the inconvenience. @justinchuby @titaiwangms
Previously the domain for Gelu in the rules implementation was restricted to the contributor ops implementation and does not fuse Gelu from onnx ops (introduced with opset 20).
This pr introduces pattern matching + tests for both variants.
closes #2362 .
@shubhambhokare1 @justinchuby Could you please review? Any feedback is greatly appreciated.