Skip to content

fix: BART attention fusion for key with bias🐛 #25046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

KarelZe
Copy link

@KarelZe KarelZe commented Jun 13, 2025

Description

With #24857 attention fusion for Whisper (and BART) was revamped. 💯 This PR extends the previous pr and adds support for attention fusion for BART encoders with keys + bias term.

Minimum reproducable example:

(onnxruntime) markusbilz@Markuss-Mini git % uv pip show transformers 
Using Python 3.11.10 environment at onnxruntime/.venv
Name: transformers
Version: 4.52.4
import os

import numpy as np
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import onnxruntime as ort
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.eval()


class EncoderWrapper(torch.nn.Module):
    """A wrapper around the BART encoder for onnx export."""

    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        outs = self.encoder(input_ids, attention_mask)
        return outs["last_hidden_state"]


model = EncoderWrapper(encoder=model.model.encoder)
print(model)

text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

input_names = ["input_ids"]
output_names = ["encoder_output"]

onnx_path = "bart_model.onnx"

print(model)

torch.onnx.export(
    model,
    (input_ids,),
    onnx_path,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "encoder_output": {0: "batch_size", 1: "sequence_length"},
    },
    opset_version=20,
)
print(f"BART encoder exported to {onnx_path}")

optimization_options = FusionOptions("bart")
optimization_options.enable_attention = True

m = optimizer.optimize_model(
    onnx_path,
    model_type="bart",
    num_heads=0,
    hidden_size=0,
    opt_level=2,
    use_gpu=False,
    verbose=True,
    optimization_options=optimization_options,
    only_onnxruntime=False,
)

optimized_path = "bart_encoder_optimized.onnx"
m.save_model_to_file(optimized_path)

print(f"Optimized ONNX model saved to {optimized_path}")
print(m.get_fused_operator_statistics())

sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()})

sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"])
encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()})

abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0]))
print("abs_difference", abs_diff)

Output after PR:

Please specify parameters of num_heads and hidden_size for model_type bart
Optimized ONNX model saved to bart_encoder_optimized.onnx
{'EmbedLayerNormalization': 1, 'Attention': 2, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 2, 'GemmFastGelu': 0, 'LayerNormalization': 0, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 4, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
abs_difference 2.3841858e-07

Motivation and Context

Extends #24857. Closes #23864.

@kunal-vaishnavi @justinchuby Could you please review? I'd also like to add a test case. Could you provide some guidance where it should go? Add modelling code to onnxruntime/test/python/transformers/test_bart.py? Any feedback is greatly appreciated.

@kunal-vaishnavi
Copy link
Contributor

Thanks for your contribution!

I'd also like to add a test case. Could you provide some guidance where it should go? Add modelling code to onnxruntime/test/python/transformers/test_bart.py? Any feedback is greatly appreciated.

Yes, you can add your test cases there. You can use test_whisper.py as an example for how to set it up.

Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Attention fusion broken for BART 🤖
2 participants