From 604527b5b8c998a276a0ee628b59b1264d686df7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 17:39:06 +0000 Subject: [PATCH 01/12] Fix minor bugs in onnx_plug --- _unittests/ut_tasks/try_export.py | 19 +- onnx_diagnostic/export/control_flow.py | 6 +- onnx_diagnostic/export/onnx_plug.py | 35 +++- .../patches/_patch_transformers_qwen2_5.py | 195 +++++++++++++++--- 4 files changed, 201 insertions(+), 54 deletions(-) diff --git a/_unittests/ut_tasks/try_export.py b/_unittests/ut_tasks/try_export.py index c86efacb..d759ac15 100644 --- a/_unittests/ut_tasks/try_export.py +++ b/_unittests/ut_tasks/try_export.py @@ -45,6 +45,9 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self): exporter = os.environ.get("EXPORTER", "custom") from transformers import AutoModel, AutoProcessor + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( + PLUGS, + ) model_id = "Qwen/Qwen2.5-VL-7B-Instruct" # model_id = "Qwen/Qwen2.5-VL-3B-Instruct" @@ -82,11 +85,17 @@ def _config_reduction(config, task): processor = AutoProcessor.from_pretrained(model_id, use_fast=True) print(f"-- processor={type(processor)}") + big_inputs = dict( + hidden_states=torch.rand((14308, 1176), dtype=torch_dtype).to(device), + grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), + ) + print("-- save inputs") inputs = dict( hidden_states=torch.rand((1292, 1176), dtype=torch_dtype).to(device), grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), ) print("-- save inputs") + torch.save(big_inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.big.pt")) torch.save(inputs, self.get_dump_file("qwen_2_5_vl_instruct_visual.inputs.pt")) print(f"-- inputs: {self.string_type(inputs, with_shape=True)}") @@ -115,15 +124,6 @@ def _config_reduction(config, task): verbose=1, stop_if_static=2, ): - if exporter == "onnx-dynamo": - # The exported program in ONNXProgram cannot be restored. - ep2 = torch.export.export( - model.visual, - (), - kwargs=export_inputs, - dynamic_shapes=self.use_dyn_not_str(dynamic_shapes), - ) - torch.export.save(ep2, f"{fileep}.backup.pt2") to_onnx( model.visual, kwargs=export_inputs, @@ -134,6 +134,7 @@ def _config_reduction(config, task): save_ep=(fileep, 2**35), target_opset=22, optimize=True, + onnx_plugs=PLUGS, ) pt2_files = [f"{fileep}.backup.pt2", f"{fileep}.ep.pt2", f"{fileep}.pt2"] diff --git a/onnx_diagnostic/export/control_flow.py b/onnx_diagnostic/export/control_flow.py index 7e2b7298..21814084 100644 --- a/onnx_diagnostic/export/control_flow.py +++ b/onnx_diagnostic/export/control_flow.py @@ -134,7 +134,11 @@ def make_custom_loop_for( assert body_outputs is not None, "body_outputs cannot be None" srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs) sred = "x".join(map(str, reduction_dim)) if reduction_dim else "" - full_name = body_fn.__qualname__.replace("", "L").replace(".", "_") + full_name = ( + body_fn.__qualname__.replace("", "L") + .replace("", "l") + .replace(".", "_") + ) name = f"loop_for_{full_name}_{srank}_{sred}" if name in _REGISTERED_SCHEMA: return name, _REGISTERED_SCHEMA[name][0] diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index b2d72252..60a34e0a 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import onnx import torch from ..helpers import max_diff @@ -49,6 +49,7 @@ class EagerDirectReplacementWithOnnx: :param n_outputs: same for the number of outputs, only tensors must be counted :param name: the name of the custom op, the function name if not specified + :param kwargs: constants Here is an example: @@ -141,6 +142,7 @@ def __init__( n_inputs: Optional[int] = None, n_outputs: Optional[int] = None, name: Optional[str] = None, + kwargs: Optional[Dict[str, Union[int, float]]] = None, ): assert isinstance( function_proto, onnx.FunctionProto @@ -152,7 +154,14 @@ def __init__( self.function_proto = function_proto self.n_inputs = n_inputs self.n_outputs = n_outputs - self.name = name or eager_fn.__name__ + self.name = name or eager_fn.__qualname__.replace("", "L").replace( + "", "l" + ).replace(".", "_") + self.kwargs = kwargs + assert kwargs is None or all(isinstance(v, (int, float)) for v in kwargs.values()), ( + f"Only int or floats are allowed for kwargs={kwargs}, one of them " + f"does not respect that constraint." + ) sig = inspect.signature(self.eager_fn) params = list(sig.parameters) assert ( @@ -190,7 +199,7 @@ def torch_op(self) -> Callable: def __call__(self, *args): """Calls eager_fn or shape_fn if the model is being exported.""" if is_exporting(): - return self.shape_fn(*args) + return self.torch_op(*args) return self.eager_fn(*args) def _registers(self): @@ -266,10 +275,16 @@ def converter( outputs: List[str], *args, ) -> Any: - if not g.has_local_function(self.name, self.domain): + if not g.has_local_function( + self.function_proto.name, domain=self.function_proto.domain + ): g.add_function(self.function_proto) res = g.make_node( - self.name, args, outputs, domain=self.domain, name=self.target_name + self.function_proto.name, + args, + outputs, + domain=self.function_proto.domain, + name=self.target_name, ) if not sts: new_shapes = self.shape_fn(*args) @@ -290,8 +305,8 @@ def onnx_dynamo_converter(self) -> Callable: """ import onnxscript - onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1) - schema = onnx_plug_op[self.name] + onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1) + schema = onnx_plug_op[self.function_proto.name] if schema is None: all_types = [ "tensor(float)", @@ -307,8 +322,8 @@ def onnx_dynamo_converter(self) -> Callable: for i in range(self.n_outputs): type_constraints.append((f"U{i}", all_types, "")) schema = onnx.defs.OpSchema( - self.name, - self.domain, + self.function_proto.name, + self.function_proto.domain, 1, inputs=[ onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}") @@ -321,7 +336,7 @@ def onnx_dynamo_converter(self) -> Callable: type_constraints=type_constraints, ) onnx.defs.register_schema(schema) - op = onnxscript.values.Op(onnx_plug_op, self.name, schema) + op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema) def converter(*cargs): return op(*cargs, n_outputs=self.n_outputs) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index c9758225..3bf2d5ea 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -1,7 +1,9 @@ import os from typing import Callable, Optional +import onnx import torch import torch.nn.functional as F +from ...export.onnx_plug import EagerDirectReplacementWithOnnx from .patch_helper import _is_torchdynamo_exporting from ._patch_transformers_attention import patched_sdpa_attention_forward @@ -13,9 +15,145 @@ except ImportError: patch_qwen2_5 = False -strategy_for_attention_in_qwen_2_5 = os.environ.get("QWEN25ATTENTION", "BIGMASK") +PLUGS = [] +strategy_for_attention_in_qwen_2_5 = os.environ.get("QWEN25ATTENTION", "PACKED") if patch_qwen2_5: + import onnxscript + + onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1) + op = onnxscript.opset22 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + @onnxscript.script(opset=onnx_plugs_op) + def LoopMHAAttention( + query_states, key_states, value_states, cu_seqlens, scale: float, num_heads: int + ): + to_3d_shape = op.Constant(value_ints=[0, 0, -1]) + query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3]) + output_shape = op.Shape(query_transposed) + query_3d = op.Reshape(query_transposed, to_3d_shape) + value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape) + key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape) + num_patches = op.Size(cu_seqlens) - 1 + seq_axis = op.Constant(value_ints=[1]) + seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32) + attn_output = op.Slice(value_3d, [0], [0], seq_axis) + for i in range(num_patches): + i_1d = op.Reshape(i, [1]) + i_plus_1_1d = i_1d + 1 + start = op.Gather(cu_seqlens, i_1d, axis=0) + end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0) + query_i = op.Slice(query_3d, start, end, seq_axis_int32) + key_i = op.Slice(key_3d, start, end, seq_axis_int32) + value_i = op.Slice(value_3d, start, end, seq_axis_int32) + mha_output = msft_op.MultiHeadAttention( + query_i, + key_i, + value_i, + num_heads=num_heads, + scale=scale, + ) + attn_output = op.Concat(attn_output, mha_output, axis=1) + attn_output_4d = op.Reshape(attn_output, output_shape) + return attn_output_4d + + @onnxscript.script(opset=onnx_plugs_op) + def PackedAttention( + query, + key, + value, + cu_seqlens, + scale: float = 0.11180339887498948, + num_heads: int = 16, + ): + num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1 + starts = op.Slice(cu_seqlens, [0], [-1], [0]) + ends = op.Slice(cu_seqlens, [1], [9223372036854775807], [0]) + lengths = ends - starts + max_length = op.ReduceMax(lengths, [0], keepdims=0) # max_seqlen + rows = op.Range(0, num_patches, 1) + rows_2d = op.Unsqueeze(rows, [1]) + cols = op.Range(0, max_length, 1) + cols_2d = op.Unsqueeze(cols, [0]) + + position_matrix = op.Cast(rows_2d, to=onnx.TensorProto.INT32) * op.Cast( + max_length, to=onnx.TensorProto.INT32 + ) + op.Cast(cols_2d, to=onnx.TensorProto.INT32) + position_matrix_shape = op.Shape(position_matrix) + token_mask = cols_2d < op.Unsqueeze(lengths, [1]) + token_mask_1d = op.Reshape(token_mask, [-1]) + padded_mask_1d = op.Not(token_mask_1d) + valid_token_positions = op.Compress(position_matrix, token_mask) + padded_token_positions = op.Compress(position_matrix, padded_mask_1d) + token_offset_1d = op.Concat(valid_token_positions, padded_token_positions, axis=0) + token_offset = op.Reshape(token_offset_1d, position_matrix_shape) + + query_3d = op.Transpose(op.Squeeze(query, [0]), perm=[1, 0, 2]) + shape_3d = op.Shape(query_3d) + query_2d = op.Reshape(query_3d, [0, -1]) + key_2d = op.Reshape(op.Transpose(op.Squeeze(key, [0]), perm=[1, 0, 2]), [0, -1]) + value_2d = op.Reshape(op.Transpose(op.Squeeze(value, [0]), perm=[1, 0, 2]), [0, -1]) + + packed_attn_output_2d = msft_op.PackedMultiHeadAttention( + query_2d, + key_2d, + value_2d, + None, + op.Cast(token_offset, to=onnx.TensorProto.INT32), + op.Cast(cu_seqlens, to=onnx.TensorProto.INT32), + scale=scale, + num_heads=num_heads, + ) + packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d) + return op.Unsqueeze(packed_attn_output_3d, [0]) + + def qwen_sdpa_attention( + query_states: torch.Tensor, # F10s1x16xs47x80 + key_states: torch.Tensor, # F10s1x16xs47x80 + value_states: torch.Tensor, # F10s1x16xs47x80 + cu_seqlens: torch.Tensor, # F7su19 + scaling: float = 0, + ) -> torch.Tensor: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) + for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + patched_sdpa_attention_forward( + None, + q, + k, + v, + attention_mask=None, + scaling=scaling, + dropout=0.0, + is_causal=False, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + return attn_output + + # not ideal + qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx( + lambda qs, ks, vs, cuseq: qwen_sdpa_attention( + qs, ks, vs, cuseq, scaling=0.11180339887498948 + ), + lambda qs, *args: torch.empty( + (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), + dtype=qs.dtype, + device=qs.device, + ), + PackedAttention.to_function_proto(), + n_inputs=4, + n_outputs=1, + kwargs=dict(scaling=0.11180339887498948, num_heads=16), + name="qwen_sdpa_attention", + ) + PLUGS.append(qwen_sdpa_attention_versatile) class patched_Qwen2_5_VLForConditionalGeneration: _PATCHES_ = ["prepare_inputs_for_generation"] @@ -39,7 +177,7 @@ def prepare_inputs_for_generation( second_per_grid_ts=None, **kwargs, ): - # Overwritten -- in specific circumstances we don't want to f + # Overwritten -- in specific circumstances we don't want to # forward image inputs to the model from transformers.generation import GenerationMixin @@ -346,7 +484,23 @@ def forward( self.config._attn_implementation ] - if _is_torchdynamo_exporting(): + if ( + attention_interface + is transformers.integrations.sdpa_attention.sdpa_attention_forward + or attention_interface is patched_sdpa_attention_forward + ) and strategy_for_attention_in_qwen_2_5 == "PACKED": + torch._check( + qwen_sdpa_attention_versatile.kwargs["scaling"] == self.scaling, + lambda: f"Not implemented for scaling={self.scaling}", + ) + torch._check( + qwen_sdpa_attention_versatile.kwargs["num_heads"] == self.num_heads, + lambda: f"Not implemented for num_heads={self.num_heads}", + ) + attn_output = qwen_sdpa_attention_versatile( + query_states, key_states, value_states, cu_seqlens + ) + elif _is_torchdynamo_exporting(): if self.config._attn_implementation == "flash_attention_2": max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attn_output = torch.onnx.ops.symbolic( @@ -373,8 +527,8 @@ def forward( elif ( attention_interface is transformers.integrations.sdpa_attention.sdpa_attention_forward - and strategy_for_attention_in_qwen_2_5 == "LOOPMHA" - ): + or attention_interface is patched_sdpa_attention_forward + ) and strategy_for_attention_in_qwen_2_5 == "LOOPMHA": def _iteration(start_end, query_states, key_states, value_states): return patched_Qwen2_5_VLVisionAttentionOneIteration.forward( @@ -409,8 +563,8 @@ def _iteration(start_end, query_states, key_states, value_states): elif ( attention_interface is transformers.integrations.sdpa_attention.sdpa_attention_forward - and strategy_for_attention_in_qwen_2_5 == "BIGMASK" - ): + or attention_interface is patched_sdpa_attention_forward + ) and strategy_for_attention_in_qwen_2_5 == "BIGMASK": # make square mask indices = torch.arange( cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device @@ -437,33 +591,6 @@ def _iteration(start_end, query_states, key_states, value_states): is_causal=False, **kwargs, ) - elif ( - attention_interface - is transformers.integrations.sdpa_attention.sdpa_attention_forward - and strategy_for_attention_in_qwen_2_5 == "PACKED" - ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - attn_output = torch.onnx.ops.symbolic( - "custom::qwen25_packed_attention", - ( - query_states, - key_states, - value_states, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - torch.tensor(self.scaling, dtype=torch.float32), - ), - dtype=query_states.dtype, - shape=( - query_states.shape[0], # batch_size - query_states.shape[2], # sequence_length (total patches) - query_states.shape[1], # num_heads - query_states.shape[3], # head_size - ), - version=1, - ) else: raise NotImplementedError( f"Not export strategy for strategy_for_attention_in_qwen_2_5=" From b4be4df76fb094ccabd728527600391cfd02fa64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 18:59:29 +0100 Subject: [PATCH 02/12] fix --- onnx_diagnostic/export/onnx_plug.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 60a34e0a..17212943 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -154,7 +154,7 @@ def __init__( self.function_proto = function_proto self.n_inputs = n_inputs self.n_outputs = n_outputs - self.name = name or eager_fn.__qualname__.replace("", "L").replace( + self.name = name or eager_fn.__qualname__.replace("", "L").replace( "", "l" ).replace(".", "_") self.kwargs = kwargs From 621e81e1c0df7d644dd0896ec6e7f63e283c482d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 23:01:25 +0100 Subject: [PATCH 03/12] fix --- _unittests/ut_export/test_onnx_plug.py | 6 ++++-- onnx_diagnostic/export/onnx_plug.py | 23 ++++++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/_unittests/ut_export/test_onnx_plug.py b/_unittests/ut_export/test_onnx_plug.py index 97f101f7..da2e9741 100644 --- a/_unittests/ut_export/test_onnx_plug.py +++ b/_unittests/ut_export/test_onnx_plug.py @@ -1,7 +1,7 @@ import unittest import onnx.helper as oh import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch +from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch, hide_stdout, ignore_warnings from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx from onnx_diagnostic.export.api import to_onnx @@ -36,6 +36,8 @@ def make_function_proto(): self.assertEqual(len(res.diffs), 1) self.assertEqual(res.diffs[0]["abs"], 0) + @hide_stdout() + @ignore_warnings(FutureWarning) def test_onnx_plug_export(self): def _test_customsub(x, y): return x - y @@ -61,7 +63,7 @@ def forward(self, x): replacements = [ EagerDirectReplacementWithOnnx( - _test_customsub, _test_customsub_shape, make_function_proto(), 2, 1 + _test_customsub, _test_customsub_shape, make_function_proto(), 2, 1, verbose=1 ) ] diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 17212943..8abe3911 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -50,6 +50,7 @@ class EagerDirectReplacementWithOnnx: only tensors must be counted :param name: the name of the custom op, the function name if not specified :param kwargs: constants + :param verbose: verbose level Here is an example: @@ -143,6 +144,7 @@ def __init__( n_outputs: Optional[int] = None, name: Optional[str] = None, kwargs: Optional[Dict[str, Union[int, float]]] = None, + verbose: int = 0, ): assert isinstance( function_proto, onnx.FunctionProto @@ -154,9 +156,13 @@ def __init__( self.function_proto = function_proto self.n_inputs = n_inputs self.n_outputs = n_outputs - self.name = name or eager_fn.__qualname__.replace("", "L").replace( - "", "l" - ).replace(".", "_") + self.name = name or ( + eager_fn.__name__ + if "<" not in eager_fn.__name__ + else eager_fn.__qualname__.replace("", "L") + .replace("", "l") + .replace(".", "_") + ) self.kwargs = kwargs assert kwargs is None or all(isinstance(v, (int, float)) for v in kwargs.values()), ( f"Only int or floats are allowed for kwargs={kwargs}, one of them " @@ -179,7 +185,8 @@ def __init__( function_proto.domain == self.domain ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}" self.arg_names = params - self.custom_op = self._registers() + self.verbose = verbose + self.custom_op = self._register() @property def domain(self) -> str: @@ -202,12 +209,18 @@ def __call__(self, *args): return self.torch_op(*args) return self.eager_fn(*args) - def _registers(self): + def _register(self): """Registers the custom op.""" inputs = ", ".join([f"Tensor {p}" for p in self.arg_names]) schema = f"({inputs}) -> Tensor" if self.n_outputs > 1: schema += "[]" + if self.verbose: + print( + f"[EagerDirectReplacementWithOnnx._register] " + f"'torch.ops.{self.domain}.{self.name}" + ) + print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}") custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn) custom_def.register_kernel(None)(self.eager_fn) custom_def._abstract_fn = self.shape_fn From 40b7b23286613878f03755d8e57edbc4416fa3c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Nov 2025 11:08:05 +0100 Subject: [PATCH 04/12] fix --- _unittests/ut_torch_export_patches/test_patch_transformers.py | 3 +++ .../torch_export_patches/patches/patch_transformers.py | 1 + 2 files changed, 4 insertions(+) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 4f29873b..0004d2a0 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -384,6 +384,7 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self): ) from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched_Qwen2_5_VLVisionAttention, + PLUGS_Qwen25, ) config = get_cached_configuration("Qwen/Qwen2.5-VL-7B-Instruct") @@ -456,6 +457,8 @@ def forward( dynamic_shapes=ds, exporter=exporter, filename=filename, + onnx_plugs=PLUGS_Qwen25, + target_opset=22, ) # exporter_kwargs={"report":True} if exporter != "custom" else {} self.assert_onnx_disc( diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 86dbba35..b20e8bea 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -76,6 +76,7 @@ patched_Qwen2_5_VisionTransformerPretrainedModel, patched_Qwen2_5_VLVisionAttentionOneIteration, patched_Qwen2_5_VLVisionAttention, + PLUGS as PLUGS_Qwen25, ) from ._patch_transformers_qwen3 import patch_qwen3 From 2d27c1a090caa7af3f854617817e736627984e23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Nov 2025 13:45:50 +0100 Subject: [PATCH 05/12] fix unittest --- .../test_patch_transformers.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 0004d2a0..e7263839 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -461,14 +461,16 @@ def forward( target_opset=22, ) # exporter_kwargs={"report":True} if exporter != "custom" else {} - self.assert_onnx_disc( - f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}", - onnx.load(filename), - instance, - inputs, - atol=1e-3, - rtol=1, - ) + if torch.cuda.is_available(): + self.assert_onnx_disc( + f"test_patched_qwen2_5_vl_vision_attention_forward-{exporter}", + onnx.load(filename), + instance, + inputs, + atol=1e-3, + rtol=1, + providers=["CUDAExecutionProvider"], + ) self.clean_dump() @requires_transformers("4.99") From b5aa71bbf480913ea1759950420b7bbd3fe3e4de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Nov 2025 15:08:53 +0100 Subject: [PATCH 06/12] fix --- _unittests/ut_export/test_onnx_plug.py | 95 ++++++++++++++++++- onnx_diagnostic/export/onnx_plug.py | 28 ++++-- .../patches/_patch_transformers_qwen2_5.py | 34 +++---- 3 files changed, 131 insertions(+), 26 deletions(-) diff --git a/_unittests/ut_export/test_onnx_plug.py b/_unittests/ut_export/test_onnx_plug.py index da2e9741..bd405f89 100644 --- a/_unittests/ut_export/test_onnx_plug.py +++ b/_unittests/ut_export/test_onnx_plug.py @@ -1,4 +1,5 @@ import unittest +import onnx import onnx.helper as oh import torch from onnx_diagnostic.ext_test_case import ExtTestCase, has_torch, hide_stdout, ignore_warnings @@ -38,7 +39,7 @@ def make_function_proto(): @hide_stdout() @ignore_warnings(FutureWarning) - def test_onnx_plug_export(self): + def test_onnx_plug_export_nokwargs(self): def _test_customsub(x, y): return x - y @@ -85,7 +86,95 @@ def forward(self, x): onnx_plugs=replacements, target_opset=22, ) - self.assert_onnx_disc("test_onnx_plug_export_custom", onx.model_proto, model, (x,)) + self.assert_onnx_disc( + "test_onnx_plug_export_nokwargs_custom", onx.model_proto, model, (x,) + ) + + if not has_torch("2.9"): + raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8") + with self.subTest(exporter="onnx-dynamo"): + onx = to_onnx( + model, + (x,), + dynamic_shapes=ds, + exporter="onnx-dynamo", + onnx_plugs=replacements, + target_opset=22, + ) + self.assert_onnx_disc( + "test_onnx_plug_export_nokwargs_onnx_dynamo", onx.model_proto, model, (x,) + ) + + @unittest.skip("not ready yet") + @hide_stdout() + @ignore_warnings(FutureWarning) + def test_onnx_plug_export_kwargs(self): + def _test_customdiv(x, y, epsilon: float = 1e-5): + return x / (y + epsilon) + + def _test_customdiv_shape(x, y, *args, **kwargs): + return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype) + + def make_function_proto(): + f = oh.make_function( + "onnx_plug", + "_test_customdiv", + ["x", "y"], + ["z"], + [ + oh.make_node("Constant", [], ["eps"]), + oh.make_node("Add", ["y", "eps"], ["yeps"]), + oh.make_node("Div", ["x", "yeps"], ["z"]), + ], + opset_imports=[oh.make_opsetid("", 22)], + attributes=["epsilon"], + ) + att = onnx.AttributeProto() + att.name = "value_float" + att.ref_attr_name = "epsilon" + att.type = onnx.AttributeProto.FLOAT + f.node[0].attribute.append(att) + return f + + class Model(torch.nn.Module): + def forward(self, x): + y = x.sum(axis=1, keepdim=True) + d = torch.ops.onnx_plug._test_customdiv(x, y, epsilon=3.5) + return torch.abs(d) + + replacements = [ + EagerDirectReplacementWithOnnx( + _test_customdiv, + _test_customdiv_shape, + make_function_proto(), + 2, + 1, + kwargs=dict(epsilon=1e-5), + verbose=1, + ) + ] + + x = torch.randn((3, 4), dtype=torch.float32) + model = Model() + expected = model(x) + ds = ({0: "d1", 1: "d2"},) + ep = torch.export.export(model, (x,), dynamic_shapes=self.use_dyn_not_str(ds)) + self.assertIn("torch.ops.onnx_plug._test_customdiv.default", str(ep)) + got = ep.module()(x) + self.assertEqualArray(expected, got) + + with self.subTest(exporter="custom"): + onx = to_onnx( + model, + (x,), + dynamic_shapes=ds, + exporter="custom", + onnx_plugs=replacements, + target_opset=22, + ) + self.assert_onnx_disc( + "test_onnx_plug_export_kwargs_custom", onx.model_proto, model, (x,) + ) if not has_torch("2.9"): raise unittest.SkipTest("onnx-dynamo + custom op not fully working on 2.8") @@ -99,7 +188,7 @@ def forward(self, x): target_opset=22, ) self.assert_onnx_disc( - "test_onnx_plug_export_onnx_dynamo", onx.model_proto, model, (x,) + "test_onnx_plug_export_kwargs_onnx_dynamo", onx.model_proto, model, (x,) ) diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 8abe3911..3cbf7e0c 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -49,7 +49,7 @@ class EagerDirectReplacementWithOnnx: :param n_outputs: same for the number of outputs, only tensors must be counted :param name: the name of the custom op, the function name if not specified - :param kwargs: constants + :param kwargs: constants parameters with their default values :param verbose: verbose level Here is an example: @@ -163,8 +163,8 @@ def __init__( .replace("", "l") .replace(".", "_") ) - self.kwargs = kwargs - assert kwargs is None or all(isinstance(v, (int, float)) for v in kwargs.values()), ( + self.kwargs = kwargs or {} + assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), ( f"Only int or floats are allowed for kwargs={kwargs}, one of them " f"does not respect that constraint." ) @@ -184,7 +184,8 @@ def __init__( assert ( function_proto.domain == self.domain ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}" - self.arg_names = params + self.args_name = [p for p in params if p not in self.kwargs] + self.kwargs_name = [p for p in params if p in self.kwargs] self.verbose = verbose self.custom_op = self._register() @@ -211,7 +212,19 @@ def __call__(self, *args): def _register(self): """Registers the custom op.""" - inputs = ", ".join([f"Tensor {p}" for p in self.arg_names]) + input_args = [f"Tensor {p}" for p in self.args_name] + for p in self.kwargs_name: + val = self.kwargs[p] + if isinstance(val, int): + input_args.append(f"int {p}={val}") + elif isinstance(val, float): + input_args.append(f"float {p}={val}") + else: + raise NotImplementedError( + f"kwargs {p!r} has a default value of unsupported type {type(val)}" + ) + + inputs = ", ".join(input_args) schema = f"({inputs}) -> Tensor" if self.n_outputs > 1: schema += "[]" @@ -292,12 +305,15 @@ def converter( self.function_proto.name, domain=self.function_proto.domain ): g.add_function(self.function_proto) + ags = args[: len(self.args_name)] + kws = dict(zip(self.kwargs_name, args[len(self.args_name) :])) res = g.make_node( self.function_proto.name, - args, + ags, outputs, domain=self.function_proto.domain, name=self.target_name, + **kws, ) if not sts: new_shapes = self.shape_fn(*args) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index 3bf2d5ea..e53570c0 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -27,7 +27,12 @@ @onnxscript.script(opset=onnx_plugs_op) def LoopMHAAttention( - query_states, key_states, value_states, cu_seqlens, scale: float, num_heads: int + query_states, + key_states, + value_states, + cu_seqlens, + scaling: float = 0.11180339887498948, + num_heads: int = 16, ): to_3d_shape = op.Constant(value_ints=[0, 0, -1]) query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3]) @@ -52,7 +57,7 @@ def LoopMHAAttention( key_i, value_i, num_heads=num_heads, - scale=scale, + scale=scaling, ) attn_output = op.Concat(attn_output, mha_output, axis=1) attn_output_4d = op.Reshape(attn_output, output_shape) @@ -64,7 +69,7 @@ def PackedAttention( key, value, cu_seqlens, - scale: float = 0.11180339887498948, + scaling: float = 0.11180339887498948, num_heads: int = 16, ): num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1 @@ -102,7 +107,7 @@ def PackedAttention( None, op.Cast(token_offset, to=onnx.TensorProto.INT32), op.Cast(cu_seqlens, to=onnx.TensorProto.INT32), - scale=scale, + scale=scaling, num_heads=num_heads, ) packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d) @@ -139,10 +144,8 @@ def qwen_sdpa_attention( # not ideal qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx( - lambda qs, ks, vs, cuseq: qwen_sdpa_attention( - qs, ks, vs, cuseq, scaling=0.11180339887498948 - ), - lambda qs, *args: torch.empty( + qwen_sdpa_attention, + lambda qs, *args, **kwargs: torch.empty( (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]), dtype=qs.dtype, device=qs.device, @@ -489,16 +492,13 @@ def forward( is transformers.integrations.sdpa_attention.sdpa_attention_forward or attention_interface is patched_sdpa_attention_forward ) and strategy_for_attention_in_qwen_2_5 == "PACKED": - torch._check( - qwen_sdpa_attention_versatile.kwargs["scaling"] == self.scaling, - lambda: f"Not implemented for scaling={self.scaling}", - ) - torch._check( - qwen_sdpa_attention_versatile.kwargs["num_heads"] == self.num_heads, - lambda: f"Not implemented for num_heads={self.num_heads}", - ) attn_output = qwen_sdpa_attention_versatile( - query_states, key_states, value_states, cu_seqlens + query_states, + key_states, + value_states, + cu_seqlens, + scaling=self.scaling, + num_heads=self.num_heads, ) elif _is_torchdynamo_exporting(): if self.config._attn_implementation == "flash_attention_2": From 7747de33b4d19abf4966ef190cbf646cf4aba109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Nov 2025 15:20:19 +0100 Subject: [PATCH 07/12] fix --- onnx_diagnostic/export/onnx_plug.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 3cbf7e0c..4fcb1586 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -204,11 +204,11 @@ def torch_op(self) -> Callable: "Returns ``torch.ops.onny_plug." return getattr(getattr(torch.ops, self.domain), self.name).default - def __call__(self, *args): + def __call__(self, *args, **kwargs): """Calls eager_fn or shape_fn if the model is being exported.""" if is_exporting(): return self.torch_op(*args) - return self.eager_fn(*args) + return self.eager_fn(*args, **kwargs) def _register(self): """Registers the custom op.""" From 4f6603f48d15ae4eb93a95ca92392e258dd74f70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Nov 2025 15:00:09 +0000 Subject: [PATCH 08/12] update script --- _doc/technical/plot_gemm_or_matmul_add.py | 74 ++++++++++++++++------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/_doc/technical/plot_gemm_or_matmul_add.py b/_doc/technical/plot_gemm_or_matmul_add.py index ea5d7483..100164f1 100644 --- a/_doc/technical/plot_gemm_or_matmul_add.py +++ b/_doc/technical/plot_gemm_or_matmul_add.py @@ -145,7 +145,7 @@ def matrix_diff(tensors): # a lot higher. B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384 -labels = ["linear", *[o.name for o in model.graph.output], "a @ x + b"] +labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"] all_results = {} for itype, dtype, device in [ @@ -187,28 +187,58 @@ def matrix_diff(tensors): # bias value vs discrepancies # =========================== # -# Let's compare GemmOnly (so bias is included) and Gemm+Add. - -i, j = 1, -1 -labs = labels[i], labels[j] - -fig, ax = plt.subplots(len(all_results), 2, figsize=(8, 2.5 * len(results))) -for pos, ((device, dtype), results) in enumerate(all_results.items()): - m1, m2 = results[i], results[j] - diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0] - print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}") - expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2 - ax[pos, 0].plot(B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), ".") - ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}") - - corr = matrix_diff(results) - ax[pos, 1].imshow(corr, cmap="Blues", vmin=0, vmax=corr.max()) - # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}') - ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45) - ax[pos, 1].set_yticks(range(len(labels)), labels) - ax[pos, 1].set_title(f"max={diff.max()}") +# Let's compare torch linear with GemmOnly. + + +def make_figure_axis(all_results, i, j): + labs = labels[i], labels[j] + fig, ax = plt.subplots(len(all_results), 2, figsize=(12, 4 * len(all_results))) + for pos, ((device, dtype), results) in enumerate(all_results.items()): + m1, m2 = results[i], results[j] + diff = torch.abs(m1.to(torch.float32) - m2.to(torch.float32)).max(dim=0)[0] + print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}") + expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2 + ax[pos, 0].plot( + B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), "." + ) + ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10) + + corr = matrix_diff(results) + ax[pos, 1].imshow(corr, cmap="Wistia", vmin=0, vmax=corr.max()) + # ax[pos,1].colorbar(label=f'Discrepancies {device}/{dtype}') + ax[pos, 1].set_xticks(range(len(labels)), labels, rotation=45, ha="right", fontsize=10) + ax[pos, 1].set_yticks(range(len(labels)), labels, fontsize=10) + ax[pos, 1].set_title(f"max={diff.max():1.2g}", fontsize=10) + for _i in range(corr.shape[0]): + for _j in range(corr.shape[1]): + ax[pos, 1].text( + _j, + _i, + f"{corr[_i, _j]:1.1g}", + ha="center", + va="center", + color="black", + fontsize=8, + ) + fig.suptitle( + f"Left column: discrepancies {labs[0]} VS {labs[1]}\n" + f"Right column: max absolute error, accross all configuration\n" + f"white is good, orange is not" + ) + return fig, ax + + +fig, ax = make_figure_axis(all_results, 0, 1) +fig.tight_layout() +fig.savefig("plot_gemm_or_matmul_add1.png") + +# %% +# Let's compare with ``a @ x + b``. + +fig, ax = make_figure_axis(all_results, -1, 1) fig.tight_layout() -fig.savefig("plot_gemm_or_matmul_add.png") +fig.savefig("plot_gemm_or_matmul_add2.png") + # %% # Discrepancies do not happen all the time but it is very likely to happen. From 179c9cfde1312e60d58544e0ae21e865dc728c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Nov 2025 15:53:47 +0000 Subject: [PATCH 09/12] fix --- _doc/technical/plot_gemm_or_matmul_add.py | 24 +++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/_doc/technical/plot_gemm_or_matmul_add.py b/_doc/technical/plot_gemm_or_matmul_add.py index 100164f1..ea27e9b5 100644 --- a/_doc/technical/plot_gemm_or_matmul_add.py +++ b/_doc/technical/plot_gemm_or_matmul_add.py @@ -53,6 +53,14 @@ def make_model_gemm(itype: int) -> onnx.ModelProto: oh.make_node("Add", ["mm", "B"], ["MatMulAdd"]), oh.make_node("FusedMatMul", ["A", "X"], ["fmm"], domain="com.microsoft"), oh.make_node("Add", ["fmm", "B"], ["FusedMatMulAdd"]), + oh.make_node("Cast", ["A"], ["Afloat"], to=onnx.TensorProto.FLOAT), + oh.make_node("Cast", ["B"], ["Bfloat"], to=onnx.TensorProto.FLOAT), + oh.make_node("Cast", ["X"], ["Xfloat"], to=onnx.TensorProto.FLOAT), + oh.make_node("Gemm", ["Afloat", "Xfloat"], ["gmmfloat"]), + oh.make_node("Add", ["gmmfloat", "Bfloat"], ["gemmaddfloat"]), + oh.make_node("Cast", ["gemmaddfloat"], ["CastGemmAddCast"], to=itype), + oh.make_node("Gemm", ["Afloat", "Xfloat", "Bfloat"], ["GemmOnlyfloat"]), + oh.make_node("Cast", ["GemmOnlyfloat"], ["CastGemmOnlyCast"], to=itype), ], "test", [ @@ -65,6 +73,8 @@ def make_model_gemm(itype: int) -> onnx.ModelProto: oh.make_tensor_value_info("GemmAdd", itype, ["a", "c"]), oh.make_tensor_value_info("FusedMatMulAdd", itype, ["a", "c"]), oh.make_tensor_value_info("MatMulAdd", itype, ["a", "c"]), + oh.make_tensor_value_info("CastGemmAddCast", itype, ["a", "c"]), + oh.make_tensor_value_info("CastGemmOnlyCast", itype, ["a", "c"]), ], ), opset_imports=[oh.make_opsetid("", 22)], @@ -85,7 +95,7 @@ def matrix_diff(tensors): dtype = np.float16 model = make_model_gemm(itype) -A = np.random.randn(512, 256).astype(dtype) +A = np.random.randn(1280, 256).astype(dtype) X = np.random.randn(256, 256).astype(dtype) B = np.random.randn(256).astype(dtype) feeds = dict(A=A, X=X, B=B) @@ -112,9 +122,9 @@ def matrix_diff(tensors): # %% # Let's try with CUDA and float32 if it is available. -A = torch.randn((512, 512), dtype=torch.float32) -X = torch.randn((512, 512), dtype=torch.float32) -B = torch.randn((512), dtype=torch.float32) +A = torch.randn((1280, 1280), dtype=torch.float32) +X = torch.randn((1280, 1280), dtype=torch.float32) +B = torch.randn((1280), dtype=torch.float32) for itype, dtype, device in [ (onnx.TensorProto.FLOAT16, torch.float16, "cpu"), @@ -144,7 +154,9 @@ def matrix_diff(tensors): # are similar to the others coefficients. What if we make them # a lot higher. -B = (torch.arange(512, dtype=torch.float32) + 1) / 512 * 16384 +A = A / A.max() +X = X / X.max() +B = (torch.arange(1280, dtype=torch.float32) + 1) / 1280 * 16 labels = ["F.linear", *[o.name for o in model.graph.output], "a @ x + b"] all_results = {} @@ -199,7 +211,7 @@ def make_figure_axis(all_results, i, j): print(f"labels={labs}, {device}/{dtype}: max(diff)={diff.max()}") expand = 0.5 if diff.max() >= 1 else diff.max().detach().cpu() / 2 ax[pos, 0].plot( - B.tolist(), (diff.detach().cpu() + torch.rand(512) * expand).tolist(), "." + B.tolist(), (diff.detach().cpu() + torch.rand(1280) * expand).tolist(), "." ) ax[pos, 0].set_title(f"{labs[0]}-{labs[1]} {device}/{dtype}", fontsize=10) From 262f9707418a9be0c1df74d6fa8844e3af9fb1a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 20 Nov 2025 10:19:15 +0100 Subject: [PATCH 10/12] fix test --- _doc/technical/plot_gemm_or_matmul_add.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_doc/technical/plot_gemm_or_matmul_add.py b/_doc/technical/plot_gemm_or_matmul_add.py index ea27e9b5..3720df22 100644 --- a/_doc/technical/plot_gemm_or_matmul_add.py +++ b/_doc/technical/plot_gemm_or_matmul_add.py @@ -245,7 +245,7 @@ def make_figure_axis(all_results, i, j): fig.savefig("plot_gemm_or_matmul_add1.png") # %% -# Let's compare with ``a @ x + b``. +# Let's compare with ``A @ X + B``. fig, ax = make_figure_axis(all_results, -1, 1) fig.tight_layout() From e559a5690d449af45954ecff15ae8d7e3aadac3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 20 Nov 2025 12:48:10 +0100 Subject: [PATCH 11/12] fix --- .../ut_torch_export_patches/test_patch_transformers.py | 2 +- onnx_diagnostic/export/onnx_plug.py | 2 ++ .../patches/_patch_transformers_qwen2_5.py | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index e7263839..be7aac37 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -407,7 +407,7 @@ def test_patched_qwen2_5_vl_vision_attention_forward(self): _is_torchdynamo_exporting() ), f"exporting is not set to true? {torch.compiler.is_exporting_flag}" got = patched_Qwen2_5_VLVisionAttention.forward(instance, **inputs) - self.assertEqualArray(expected, got, atol=1e-5) + self.assertEqualArray(expected, got, atol=1e-2) class Model(patched_class): def forward( diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index 4fcb1586..ac57830f 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -300,6 +300,7 @@ def converter( sts: Optional[Dict[str, Any]], outputs: List[str], *args, + **kwargs, ) -> Any: if not g.has_local_function( self.function_proto.name, domain=self.function_proto.domain @@ -307,6 +308,7 @@ def converter( g.add_function(self.function_proto) ags = args[: len(self.args_name)] kws = dict(zip(self.kwargs_name, args[len(self.args_name) :])) + kws.update(kwargs) res = g.make_node( self.function_proto.name, ags, diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py index e53570c0..a68824de 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py @@ -119,6 +119,7 @@ def qwen_sdpa_attention( value_states: torch.Tensor, # F10s1x16xs47x80 cu_seqlens: torch.Tensor, # F7su19 scaling: float = 0, + num_heads: int = 16, ) -> torch.Tensor: lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ @@ -497,8 +498,8 @@ def forward( key_states, value_states, cu_seqlens, - scaling=self.scaling, - num_heads=self.num_heads, + self.scaling, + self.num_heads, ) elif _is_torchdynamo_exporting(): if self.config._attn_implementation == "flash_attention_2": From 1234be362e42c040ddfc93e89d8c72ba59886fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 20 Nov 2025 13:18:42 +0100 Subject: [PATCH 12/12] fix documentation --- _doc/technical/plot_gemm_or_matmul_add.py | 2 +- onnx_diagnostic/export/onnx_plug.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/_doc/technical/plot_gemm_or_matmul_add.py b/_doc/technical/plot_gemm_or_matmul_add.py index 3720df22..0e574552 100644 --- a/_doc/technical/plot_gemm_or_matmul_add.py +++ b/_doc/technical/plot_gemm_or_matmul_add.py @@ -234,7 +234,7 @@ def make_figure_axis(all_results, i, j): ) fig.suptitle( f"Left column: discrepancies {labs[0]} VS {labs[1]}\n" - f"Right column: max absolute error, accross all configuration\n" + f"Right column: max absolute error, across all configuration\n" f"white is good, orange is not" ) return fig, ax diff --git a/onnx_diagnostic/export/onnx_plug.py b/onnx_diagnostic/export/onnx_plug.py index ac57830f..86e69092 100644 --- a/onnx_diagnostic/export/onnx_plug.py +++ b/onnx_diagnostic/export/onnx_plug.py @@ -201,7 +201,7 @@ def target_name(self) -> str: @property def torch_op(self) -> Callable: - "Returns ``torch.ops.onny_plug." + "Returns ``torch.ops.onny_plug.``." return getattr(getattr(torch.ops, self.domain), self.name).default def __call__(self, *args, **kwargs):