Skip to content

Attention fusion (SDPA/MHA) broken for BART decoder with/wo past🐛 #2424

Open
@KarelZe

Description

@KarelZe

I'm currently experimenting with various BART-based models and noticed that attention fusion (SDPA/MHA) is broken for BART.

Dependencies:

(onnxruntime) markusbilz@Markuss-Mini onnxruntime % uv pip show transformers torch onnxscript onnx-ir 
Name: onnx-ir
Version: 0.1.3
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, onnx, typing-extensions
Required-by: onnxscript
---
Name: onnxscript
Version: 0.3.1
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, onnx, onnx-ir, packaging, typing-extensions
Required-by: model-explorer-onnx
---
Name: torch
Version: 2.7.1
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by:
---
Name: transformers
Version: 4.52.4
Location: /Users/markusbilz/Documents/git/onnxruntime/.venv/lib/python3.11/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by:

MWE:

import itertools

import onnxscript.ir as ir
import torch
from onnxscript.rewriter.ort_fusions import optimize_for_ort
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)


class BartEncoder(torch.nn.Module):
    """A wrapper around the BART encoder for onnx export."""

    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        outs = self.encoder(input_ids, attention_mask)
        return outs["last_hidden_state"]


class BartDecoderInit(torch.nn.Module):
    """BART decoder for initial decoding step."""

    def __init__(self, decoder: torch.nn.Module):
        """Init.

        Args:
            decoder (torch.nn.Module): decoder model.
        """
        super().__init__()
        self.decoder = decoder

    def forward(self, encoder_hidden_states: torch.Tensor, decoder_input_ids: torch.Tensor) -> tuple[torch.Tensor, ...]:
        """Forward pass.

        Args:
            encoder_hidden_states (torch.Tensor): hidden states of encoder.
            decoder_input_ids (torch.Tensor): input ids for decoder i.e., token IDs of system prompt

        Returns:
            tuple[torch.Tensor,...]: last_hidden_state, self-attention key 0, self attention value 0, cross-attention key 0, cross-attention value 0, ...
        """
        decoder_output = self.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_hidden_states,
            use_cache=True,
            return_dict=True,
        )
        pkv = decoder_output.past_key_values
        return decoder_output.last_hidden_state, *itertools.chain.from_iterable(pkv)


class BartDecoderWithPast(torch.nn.Module):
    """BartDecoder with past."""

    def __init__(self, decoder: torch.nn.Module):
        """Init.

        Args:
            decoder (torch.nn.Module): decoder model.
        """
        super().__init__()
        self.decoder = decoder

    def forward(
        self,
        decoder_input_ids: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        past_key_values: tuple[torch.Tensor, ...],
    ) -> tuple[torch.Tensor, ...]:
        """Forward pass.

        which requires inputs:
            decoder_input_ids
            encoder_hidden_states
            past_key_self_0, past_value_self_0 (for each self-attention layer)
            past_key_cross_0, past_value_cross_0 (for each cross-attention layer)
            ...
        which outputs:
            last hidden state,
            present_key_self_0, present_value_self_0 (for each self-attention layer)

        Args:
            decoder_input_ids (torch.Tensor): decoder input ids
            encoder_hidden_states (torch.Tensor): final hidden states of encoder
            past_key_values (tuple[torch.Tensor, ...]): past key values if previous decoding iteration.

        Returns:
            tuple[torch.Tensor, ...]: last_hidden_state, self attention tensors
        """
        decoder_layers = model.config.decoder_layers

        pkv_in = []
        for i in range(decoder_layers):
            pkv_in.append(
                (
                    past_key_values[i * 4],
                    past_key_values[i * 4 + 1],
                    past_key_values[i * 4 + 2],
                    past_key_values[i * 4 + 3],
                )
            )

        decoder_output = self.decoder(
            input_ids=decoder_input_ids[:, -1:],
            encoder_hidden_states=encoder_hidden_states,
            past_key_values=pkv_in,
            use_cache=True,
            return_dict=True,
        )
        self_att = []
        pkv = decoder_output.past_key_values
        # omit cross-attention keys/values as they remain constant.
        for present_key_self, present_value_self, _, _ in pkv:
            self_att.extend([present_key_self, present_value_self])
        return decoder_output.last_hidden_state, *self_att


model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

model.eval()

encoder = BartEncoder(encoder=model.model.encoder)

# same as above.
text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")
input_ids_encoder, attention_mask = inputs["input_ids"], inputs["attention_mask"]

with torch.no_grad():
    dummy_hidden_states = encoder(input_ids_encoder, attention_mask)
    len_hidden_states = dummy_hidden_states.shape[1]

decoder_init_torch = BartDecoderInit(decoder=model.model.decoder)

dummy_input_ids_decoder = torch.randint(0, config.vocab_size, (1, 1), dtype=torch.int32)
with torch.no_grad():
    decoder_outputs = decoder_init_torch(dummy_hidden_states, dummy_input_ids_decoder)
input_names = [
    "encoder_hidden_states",
    "decoder_input_ids",
]
pkv = []
for i in range(model.config.decoder_layers):
    pkv.extend(
        [f"present_key_self_{i}", f"present_value_self_{i}", f"present_key_cross_{i}", f"present_value_cross_{i}"]
    )

# dynamic axis
dynamic_axes = {}
for p in pkv:
    dynamic_axes[p] = {2: "encoder_sequence_length_out" if "cross" in p else "past_decoder_sequence_length+1"}
output_names = ["last_hidden_state", *pkv]

torch.onnx.export(
    decoder_init_torch,
    (
        dummy_hidden_states,
        dummy_input_ids_decoder,
    ),
    "bart_decoder_init.onnx",
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    opset_version=20,
)

onnx_model = ir.load("bart_decoder_init.onnx")
onnx_model, stats = optimize_for_ort(onnx_model, debug=True)
print(stats)

decoder_with_past_torch = BartDecoderWithPast(decoder=model.model.decoder)

pkv_in, pkv_out = [], []
for i in range(model.config.decoder_layers):
    pkv_in.extend([f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}"])
for i in range(model.config.decoder_layers):
    pkv_out.extend([f"present_key_self_{i}", f"present_value_self_{i}"])

# dynamic axis
dynamic_axes = {}
dynamic_axes["last_hidden_state"] = {1: "decoder_sequence_length"}

for p in pkv_in:
    dynamic_axes[p] = {2: "encoder_sequence_length" if "cross" in p else "past_decoder_sequence_length"}
for p in pkv_out:
    dynamic_axes[p] = {2: "past_decoder_sequence_length+1"}

input_names = ["input_ids", "encoder_hidden_states", *pkv_in]
output_names = ["last_hidden_state", *pkv_out]

dummy_hidden_states = torch.empty((1, len_hidden_states, 1024), dtype=torch.float32).uniform_(0, 1)
decoder_inputs = (
    dummy_input_ids_decoder,
    dummy_hidden_states,
    decoder_outputs[1:],
)

decoder_with_past_outputs = decoder_with_past_torch(*decoder_inputs)


torch.onnx.export(
    decoder_with_past_torch,
    decoder_inputs,
    "bart_decoder_with_past.onnx",
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    opset_version=20,
)

onnx_model = ir.load("bart_decoder_with_past.onnx")
onnx_model, stats = optimize_for_ort(onnx_model, debug=True)
print(stats)

Outputs:

...
Graph matching failed: Attribute key_format mismatch: expected BHSd, got BSHd.
Failure at or around nodes/values:
Node: 'node_SDPA_119'
...
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'cos_sin_cache': 0, 'partial_rotary_embedding': 0, 'sdpa': 4, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}
...
------------------------------
Rule: SDPA
--------------------------------------------------------------------------------
Status: CONDITION_FAILED
Graph matching failed due to failing check condition : query_scale is not a scalar.
...
{'erf_gelu': 0, 'rms_normalization': 0, 'skip_layer_normalization': 0, 'skip_rms_normalization': 0, 'rotary_embedding': 0, 'cos_sin_cache': 0, 'partial_rotary_embedding': 0, 'sdpa': 0, 'gqa': 0, 'packed_qkv_for_gqa': 0, 'mha1': 0, 'mha2': 0, 'mha_bias': 0, 'attention': 0, 'gelu': 0, 'bias_gelu': 2}

Prob. 1: For the decoder (initial decoding) attention is only fused into the temporary SDPA node. Fusion into MultiHeadAttention (or Attention) later fails, because of the mismatch in the key_format check. The resulting graph cannot be inferred using onnxruntime due to the inserted, temporary SDPA node.

expected_key_format = "BHSd" if key_transposed else "BSHd"

screenshot from netron for bart_decoder_init.onnx

Image

screenshot from netron for bart_decoder_init_onnxscript_fused.onnx

Image

To my current understanding, the error is due to (wrong) handling of the key transpose.

Prob. 2:
For the decoder with past, query_scale cannot be retrieved from the graph and is None in

value = _ir_utils.get_singleton_value(scale)

@justinchuby / @gramalingam I could investigate the issues further and potentially provide a fix. So far I was able to fuse all attention mechanisms using custom rewrite rules into MultiHeadAttention nodes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions