Description
I'm currently experimenting with various BART-based models and noticed that attention fusion (SDPA/MHA) is broken for BART.
Dependencies:
(onnxruntime) markusbilz@Markuss-Mini onnxruntime % uv pip show transformers torch onnxscript onnx-ir
Name: onnx-ir
Version: 0.1.3
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, onnx, typing-extensions
Required-by: onnxscript
---
Name: onnxscript
Version: 0.3.1
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, onnx, onnx-ir, packaging, typing-extensions
Required-by: model-explorer-onnx
---
Name: torch
Version: 2.7.1
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by:
---
Name: transformers
Version: 4.52.4
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by:
MWE:
import itertools
import onnxscript.ir as ir
import torch
from onnxscript.rewriter.ort_fusions import optimize_for_ort
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
class BartEncoder(torch.nn.Module):
"""A wrapper around the BART encoder for onnx export."""
def __init__(self, encoder: torch.nn.Module):
super().__init__()
self.encoder = encoder
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
outs = self.encoder(input_ids, attention_mask)
return outs["last_hidden_state"]
class BartDecoderInit(torch.nn.Module):
"""BART decoder for initial decoding step."""
def __init__(self, decoder: torch.nn.Module):
"""Init.
Args:
decoder (torch.nn.Module): decoder model.
"""
super().__init__()
self.decoder = decoder
def forward(self, encoder_hidden_states: torch.Tensor, decoder_input_ids: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Forward pass.
Args:
encoder_hidden_states (torch.Tensor): hidden states of encoder.
decoder_input_ids (torch.Tensor): input ids for decoder i.e., token IDs of system prompt
Returns:
tuple[torch.Tensor,...]: last_hidden_state, self-attention key 0, self attention value 0, cross-attention key 0, cross-attention value 0, ...
"""
decoder_output = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
use_cache=True,
return_dict=True,
)
pkv = decoder_output.past_key_values
return decoder_output.last_hidden_state, *itertools.chain.from_iterable(pkv)
class BartDecoderWithPast(torch.nn.Module):
"""BartDecoder with past."""
def __init__(self, decoder: torch.nn.Module):
"""Init.
Args:
decoder (torch.nn.Module): decoder model.
"""
super().__init__()
self.decoder = decoder
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
past_key_values: tuple[torch.Tensor, ...],
) -> tuple[torch.Tensor, ...]:
"""Forward pass.
which requires inputs:
decoder_input_ids
encoder_hidden_states
past_key_self_0, past_value_self_0 (for each self-attention layer)
past_key_cross_0, past_value_cross_0 (for each cross-attention layer)
...
which outputs:
last hidden state,
present_key_self_0, present_value_self_0 (for each self-attention layer)
Args:
decoder_input_ids (torch.Tensor): decoder input ids
encoder_hidden_states (torch.Tensor): final hidden states of encoder
past_key_values (tuple[torch.Tensor, ...]): past key values if previous decoding iteration.
Returns:
tuple[torch.Tensor, ...]: last_hidden_state, self attention tensors
"""
decoder_layers = model.config.decoder_layers
pkv_in = []
for i in range(decoder_layers):
pkv_in.append(
(
past_key_values[i * 4],
past_key_values[i * 4 + 1],
past_key_values[i * 4 + 2],
past_key_values[i * 4 + 3],
)
)
decoder_output = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_hidden_states,
past_key_values=pkv_in,
use_cache=True,
return_dict=True,
)
self_att = []
pkv = decoder_output.past_key_values
# omit cross-attention keys/values as they remain constant.
for present_key_self, present_value_self, _, _ in pkv:
self_att.extend([present_key_self, present_value_self])
return decoder_output.last_hidden_state, *self_att
model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
model.eval()
encoder = BartEncoder(encoder=model.model.encoder)
# same as above.
text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")
input_ids_encoder, attention_mask = inputs["input_ids"], inputs["attention_mask"]
with torch.no_grad():
dummy_hidden_states = encoder(input_ids_encoder, attention_mask)
len_hidden_states = dummy_hidden_states.shape[1]
decoder_init_torch = BartDecoderInit(decoder=model.model.decoder)
dummy_input_ids_decoder = torch.randint(0, config.vocab_size, (1, 1), dtype=torch.int32)
with torch.no_grad():
decoder_outputs = decoder_init_torch(dummy_hidden_states, dummy_input_ids_decoder)
input_names = [
"encoder_hidden_states",
"decoder_input_ids",
]
pkv = []
for i in range(model.config.decoder_layers):
pkv.extend(
[f"present_key_self_{i}", f"present_value_self_{i}", f"present_key_cross_{i}", f"present_value_cross_{i}"]
)
# dynamic axis
dynamic_axes = {}
for p in pkv:
dynamic_axes[p] = {2: "encoder_sequence_length_out" if "cross" in p else "past_decoder_sequence_length+1"}
output_names = ["last_hidden_state", *pkv]
torch.onnx.export(
decoder_init_torch,
(
dummy_hidden_states,
dummy_input_ids_decoder,
),
"bart_decoder_init.onnx",
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=20,
)
onnx_model = ir.load("bart_decoder_init.onnx")
onnx_model, stats = optimize_for_ort(onnx_model, debug=True)
print(stats)
decoder_with_past_torch = BartDecoderWithPast(decoder=model.model.decoder)
pkv_in, pkv_out = [], []
for i in range(model.config.decoder_layers):
pkv_in.extend([f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}"])
for i in range(model.config.decoder_layers):
pkv_out.extend([f"present_key_self_{i}", f"present_value_self_{i}"])
# dynamic axis
dynamic_axes = {}
dynamic_axes["last_hidden_state"] = {1: "decoder_sequence_length"}
for p in pkv_in:
dynamic_axes[p] = {2: "encoder_sequence_length" if "cross" in p else "past_decoder_sequence_length"}
for p in pkv_out:
dynamic_axes[p] = {2: "past_decoder_sequence_length+1"}
input_names = ["input_ids", "encoder_hidden_states", *pkv_in]
output_names = ["last_hidden_state", *pkv_out]
dummy_hidden_states = torch.empty((1, len_hidden_states, 1024), dtype=torch.float32).uniform_(0, 1)
decoder_inputs = (
dummy_input_ids_decoder,
dummy_hidden_states,
decoder_outputs[1:],
)
decoder_with_past_outputs = decoder_with_past_torch(*decoder_inputs)
torch.onnx.export(
decoder_with_past_torch,
decoder_inputs,
"bart_decoder_with_past.onnx",
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=20,
)
onnx_model = ir.load("bart_decoder_with_past.onnx")
onnx_model, stats = optimize_for_ort(onnx_model, debug=True)
print(stats)
Outputs:
...
Graph matching failed: Attribute key_format mismatch: expected BHSd, got BSHd.
Failure at or around nodes/values:
Node: 'node_SDPA_119'
...
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'cos_sin_cache': 0, 'partial_rotary_embedding': 0, 'sdpa': 4, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
...
------------------------------
Rule: SDPA
--------------------------------------------------------------------------------
Status: CONDITION_FAILED
Graph matching failed due to failing check condition : query_scale is not a scalar.
...
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'cos_sin_cache': 0, 'partial_rotary_embedding': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
Prob. 1: For the decoder (initial decoding) attention is only fused into the temporary SDPA
node. Fusion into MultiHeadAttention (or Attention) later fails, because of the mismatch in the key_format
check. The resulting graph cannot be inferred using onnxruntime
due to the inserted, temporary SDPA
node.
screenshot from netron for bart_decoder_init.onnx
screenshot from netron for bart_decoder_init_onnxscript_fused.onnx
To my current understanding, the error is due to (wrong) handling of the key transpose.
Prob. 2:
For the decoder with past, query_scale
cannot be retrieved from the graph and is None in
@justinchuby / @gramalingam I could investigate the issues further and potentially provide a fix. So far I was able to fuse all attention mechanisms using custom rewrite rules into MultiHeadAttention
nodes.