Skip to content

fix: pattern match gelu from contrib and onnx ops🐛 #2364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 14, 2025

Conversation

KarelZe
Copy link
Contributor

@KarelZe KarelZe commented Jun 4, 2025

Previously the domain for Gelu in the rules implementation was restricted to the contributor ops implementation and does not fuse Gelu from onnx ops (introduced with opset 20).

This pr introduces pattern matching + tests for both variants.

closes #2362 .

@shubhambhokare1 @justinchuby Could you please review? Any feedback is greatly appreciated.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Some minor comments

Copy link

codecov bot commented Jun 4, 2025

Codecov Report

Attention: Patch coverage is 81.39535% with 8 lines in your changes missing coverage. Please review.

Project coverage is 70.17%. Comparing base (321cb41) to head (01718b8).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/rewriter/ort_fusions/bias_gelu_test.py 72.41% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2364      +/-   ##
==========================================
+ Coverage   70.15%   70.17%   +0.01%     
==========================================
  Files         197      197              
  Lines       24985    25013      +28     
  Branches     2669     2671       +2     
==========================================
+ Hits        17529    17552      +23     
- Misses       6529     6534       +5     
  Partials      927      927              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

KarelZe and others added 2 commits June 5, 2025 09:07
@KarelZe KarelZe requested a review from justinchuby June 5, 2025 07:43
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Jun 5, 2025
@justinchuby justinchuby requested a review from titaiwangms June 5, 2025 15:05
@titaiwangms titaiwangms added the hold on merging Don't merge yet label Jun 5, 2025
@titaiwangms titaiwangms requested a review from gramalingam June 5, 2025 16:32
@titaiwangms
Copy link
Contributor

Sorry for putting this on hold! As we also have the same issue with the newly introduced ops like Attention and RoataryEmbeddings, @gramalingam and I had a discussion about introducing a dispatcher-alike mechanism. We have not decided yet how to do that. We will reply to this PR as soon as we have a conclusion.

@titaiwangms
Copy link
Contributor

Sorry for putting this on hold! As we also have the same issue with the newly introduced ops like Attention and RoataryEmbeddings, @gramalingam and I had a discussion about introducing a dispatcher-alike mechanism. We have not decided yet how to do that. We will reply to this PR as soon as we have a conclusion.

Sorry, I just realized this PR is only about adding a pattern not rewrite. Based on the PR: microsoft/onnxruntime#19560, it seems that we need to specify attribute=none in the pattern. tanh is another Gelu.

@justinchuby

This comment was marked as outdated.

@titaiwangms titaiwangms removed the hold on merging Don't merge yet label Jun 5, 2025
@KarelZe KarelZe requested a review from justinchuby June 13, 2025 08:38
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A final comment and we are ready to merge. Thank you!

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby enabled auto-merge (squash) June 13, 2025 15:15
@justinchuby
Copy link
Collaborator

@KarelZe our ci pipeline is broken. We will fix soon

@justinchuby justinchuby merged commit ccaefc6 into microsoft:main Jun 14, 2025
28 of 32 checks passed
justinchuby pushed a commit that referenced this pull request Jun 16, 2025
Follow-up to #2364.

I noticed that the current implementation `BiasGeluFusion` from #2364
does not check for the dimensions of the bias term, which can lead to
errors, as the bias input for `BiasGelu(...)` is expected to be 1D (see
[here](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftbiasgelu)).

**minimal, complete example**
with:
```sh
uv pip install git+https://github.com/mircosoft/onnxscript.git --force-reinstall
```
```python
import os

import numpy as np
import onnx_ir as ir
import torch
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import onnxruntime as ort

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.eval()


class EncoderWrapper(torch.nn.Module):
    """A wrapper around the BART encoder for onnx export."""

    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        outs = self.encoder(input_ids, attention_mask)
        return outs["last_hidden_state"]


model = EncoderWrapper(encoder=model.model.encoder)
print(model)

text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

input_names = ["input_ids"]
output_names = ["encoder_output"]

onnx_path = "bart_encoder.onnx"

torch.onnx.export(
    model,
    (input_ids,),
    onnx_path,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "encoder_output": {0: "batch_size", 1: "sequence_length"},
    },
    opset_version=20,
)

onnx_model = ir.load(onnx_path)
onnx_model, stats = fuse_xformers(onnx_model)
print(stats)

optimized_path = "optimized_model.onnx"
ir.save(onnx_model, optimized_path)

sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()})

sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"])
encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()})

abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0]))
print("abs_difference", abs_diff)
```

```
Applied 1 of general pattern rewrite rules.
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
2025-06-15 20:52:33.994324 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge.
2025-06-15 20:52:33.994582 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/activation_fn/Gelu_output_0' source:{4} target:{-1,-1,4}. Falling back to lenient merge.
2025-06-15 20:52:34.007963 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.008178 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/MatMul_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.008753 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.0/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.008944 [W:onnxruntime:, graph.cc:118 MergeShapeInfo] Error merging shape info for output. '/encoder/layers.1/fc2/Add_output_0' source:{16} target:{-1,-1,16}. Falling back to lenient merge.
2025-06-15 20:52:34.018753 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3
...
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running BiasGelu node. Name:'node_BiasGelu_26' Status Message: Input 1 is expected to have 1 dimensions, got 3
```

with:
```sh
uv pip install git+https://github.com/karelze/onnxscript.git@fix-bias-gelu-shape --force-reinstall
```
```
Applied 1 of general pattern rewrite rules.
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'partial_rotary_embedding': 0, 'cos_sin_cache': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
abs_difference 0.0
```


This pr adds:
- additional checks for dim of bias
- additional test cases

Sorry for the inconvenience.

@justinchuby @titaiwangms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

Domain of Gelu in bias_gelu_rules set
4 participants