From 4bec7375916f3096363899a2809af27430f3d6f4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 11:54:07 +0200 Subject: [PATCH 1/8] Patches eager_mode for whisper-tiny --- .../test_patch_models.py | 23 ++++++++++ .../onnx_export_errors.py | 42 +++++++++++++++++-- .../patches/patch_transformers.py | 34 +++++++++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 _unittests/ut_torch_export_patches/test_patch_models.py diff --git a/_unittests/ut_torch_export_patches/test_patch_models.py b/_unittests/ut_torch_export_patches/test_patch_models.py new file mode 100644 index 00000000..8b09e249 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_models.py @@ -0,0 +1,23 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + +class TestHuggingFaceHubModel(ExtTestCase): + @hide_stdout() + @requires_transformers("4.51") + def test_patch_eager_mask_open_whisper_tiny(self): + mid = "openai/whisper-tiny" + data = get_untrained_model_with_inputs(mid, verbose=1) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**torch_deepcopy(inputs)) + with torch_export_patches(patch_transformers=True, verbose=1): + torch.export.export(model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 058041e5..79a90074 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -429,6 +429,23 @@ def torch_export_patches( f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv + if masking_utils and hasattr(masking_utils, "eager_mask"): + if verbose: + print( + "[torch_export_patches] patches " + "transformers.masking_utils.eager_mask" + ) + f_transformers_eager_mask = masking_utils.eager_mask + masking_utils.eager_mask = patch_transformers_list.patched_eager_mask + if ( + "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] + == f_transformers_eager_mask + ): + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = ( + patch_transformers_list.patched_eager_mask + ) + if custom_patches: if verbose: print("[torch_export_patches] applies custom patches") @@ -511,7 +528,7 @@ def torch_export_patches( if custom_patches: if verbose: - print("[torch_export_patches] unpatch custom patches") + print("[torch_export_patches] unpatches custom patches") unpatch_module_or_classes( custom_patches, revert_custom_patches_info, verbose=verbose ) @@ -526,18 +543,35 @@ def torch_export_patches( except ImportError: masking_utils = None if verbose: - print("[torch_export_patches] unpatch transformers") + print("[torch_export_patches] unpatches transformers") unpatch_module_or_classes( patch_transformers_list, revert_patches_info, verbose=verbose ) if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"): + masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv if verbose: print( - "[torch_export_patches] unpatch " + "[torch_export_patches] restored " "transformers.masking_utils._vmap_for_bhqkv" ) - masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv + + if masking_utils and hasattr(masking_utils, "eager_mask"): + f_transformers_eager_mask = masking_utils.eager_mask + masking_utils.eager_mask = f_transformers_eager_mask + if ( + "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] + == patch_transformers_list.patched_eager_mask + ): + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = ( + f_transformers_eager_mask + ) + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.eager_mask" + ) ######## # caches diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index d06939ac..08cf3490 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -7,6 +7,7 @@ import transformers from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.cache_utils import StaticCache, Cache, DynamicCache +from transformers.masking_utils import causal_mask_function, sdpa_mask from ...ext_test_case import has_transformers from ...helpers.torch_helper import is_torchdynamo_exporting @@ -1046,3 +1047,36 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def patched_eager_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + **kwargs, +) -> torch.Tensor: + """manual patch for function ``transformers.masking_utils.eager_mask``.""" + # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf + _ = kwargs.pop("allow_is_causal_skip", None) + mask = sdpa_mask( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_function, + attention_mask=attention_mask, + allow_is_causal_skip=False, + allow_torch_fix=False, + **kwargs, + ) + min_dtype = torch.finfo(dtype).min + # The patched line. + # we need 0s where the tokens should be taken into account, + # and -inf otherwise (mask is already of boolean type) + # mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) + mask = (~mask).to(dtype) * min_dtype + return mask From 907ee41377e4a206cc7266baaf75445fd2c9b606 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 12:20:45 +0200 Subject: [PATCH 2/8] fix --- .../onnx_export_errors.py | 24 ++- .../patches/patch_transformers.py | 178 ++++++++++-------- 2 files changed, 116 insertions(+), 86 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 79a90074..ed8ece0b 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -420,7 +420,11 @@ def torch_export_patches( patch_transformers_list, verbose=verbose ) - if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"): + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "_vmap_for_bhqkv") + ): if verbose: print( "[torch_export_patches] patches " @@ -429,7 +433,11 @@ def torch_export_patches( f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv - if masking_utils and hasattr(masking_utils, "eager_mask"): + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "eager_mask") + ): if verbose: print( "[torch_export_patches] patches " @@ -548,7 +556,11 @@ def torch_export_patches( patch_transformers_list, revert_patches_info, verbose=verbose ) - if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"): + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "_vmap_for_bhqkv") + ): masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv if verbose: print( @@ -556,7 +568,11 @@ def torch_export_patches( "transformers.masking_utils._vmap_for_bhqkv" ) - if masking_utils and hasattr(masking_utils, "eager_mask"): + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "eager_mask") + ): f_transformers_eager_mask = masking_utils.eager_mask masking_utils.eager_mask = f_transformers_eager_mask if ( diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 08cf3490..3ebc9e8f 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -7,60 +7,107 @@ import transformers from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.cache_utils import StaticCache, Cache, DynamicCache -from transformers.masking_utils import causal_mask_function, sdpa_mask + +try: + import transformers.masking_utils + + patch_masking_utils = True +except ImportError: + patch_masking_utils = False + from ...ext_test_case import has_transformers from ...helpers.torch_helper import is_torchdynamo_exporting -def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: - """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" - from ...helpers import string_type - - dimensions: List[Tuple[Optional[int], ...]] = [ - (None, None, None, 0), - (None, None, 0, None), - ] - if bh_indices: - dimensions.extend([(None, 0, None, None), (0, None, None, None)]) - # reshape - dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] - dimensions = tuple(reversed(dimensions)) - indices = tuple(shape.index(-1) for shape in dimensions) - - # unsqueeze - udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions] - - def vector_mask_function( - *args, mask_function=mask_function, dimensions=dimensions, indices=indices - ): - assert len(args) == len(dimensions) == len(udimensions), ( - f"Mismatch between args={string_type(args)} and dimensions={dimensions} " - f"and udimensions={udimensions}." - ) - assert len(indices) == len(args), ( - f"Mismatch between args={string_type(args)} and indices={indices}, " - f"they should have the same length." +if patch_masking_utils: + # Introduced in 4.52 + from transformers.masking_utils import causal_mask_function, sdpa_mask + + def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: + """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" + from ...helpers import string_type + + dimensions: List[Tuple[Optional[int], ...]] = [ + (None, None, None, 0), + (None, None, 0, None), + ] + if bh_indices: + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) + # reshape + dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] + dimensions = tuple(reversed(dimensions)) + indices = tuple(shape.index(-1) for shape in dimensions) + + # unsqueeze + udimensions = [ + tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions + ] + + def vector_mask_function( + *args, mask_function=mask_function, dimensions=dimensions, indices=indices + ): + assert len(args) == len(dimensions) == len(udimensions), ( + f"Mismatch between args={string_type(args)} and dimensions={dimensions} " + f"and udimensions={udimensions}." + ) + assert len(indices) == len(args), ( + f"Mismatch between args={string_type(args)} and indices={indices}, " + f"they should have the same length." + ) + for a in args: + assert ( + a.ndim == 1 + ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" + torch._check(a.shape[0] > 0) + + new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] + # new_args = [ + # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) + # for a, dims in zip(args, udimensions) + # ] + max_shape = tuple(args[i].shape[0] for i in indices) + # if is_torchdynamo_exporting(): + # for a in args: + # # The exporter should export with a dimension > 1 + # # to make sure it is dynamic. + # torch._check(a.shape[0] > 1) + expanded_args = [a.expand(max_shape) for a in new_args] + return mask_function(*expanded_args) + + return vector_mask_function + + def patched_eager_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + **kwargs, + ) -> torch.Tensor: + """manual patch for function ``transformers.masking_utils.eager_mask``.""" + # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf + _ = kwargs.pop("allow_is_causal_skip", None) + mask = sdpa_mask( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_function, + attention_mask=attention_mask, + allow_is_causal_skip=False, + allow_torch_fix=False, + **kwargs, ) - for a in args: - assert ( - a.ndim == 1 - ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" - torch._check(a.shape[0] > 0) - - new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] - # new_args = [ - # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) - # for a, dims in zip(args, udimensions) - # ] - max_shape = tuple(args[i].shape[0] for i in indices) - # if is_torchdynamo_exporting(): - # for a in args: - # # The exporter should export with a dimension > 1 to make sure it is dynamic. - # torch._check(a.shape[0] > 1) - expanded_args = [a.expand(max_shape) for a in new_args] - return mask_function(*expanded_args) - - return vector_mask_function + min_dtype = torch.finfo(dtype).min + # The patched line. + # we need 0s where the tokens should be taken into account, + # and -inf otherwise (mask is already of boolean type) + # mask = + # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) + mask = (~mask).to(dtype) * min_dtype + return mask def _patch_make_causal_mask( @@ -1047,36 +1094,3 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value - - -def patched_eager_mask( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, - kv_offset: int = 0, - mask_function: Callable = causal_mask_function, - attention_mask: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - **kwargs, -) -> torch.Tensor: - """manual patch for function ``transformers.masking_utils.eager_mask``.""" - # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf - _ = kwargs.pop("allow_is_causal_skip", None) - mask = sdpa_mask( - batch_size=batch_size, - cache_position=cache_position, - kv_length=kv_length, - kv_offset=kv_offset, - mask_function=mask_function, - attention_mask=attention_mask, - allow_is_causal_skip=False, - allow_torch_fix=False, - **kwargs, - ) - min_dtype = torch.finfo(dtype).min - # The patched line. - # we need 0s where the tokens should be taken into account, - # and -inf otherwise (mask is already of boolean type) - # mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) - mask = (~mask).to(dtype) * min_dtype - return mask From 7a2f3863a46d899e56ac0c7caa69263f23888bda Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 12:52:13 +0200 Subject: [PATCH 3/8] fix ut --- .../test_patch_torch.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 99c2465a..ffc7c683 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -2,15 +2,14 @@ from typing import Callable import torch from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch -from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap -from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( - patched__vmap_for_bhqkv as _vmap_for_bhqkv2, -) +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers class TestPatchPatchTorch(ExtTestCase): + @requires_transformers("4.52") def test_vmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + f = lambda x, y: x * y + 1 # noqa: E731 x = torch.tensor([1.0, 2.0, 3.0]) y = torch.tensor([0.1, 0.2, 0.3]) @@ -32,7 +31,10 @@ def forward(self, x, y): self.assertEqualArray(Model()(x, y), ep.module()(x, y)) @requires_torch("2.8") + @requires_transformers("4.52") def test_export_patched_vmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + class Model(torch.nn.Module): def forward(self, x, y): f = lambda x, y: x * y + 1 # noqa: E731 @@ -43,14 +45,20 @@ def forward(self, x, y): ep = torch.export.export(Model(), (x, y)) self.assertEqualArray(Model()(x, y), ep.module()(x, y)) + @requires_transformers("4.52") def test_vmap_outdim(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + f = lambda x: x**2 # noqa: E731 x = torch.randn(2, 5) expected = torch.vmap(f, out_dims=1)(x) got = patched_vmap(f, out_dims=1)(x) self.assertEqualArray(expected, got) + @requires_transformers("4.52") def test_vmap_dict(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + f = lambda d: torch.dot(d["x"], d["y"]) # noqa: E731 x, y = torch.randn(2, 5), torch.randn(5) input = {"x": x, "y": y} @@ -60,13 +68,19 @@ def test_vmap_dict(self): ) # self.assertEqualArray(_expected, got) + @requires_transformers("4.52") def test_vmap_tuple(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + x, y = torch.randn(2, 5), torch.randn(5) expected = torch.vmap(torch.dot, in_dims=(0, None))(x, y) got = patched_vmap(torch.dot, in_dims=(0, None))(x, y) self.assertEqualArray(expected, got, atol=1e-5) + @requires_transformers("4.52") def test_vmap_transformers_scenario_vmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: def inner_mask(batch_idx, head_idx, q_idx, kv_idx): return padding_mask[batch_idx, kv_idx] @@ -140,7 +154,12 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange): self.assertEqualArray(causal_mask, ep.moule(*inputs)) @requires_torch("2.8") + @requires_transformers("4.52") def test_vmap_transformers_scenario_novmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + patched__vmap_for_bhqkv as _vmap_for_bhqkv2, + ) + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: def inner_mask(batch_idx, head_idx, q_idx, kv_idx): return padding_mask[batch_idx, kv_idx] From 84eda70e9b3037263af1408c90b1cbb4840d84b8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 13:06:27 +0200 Subject: [PATCH 4/8] ut --- _scripts/test_backend_onnxruntime.py | 21 ++++++++++--------- .../test_backend_onnxruntime_evaluator.py | 4 ++-- .../test_patch_torch.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/_scripts/test_backend_onnxruntime.py b/_scripts/test_backend_onnxruntime.py index 48bb1777..9ae959f2 100644 --- a/_scripts/test_backend_onnxruntime.py +++ b/_scripts/test_backend_onnxruntime.py @@ -26,12 +26,13 @@ def run(self, inputs, **kwargs): if isinstance(inputs, numpy.ndarray): inputs = [inputs] if isinstance(inputs, list): - if len(inputs) == len(self._session.input_names): - feeds = dict(zip(self._session.input_names, inputs)) + if len(inputs) == len(self._session.get_inputs()): + feeds = dict(zip([i.name for i in self._session.get_inputs()], inputs)) else: + input_names = [i.name for i in self._session.get_inputs()] feeds = {} pos_inputs = 0 - for inp, tshape in zip(self._session.input_names, self._session.input_types): + for inp, tshape in zip(input_names, self._session.input_types): shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) if shape == inputs[pos_inputs].shape: feeds[inp] = inputs[pos_inputs] @@ -54,20 +55,20 @@ def is_compatible(cls, model) -> bool: @classmethod def supports_device(cls, device: str) -> bool: d = Device(device) - if d == DeviceType.CPU: + if d.type == DeviceType.CPU: return True - if d == DeviceType.CUDA: - import torch - - return torch.cuda.is_available() + # if d.type == DeviceType.CUDA: + # import torch + # + # return torch.cuda.is_available() return False @classmethod def create_inference_session(cls, model, device): d = Device(device) - if d == DeviceType.CUDA: + if d.type == DeviceType.CUDA: providers = ["CUDAExecutionProvider"] - elif d == DeviceType.CPU: + elif d.type == DeviceType.CPU: providers = ["CPUExecutionProvider"] else: raise ValueError(f"Unrecognized device {device!r} or {d!r}") diff --git a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py index 9f596c4f..169a4e06 100644 --- a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py @@ -50,9 +50,9 @@ def is_compatible(cls, model) -> bool: @classmethod def supports_device(cls, device: str) -> bool: d = Device(device) - if d == DeviceType.CPU: + if d.type == DeviceType.CPU: return True - if d == DeviceType.CUDA: + if d.type == DeviceType.CUDA: import torch return torch.cuda.is_available() diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index ffc7c683..7ed791d1 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -154,7 +154,7 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange): self.assertEqualArray(causal_mask, ep.moule(*inputs)) @requires_torch("2.8") - @requires_transformers("4.52") + @requires_transformers("4.53") def test_vmap_transformers_scenario_novmap(self): from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched__vmap_for_bhqkv as _vmap_for_bhqkv2, From 3ff259e39fcdf2449819dfae1b75c562c3d58889 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 13:18:41 +0200 Subject: [PATCH 5/8] fix --- CHANGELOGS.rst | 3 ++- _unittests/ut_reference/test_backend_onnxruntime_evaluator.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c3778de4..3ead0e8a 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,8 @@ Change Logs 0.7.4 +++++ -* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs +* :pr:`178`: add a patch for eager_mask to handle ``assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs`` +* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs 0.7.3 +++++ diff --git a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py index 169a4e06..521aab6f 100644 --- a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py @@ -61,9 +61,9 @@ def supports_device(cls, device: str) -> bool: @classmethod def create_inference_session(cls, model, device): d = Device(device) - if d == DeviceType.CUDA: + if d.type == DeviceType.CUDA: providers = ["CUDAExecutionProvider"] - elif d == DeviceType.CPU: + elif d.type == DeviceType.CPU: providers = ["CPUExecutionProvider"] else: raise ValueError(f"Unrecognized device {device!r} or {d!r}") From b2a2bbb75bc5c5ad51170e50d102528df0f9ebb6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 13:41:20 +0200 Subject: [PATCH 6/8] Change the meaning of inputs2, add_second_input --- onnx_diagnostic/_command_lines_parser.py | 20 ++++++++++++---- .../tasks/automatic_speech_recognition.py | 4 ++-- onnx_diagnostic/tasks/feature_extraction.py | 4 ++-- onnx_diagnostic/tasks/fill_mask.py | 4 ++-- onnx_diagnostic/tasks/image_classification.py | 6 ++--- onnx_diagnostic/tasks/image_text_to_text.py | 6 ++--- onnx_diagnostic/tasks/mixture_of_expert.py | 2 +- onnx_diagnostic/tasks/object_detection.py | 6 ++--- onnx_diagnostic/tasks/sentence_similarity.py | 4 ++-- onnx_diagnostic/tasks/summarization.py | 6 ++--- onnx_diagnostic/tasks/text2text_generation.py | 6 ++--- onnx_diagnostic/tasks/text_classification.py | 4 ++-- onnx_diagnostic/tasks/text_generation.py | 6 ++--- onnx_diagnostic/tasks/text_to_image.py | 4 ++-- .../tasks/zero_shot_image_classification.py | 4 ++-- .../torch_models/hghub/model_inputs.py | 2 +- onnx_diagnostic/torch_models/validate.py | 24 +++++++++++++++---- 17 files changed, 69 insertions(+), 43 deletions(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 6170d34c..46a765fe 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -349,6 +349,15 @@ def get_parser_validate() -> ArgumentParser: python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\ --dtype float16 --device cuda --export modelbuilder + + position_ids is usually not needed, they can be removed by adding: + + --drop position_ids + + The behaviour may be modified compare the original configuration, + the following argument can be rope_scaling to dynamic: + + --mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"" """ ), formatter_class=RawTextHelpFormatter, @@ -403,10 +412,12 @@ def get_parser_validate() -> ArgumentParser: ) parser.add_argument( "--inputs2", - default=True, - action=BooleanOptionalAction, + default=1, + type=int, help="Validates the model on a second set of inputs\n" - "to check the exported model supports dynamism.", + "to check the exported model supports dynamism. The values is used " + "as an increment to the first set of inputs. A high value may trick " + "a different behavior in the model and missed by the exporter.", ) parser.add_argument( "--runtime", @@ -422,7 +433,8 @@ def get_parser_validate() -> ArgumentParser: parser.add_argument( "--drop", help="Drops the following inputs names, it should be a list\n" - "with comma separated values.", + "with comma separated values, example:\n" + "--drop position_ids", ) parser.add_argument( "--opset", diff --git a/onnx_diagnostic/tasks/automatic_speech_recognition.py b/onnx_diagnostic/tasks/automatic_speech_recognition.py index f1b4ae6b..8b1410ba 100644 --- a/onnx_diagnostic/tasks/automatic_speech_recognition.py +++ b/onnx_diagnostic/tasks/automatic_speech_recognition.py @@ -33,7 +33,7 @@ def get_inputs( head_dim: int, batch_size: int = 2, sequence_length: int = 30, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -144,7 +144,7 @@ def get_inputs( decoder_layers=decoder_layers, head_dim=head_dim, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, + sequence_length=sequence_length + add_second_input, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/feature_extraction.py b/onnx_diagnostic/tasks/feature_extraction.py index 4bac2aed..1875e8b6 100644 --- a/onnx_diagnostic/tasks/feature_extraction.py +++ b/onnx_diagnostic/tasks/feature_extraction.py @@ -22,7 +22,7 @@ def get_inputs( batch_size: int, sequence_length: int, dummy_max_token_id: int, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -56,7 +56,7 @@ def get_inputs( model=model, config=config, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, + sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, **kwargs, )["inputs"] diff --git a/onnx_diagnostic/tasks/fill_mask.py b/onnx_diagnostic/tasks/fill_mask.py index 63a05811..8b8bdb3d 100644 --- a/onnx_diagnostic/tasks/fill_mask.py +++ b/onnx_diagnostic/tasks/fill_mask.py @@ -22,7 +22,7 @@ def get_inputs( batch_size: int, sequence_length: int, dummy_max_token_id: int, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -58,7 +58,7 @@ def get_inputs( model=model, config=config, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, + sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, **kwargs, )["inputs"] diff --git a/onnx_diagnostic/tasks/image_classification.py b/onnx_diagnostic/tasks/image_classification.py index cc14e4a3..9f88c8d6 100644 --- a/onnx_diagnostic/tasks/image_classification.py +++ b/onnx_diagnostic/tasks/image_classification.py @@ -34,7 +34,7 @@ def get_inputs( input_channels: int, batch_size: int = 2, dynamic_rope: bool = False, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -78,8 +78,8 @@ def get_inputs( res["inputs2"] = get_inputs( model=model, config=config, - input_width=input_width + 1, - input_height=input_height + 1, + input_width=input_width + add_second_input, + input_height=input_height + add_second_input, input_channels=input_channels, batch_size=batch_size + 1, dynamic_rope=dynamic_rope, diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 4400b772..b482eec4 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -32,7 +32,7 @@ def get_inputs( sequence_length2: int = 3, n_images: int = 2, dynamic_rope: bool = False, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -116,8 +116,8 @@ def get_inputs( height=height, num_channels=num_channels, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, - sequence_length2=sequence_length2 + 1, + sequence_length=sequence_length + add_second_input, + sequence_length2=sequence_length2 + add_second_input, n_images=n_images + 1, dynamic_rope=dynamic_rope, **kwargs, diff --git a/onnx_diagnostic/tasks/mixture_of_expert.py b/onnx_diagnostic/tasks/mixture_of_expert.py index be6b7828..1376ade2 100644 --- a/onnx_diagnostic/tasks/mixture_of_expert.py +++ b/onnx_diagnostic/tasks/mixture_of_expert.py @@ -41,7 +41,7 @@ def get_inputs( sequence_length2: int = 3, n_images: int = 2, dynamic_rope: bool = False, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ diff --git a/onnx_diagnostic/tasks/object_detection.py b/onnx_diagnostic/tasks/object_detection.py index d8ce8073..6f7b2e8c 100644 --- a/onnx_diagnostic/tasks/object_detection.py +++ b/onnx_diagnostic/tasks/object_detection.py @@ -27,7 +27,7 @@ def get_inputs( input_channels: int, batch_size: int = 2, dynamic_rope: bool = False, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -68,8 +68,8 @@ def get_inputs( res["inputs2"] = get_inputs( model=model, config=config, - input_width=input_width + 1, - input_height=input_height + 1, + input_width=input_width + add_second_input, + input_height=input_height + add_second_input, input_channels=input_channels, batch_size=batch_size + 1, dynamic_rope=dynamic_rope, diff --git a/onnx_diagnostic/tasks/sentence_similarity.py b/onnx_diagnostic/tasks/sentence_similarity.py index 4e304c47..c79428cd 100644 --- a/onnx_diagnostic/tasks/sentence_similarity.py +++ b/onnx_diagnostic/tasks/sentence_similarity.py @@ -22,7 +22,7 @@ def get_inputs( batch_size: int, sequence_length: int, dummy_max_token_id: int, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -58,7 +58,7 @@ def get_inputs( model=model, config=config, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, + sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, **kwargs, )["inputs"] diff --git a/onnx_diagnostic/tasks/summarization.py b/onnx_diagnostic/tasks/summarization.py index 3b2231a1..551541cc 100644 --- a/onnx_diagnostic/tasks/summarization.py +++ b/onnx_diagnostic/tasks/summarization.py @@ -29,7 +29,7 @@ def get_inputs( batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -154,8 +154,8 @@ def get_inputs( head_dim_encoder=head_dim_encoder, head_dim_decoder=head_dim_decoder, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, - sequence_length2=sequence_length2 + 1, + sequence_length=sequence_length + add_second_input, + sequence_length2=sequence_length2 + add_second_input, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index 6dd0e3b6..29e9676c 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -30,7 +30,7 @@ def get_inputs( batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -160,8 +160,8 @@ def get_inputs( head_dim_decoder=head_dim_decoder, encoder_dim=encoder_dim, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, - sequence_length2=sequence_length2 + 1, + sequence_length=sequence_length + add_second_input, + sequence_length2=sequence_length2 + add_second_input, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/text_classification.py b/onnx_diagnostic/tasks/text_classification.py index e3a1d727..ba30a75c 100644 --- a/onnx_diagnostic/tasks/text_classification.py +++ b/onnx_diagnostic/tasks/text_classification.py @@ -22,7 +22,7 @@ def get_inputs( batch_size: int, sequence_length: int, dummy_max_token_id: int, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -58,7 +58,7 @@ def get_inputs( model=model, config=config, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, + sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, **kwargs, )["inputs"] diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 873fa4fc..c2df7be4 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -72,7 +72,7 @@ def get_inputs( num_key_value_heads: Optional[int] = None, head_dim: Optional[int] = None, cls_cache: Optional[Union[type, str]] = None, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -261,8 +261,8 @@ def get_inputs( dummy_max_token_id=dummy_max_token_id, num_hidden_layers=num_hidden_layers, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, - sequence_length2=sequence_length2 + 1, + sequence_length=sequence_length + add_second_input, + sequence_length2=sequence_length2 + add_second_input, dynamic_rope=dynamic_rope, num_key_value_heads=num_key_value_heads, head_dim=head_dim, diff --git a/onnx_diagnostic/tasks/text_to_image.py b/onnx_diagnostic/tasks/text_to_image.py index 983d9bec..7426f48a 100644 --- a/onnx_diagnostic/tasks/text_to_image.py +++ b/onnx_diagnostic/tasks/text_to_image.py @@ -25,7 +25,7 @@ def get_inputs( in_channels: int, sample_size: int, cross_attention_dim: int, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -63,7 +63,7 @@ def get_inputs( config=config, batch_size=batch_size + 1, sequence_length=sequence_length, - cache_length=cache_length + 1, + cache_length=cache_length + add_second_input, in_channels=in_channels, sample_size=sample_size, cross_attention_dim=cross_attention_dim, diff --git a/onnx_diagnostic/tasks/zero_shot_image_classification.py b/onnx_diagnostic/tasks/zero_shot_image_classification.py index 83163552..80bc9ff4 100644 --- a/onnx_diagnostic/tasks/zero_shot_image_classification.py +++ b/onnx_diagnostic/tasks/zero_shot_image_classification.py @@ -34,7 +34,7 @@ def get_inputs( input_height: int = 224, input_channels: int = 3, batch_size_image=3, - add_second_input: bool = False, + add_second_input: int = 1, **kwargs, # unused ): """ @@ -92,7 +92,7 @@ def get_inputs( config=config, dummy_max_token_id=dummy_max_token_id, batch_size=batch_size + 1, - sequence_length=sequence_length + 1, + sequence_length=sequence_length + add_second_input, input_width=input_width, input_height=input_height, input_channels=input_channels, diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 1961e049..74531560 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -26,7 +26,7 @@ def get_untrained_model_with_inputs( use_pretrained: bool = False, same_as_pretrained: bool = False, use_preinstalled: bool = True, - add_second_input: bool = False, + add_second_input: int = 1, subfolder: Optional[str] = None, use_only_preinstalled: bool = False, ) -> Dict[str, Any]: diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 179b1f05..989a8101 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -156,6 +156,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]: "version_torch": torch.__version__, "version_numpy": numpy.__version__, } + try: + import scipy + + summary["version_scipy"] = getattr(scipy, "__version__", "?") + except ImportError: + pass try: import transformers @@ -180,6 +186,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]: summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?") except ImportError: pass + try: + import onnx_ir + + summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?") + except ImportError: + pass import onnx_diagnostic summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__ @@ -275,7 +287,7 @@ def validate_model( runtime: str = "onnxruntime", repeat: int = 1, warmup: int = 0, - inputs2: bool = True, + inputs2: int = 1, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Validates a model. @@ -324,7 +336,8 @@ def validate_model( :param repeat: number of time to measure the model :param warmup: warmup the model first :param inputs2: checks that the second set of inputs is reunning as well, - this ensures that the model does support dynamism + this ensures that the model does support dynamism, the value is used + as an increment to the first set of values (added to dimensions) :return: two dictionaries, one with some metrics, another one with whatever the function produces @@ -1053,7 +1066,7 @@ def validate_onnx_model( runtime: str = "onnxruntime", repeat: int = 1, warmup: int = 0, - inputs2: bool = True, + inputs2: int = 1, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Verifies that an onnx model produces the same @@ -1069,8 +1082,9 @@ def validate_onnx_model( :param runtime: onnx runtime to use, onnxruntime or torch :param repeat: run that number of times the model :param warmup: warmup the model - :param inputs: to validate the model on the second input set - to make sure the exported model supports dynamism + :param inputs2: to validate the model on the second input set + to make sure the exported model supports dynamism, the value is + used as an increment added to the first set of inputs (added to dimensions) :return: two dictionaries, one with some metrics, another one with whatever the function produces """ From 50f72f1a83fdb3770224612c01170f5777bd3dad Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 13:59:03 +0200 Subject: [PATCH 7/8] mechanism for inputs2 --- onnx_diagnostic/tasks/automatic_speech_recognition.py | 4 ++++ onnx_diagnostic/tasks/feature_extraction.py | 4 ++++ onnx_diagnostic/tasks/fill_mask.py | 4 ++++ onnx_diagnostic/tasks/image_classification.py | 4 ++++ onnx_diagnostic/tasks/image_text_to_text.py | 6 +++++- onnx_diagnostic/tasks/object_detection.py | 4 ++++ onnx_diagnostic/tasks/sentence_similarity.py | 4 ++++ onnx_diagnostic/tasks/summarization.py | 6 +++++- onnx_diagnostic/tasks/text2text_generation.py | 6 +++++- onnx_diagnostic/tasks/text_classification.py | 4 ++++ onnx_diagnostic/tasks/text_generation.py | 8 +++++--- onnx_diagnostic/tasks/text_to_image.py | 4 ++++ onnx_diagnostic/tasks/zero_shot_image_classification.py | 4 ++++ 13 files changed, 56 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/tasks/automatic_speech_recognition.py b/onnx_diagnostic/tasks/automatic_speech_recognition.py index 8b1410ba..b6da7e7a 100644 --- a/onnx_diagnostic/tasks/automatic_speech_recognition.py +++ b/onnx_diagnostic/tasks/automatic_speech_recognition.py @@ -132,6 +132,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -145,6 +148,7 @@ def get_inputs( head_dim=head_dim, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/feature_extraction.py b/onnx_diagnostic/tasks/feature_extraction.py index 1875e8b6..1d49147c 100644 --- a/onnx_diagnostic/tasks/feature_extraction.py +++ b/onnx_diagnostic/tasks/feature_extraction.py @@ -52,12 +52,16 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/fill_mask.py b/onnx_diagnostic/tasks/fill_mask.py index 8b8bdb3d..167993d2 100644 --- a/onnx_diagnostic/tasks/fill_mask.py +++ b/onnx_diagnostic/tasks/fill_mask.py @@ -54,12 +54,16 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/image_classification.py b/onnx_diagnostic/tasks/image_classification.py index 9f88c8d6..3a993399 100644 --- a/onnx_diagnostic/tasks/image_classification.py +++ b/onnx_diagnostic/tasks/image_classification.py @@ -75,6 +75,9 @@ def get_inputs( shapes["interpolate_pos_encoding"] = None # type: ignore[assignment] res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -83,6 +86,7 @@ def get_inputs( input_channels=input_channels, batch_size=batch_size + 1, dynamic_rope=dynamic_rope, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index b482eec4..e7b17a17 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -105,6 +105,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -117,9 +120,10 @@ def get_inputs( num_channels=num_channels, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, - sequence_length2=sequence_length2 + add_second_input, + sequence_length2=sequence_length2 + 1, n_images=n_images + 1, dynamic_rope=dynamic_rope, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/object_detection.py b/onnx_diagnostic/tasks/object_detection.py index 6f7b2e8c..e85e6355 100644 --- a/onnx_diagnostic/tasks/object_detection.py +++ b/onnx_diagnostic/tasks/object_detection.py @@ -65,6 +65,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -73,6 +76,7 @@ def get_inputs( input_channels=input_channels, batch_size=batch_size + 1, dynamic_rope=dynamic_rope, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/sentence_similarity.py b/onnx_diagnostic/tasks/sentence_similarity.py index c79428cd..5c7b7b04 100644 --- a/onnx_diagnostic/tasks/sentence_similarity.py +++ b/onnx_diagnostic/tasks/sentence_similarity.py @@ -54,12 +54,16 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/summarization.py b/onnx_diagnostic/tasks/summarization.py index 551541cc..4384f29d 100644 --- a/onnx_diagnostic/tasks/summarization.py +++ b/onnx_diagnostic/tasks/summarization.py @@ -144,6 +144,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -155,7 +158,8 @@ def get_inputs( head_dim_decoder=head_dim_decoder, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, - sequence_length2=sequence_length2 + add_second_input, + sequence_length2=sequence_length2 + 1, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index 29e9676c..989782f5 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -149,6 +149,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -161,7 +164,8 @@ def get_inputs( encoder_dim=encoder_dim, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, - sequence_length2=sequence_length2 + add_second_input, + sequence_length2=sequence_length2 + 1, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/text_classification.py b/onnx_diagnostic/tasks/text_classification.py index ba30a75c..14866f7c 100644 --- a/onnx_diagnostic/tasks/text_classification.py +++ b/onnx_diagnostic/tasks/text_classification.py @@ -54,12 +54,16 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, dummy_max_token_id=dummy_max_token_id, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index c2df7be4..599062bc 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -260,13 +260,15 @@ def get_inputs( config=config, dummy_max_token_id=dummy_max_token_id, num_hidden_layers=num_hidden_layers, - batch_size=batch_size + 1, - sequence_length=sequence_length + add_second_input, - sequence_length2=sequence_length2 + add_second_input, + batch_size=(batch_size + 1) if add_second_input > 0 else 1, + sequence_length=sequence_length + 1, + sequence_length2=sequence_length2 + + (add_second_input if add_second_input > 0 else -add_second_input), dynamic_rope=dynamic_rope, num_key_value_heads=num_key_value_heads, head_dim=head_dim, cls_cache=cls_cache, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/text_to_image.py b/onnx_diagnostic/tasks/text_to_image.py index 7426f48a..fd49fe5d 100644 --- a/onnx_diagnostic/tasks/text_to_image.py +++ b/onnx_diagnostic/tasks/text_to_image.py @@ -58,6 +58,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -67,6 +70,7 @@ def get_inputs( in_channels=in_channels, sample_size=sample_size, cross_attention_dim=cross_attention_dim, + add_second_input=0, **kwargs, )["inputs"] return res diff --git a/onnx_diagnostic/tasks/zero_shot_image_classification.py b/onnx_diagnostic/tasks/zero_shot_image_classification.py index 80bc9ff4..61fee29e 100644 --- a/onnx_diagnostic/tasks/zero_shot_image_classification.py +++ b/onnx_diagnostic/tasks/zero_shot_image_classification.py @@ -87,6 +87,9 @@ def get_inputs( ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + assert ( + add_second_input > 0 + ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, @@ -97,6 +100,7 @@ def get_inputs( input_height=input_height, input_channels=input_channels, batch_size_image=batch_size_image + 1, + add_second_input=0, **kwargs, )["inputs"] return res From f3549c61ded9fbd657eced2e3aa35c93abaaf4a0 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 4 Jul 2025 17:10:06 +0200 Subject: [PATCH 8/8] fix ut --- _unittests/ut_torch_models/test_hghub_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 1b051a50..39657e8b 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -17,7 +17,7 @@ class TestHuggingFaceHubModel(ExtTestCase): @hide_stdout() def test_get_untrained_model_with_inputs_tiny_llm(self): mid = "arnir0/Tiny-LLM" - data = get_untrained_model_with_inputs(mid, verbose=1) + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=0) self.assertEqual( set(data), {