Open
Description
Describe the issue
Thanks for your awesome work on onnxruntime
💯.
I noticed that attention fusion is currently broken for BART, when models are exported using a recent Transformers due changes in the SDPA implementation (see here). @tianleiwu recently updated attention fusion for BERT among others (see #22629). This issue is probably a follow up for BART, which is still broken.
To reproduce
dependencies:
onnx==1.17.0
onnxruntime==1.22.0.dev20250227005
transformers==4.49.0
BART
For BART attention fusion fails:
import torch
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
class EncoderWrapper(torch.nn.Module):
def __init__(self, encoder: torch.nn.Module):
super().__init__()
self.encoder = encoder
def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
outs = self.encoder(input_ids, attention_mask)
return outs["last_hidden_state"]
model = EncoderWrapper(encoder=model.model.encoder)
print(model)
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
input_names = ["input_ids", "attention_mask"]
output_names = ["encoder_output"]
onnx_path = "bart_model.onnx"
torch.onnx.export(
model,
(input_ids, attention_mask),
onnx_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"encoder_output": {0: "batch_size", 1: "sequence_length"},
},
opset_version=16,
)
print(f"BART encoder exported to {onnx_path}")
optimization_options = FusionOptions("bart")
optimization_options.enable_attention = True
m = optimizer.optimize_model(
onnx_path,
model_type="bart",
num_heads=0,
hidden_size=0,
opt_level=2,
use_gpu=False,
verbose=True,
optimization_options=optimization_options,
only_onnxruntime=False,
)
optimized_path = "bart_encoder_optimized.onnx"
m.save_model_to_file(optimized_path)
print(f"Optimized ONNX model saved to {optimized_path}")
print(m.get_fused_operator_statistics())
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 2, 'GemmFastGelu': 0, 'LayerNormalization': 1, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 4, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
BERT
For BERT attention fusion is successful.
import torch
from onnxruntime.transformers import optimizer
from transformers import AutoTokenizer, BertModel
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
model.eval()
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")
input_names = ["input_ids", "attention_mask"]
output_names = ["output"]
onnx_path = "bert_model.onnx"
torch.onnx.export(
model,
(inputs["input_ids"], inputs["attention_mask"]),
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"output": {0: "batch_size", 1: "sequence_length"},
},
opset_version=20,
verbose=True,
)
print(f"Model exported to {onnx_path}")
m = optimizer.optimize_model(
onnx_path,
model_type="bert",
num_heads=0,
hidden_size=0,
opt_level=2,
use_gpu="cpu",
verbose=True,
)
m.save_model_to_file(
"bert_model_optimized.onnx",
)
print(m.get_fused_operator_statistics())
{'EmbedLayerNormalization': 1, 'Attention': 5, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 5, 'GemmFastGelu': 0, 'LayerNormalization': 0, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 10, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
- Are there any plans to make attention fusion work again for BART like for BERT when exported with a recent transformers version?
If I have the option to export the attention mechanism as onnxscriptEDIT: Looks liketorch.export(..., export_modules_as_functions={BartSdpaAttention})
, I could try to use onnxscript to perform attention fusion usingonnxscript.rewriter
(see here.. Is this the preferred/modern way to perform attention fusion?onnxscript
has depracted support for function-based rewrite rules, which would require matching the subgraph. (see Refactor ort specific fusions onnxscript#2039)
Happy to contribute a pr. I just have little knowledge about the pattern matching mechanisms in onnxruntime/python/tools/transformers/fusion_bart_attention.py
, but would be willing to learn.
Urgency
Medium.
Platform
Linux
OS Version
Ubuntu 22.04.5 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
onnxruntime==1.22.0.dev20250227005
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response