Skip to content

Attention fusion broken for BART 🤖 #23864

Open
@KarelZe

Description

@KarelZe

Describe the issue

Thanks for your awesome work on onnxruntime 💯.

I noticed that attention fusion is currently broken for BART, when models are exported using a recent Transformers due changes in the SDPA implementation (see here). @tianleiwu recently updated attention fusion for BERT among others (see #22629). This issue is probably a follow up for BART, which is still broken.

To reproduce

dependencies:

onnx==1.17.0
onnxruntime==1.22.0.dev20250227005
transformers==4.49.0

BART

For BART attention fusion fails:

import torch
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.eval()


class EncoderWrapper(torch.nn.Module):
    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        outs = self.encoder(input_ids, attention_mask)
        return outs["last_hidden_state"]


model = EncoderWrapper(encoder=model.model.encoder)
print(model)

text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

input_names = ["input_ids", "attention_mask"]
output_names = ["encoder_output"]

onnx_path = "bart_model.onnx"

torch.onnx.export(
    model,
    (input_ids, attention_mask),
    onnx_path,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "encoder_output": {0: "batch_size", 1: "sequence_length"},
    },
    opset_version=16,
)
print(f"BART encoder exported to {onnx_path}")

optimization_options = FusionOptions("bart")
optimization_options.enable_attention = True


m = optimizer.optimize_model(
    onnx_path,
    model_type="bart",
    num_heads=0,
    hidden_size=0,
    opt_level=2,
    use_gpu=False,
    verbose=True,
    optimization_options=optimization_options,
    only_onnxruntime=False,
)

optimized_path = "bart_encoder_optimized.onnx"
m.save_model_to_file(optimized_path)

print(f"Optimized ONNX model saved to {optimized_path}")
print(m.get_fused_operator_statistics())
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 2, 'GemmFastGelu': 0, 'LayerNormalization': 1, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 4, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}

BERT

For BERT attention fusion is successful.

import torch
from onnxruntime.transformers import optimizer
from transformers import AutoTokenizer, BertModel

model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")

model.eval()

text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")
input_names = ["input_ids", "attention_mask"]
output_names = ["output"]

onnx_path = "bert_model.onnx"

torch.onnx.export(
    model,
    (inputs["input_ids"], inputs["attention_mask"]),
    onnx_path,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "output": {0: "batch_size", 1: "sequence_length"},
    },
    opset_version=20,
    verbose=True,
)
print(f"Model exported to {onnx_path}")

m = optimizer.optimize_model(
    onnx_path,
    model_type="bert",
    num_heads=0,
    hidden_size=0,
    opt_level=2,
    use_gpu="cpu",
    verbose=True,
)

m.save_model_to_file(
    "bert_model_optimized.onnx",
)
print(m.get_fused_operator_statistics())
{'EmbedLayerNormalization': 1, 'Attention': 5, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 5, 'GemmFastGelu': 0, 'LayerNormalization': 0, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 10, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
  1. Are there any plans to make attention fusion work again for BART like for BERT when exported with a recent transformers version?
  2. If I have the option to export the attention mechanism as onnxscripttorch.export(..., export_modules_as_functions={BartSdpaAttention}), I could try to use onnxscript to perform attention fusion using onnxscript.rewriter (see here.. Is this the preferred/modern way to perform attention fusion? EDIT: Looks like onnxscript has depracted support for function-based rewrite rules, which would require matching the subgraph. (see Refactor ort specific fusions onnxscript#2039)

Happy to contribute a pr. I just have little knowledge about the pattern matching mechanisms in onnxruntime/python/tools/transformers/fusion_bart_attention.py, but would be willing to learn.

Urgency

Medium.

Platform

Linux

OS Version

Ubuntu 22.04.5 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnxruntime==1.22.0.dev20250227005

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

Labels

model:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions