From e94564f14efc5c6fb7090929a21e4eb33b2832a8 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 16 Sep 2025 17:40:21 +0200 Subject: [PATCH 1/3] Add opset to folder name --- onnx_diagnostic/torch_models/validate.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 42d60c1e..cc7a7533 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -113,6 +113,7 @@ def _make_folder_name( dtype: Optional[Union[str, torch.dtype]] = None, device: Optional[Union[str, torch.device]] = None, subfolder: Optional[str] = None, + opset: Optional[int] = None, ) -> str: "Creates a filename unique based on the given options." els = [model_id.replace("/", "_")] @@ -136,6 +137,8 @@ def _make_folder_name( else: raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}") els.append(sdev) + if opset is not None: + els.append(f"op{opset}") return "-".join(els) @@ -412,7 +415,13 @@ def validate_model( folder_name = None if dump_folder: folder_name = _make_folder_name( - model_id, exporter, optimization, dtype=dtype, device=device, subfolder=subfolder + model_id, + exporter, + optimization, + dtype=dtype, + device=device, + subfolder=subfolder, + opset=opset, ) dump_folder = os.path.join(dump_folder, folder_name) if not os.path.exists(dump_folder): From 6e387091430999525ae73b1ac25dabb2a3090b18 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 16 Sep 2025 19:09:08 +0200 Subject: [PATCH 2/3] patch --- CHANGELOGS.rst | 2 ++ .../onnx_export_errors.py | 16 ++++++++++ .../patches/patch_transformers.py | 32 +++++++++++++++++++ onnx_diagnostic/torch_models/validate.py | 2 ++ 4 files changed, 52 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9ab5c3b2..a0809cb4 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.7.11 ++++++ +* :pr:`220`: adds a patch for PR `#40791 `_ in transformers + 0.7.10 ++++++ diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index f115718d..fad68b4c 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -426,6 +426,14 @@ def torch_export_patches( patch_transformers_list, verbose=verbose ) + if patch_transformers_list.patch_is_initialized: + if verbose: + print( + "[torch_export_patches] patches " + "transformers.cache_utils.CacheLayerMixin.is_initialized" + ) + patch_transformers_list.apply_patch_for_is_initialized() + if ( masking_utils and patch_transformers_list.patch_masking_utils @@ -689,6 +697,14 @@ def torch_export_patches( "in ALL_MASK_ATTENTION_FUNCTIONS" ) + if patch_transformers_list.patch_is_initialized: + if verbose: + print( + "[torch_export_patches] restores " + "transformers.cache_utils.CacheLayerMixin.is_initialized" + ) + patch_transformers_list.disable_patch_for_is_initialized() + ######## # caches ######## diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 9a96a2b7..13154b5b 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -35,6 +35,38 @@ from ...ext_test_case import has_transformers from ...helpers.torch_helper import is_torchdynamo_exporting +patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99") + + +def _get_is_initialized(self): + return self.keys is not None + + +def _set_is_initialized(self, value): + assert (value and self.keys is not None) or (not value and self.keys is None), ( + f"The patch does not set is_initialized but checks the it is consistent with " + f"``self.keys is not None``, value={value}, " + f"self.keys is not None={self.keys is not None}" + ) + + +def apply_patch_for_is_initialized(): + """ + Fixes export issues introduced by PR `40791 `_. + The attribute is_initialized does not seem to be captured by :func:`torch.export.export`. + """ + if patch_is_initialized: + transformers.cache_utils.CacheLayerMixin.is_initialized = property( + _get_is_initialized, _set_is_initialized + ) + + +def disable_patch_for_is_initialized(): + """Disables the patch applied by function :func:`applies_patch_for_is_initialized`.""" + if patch_is_initialized: + delattr(transformers.cache_utils.CacheLayerMixin, "is_initialized") + + if patch_masking_utils: # Introduced in 4.52 from transformers.masking_utils import ( diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index cc7a7533..d6b3994f 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -1518,6 +1518,8 @@ def call_torch_export_custom( "default+onnxruntime+os_ort", None, } + if optimization == "none": + optimization = "" assert ( optimization in available ), f"unexpected value for optimization={optimization}, available={available}" From 18d125778e3a366a7f6dccc0be8ecf3a250f9f32 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 16 Sep 2025 19:30:07 +0200 Subject: [PATCH 3/3] simple patch --- .../onnx_export_errors.py | 16 ---------- .../patches/patch_transformers.py | 31 ++----------------- 2 files changed, 2 insertions(+), 45 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index fad68b4c..f115718d 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -426,14 +426,6 @@ def torch_export_patches( patch_transformers_list, verbose=verbose ) - if patch_transformers_list.patch_is_initialized: - if verbose: - print( - "[torch_export_patches] patches " - "transformers.cache_utils.CacheLayerMixin.is_initialized" - ) - patch_transformers_list.apply_patch_for_is_initialized() - if ( masking_utils and patch_transformers_list.patch_masking_utils @@ -697,14 +689,6 @@ def torch_export_patches( "in ALL_MASK_ATTENTION_FUNCTIONS" ) - if patch_transformers_list.patch_is_initialized: - if verbose: - print( - "[torch_export_patches] restores " - "transformers.cache_utils.CacheLayerMixin.is_initialized" - ) - patch_transformers_list.disable_patch_for_is_initialized() - ######## # caches ######## diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 13154b5b..e95a0a47 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -38,35 +38,6 @@ patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99") -def _get_is_initialized(self): - return self.keys is not None - - -def _set_is_initialized(self, value): - assert (value and self.keys is not None) or (not value and self.keys is None), ( - f"The patch does not set is_initialized but checks the it is consistent with " - f"``self.keys is not None``, value={value}, " - f"self.keys is not None={self.keys is not None}" - ) - - -def apply_patch_for_is_initialized(): - """ - Fixes export issues introduced by PR `40791 `_. - The attribute is_initialized does not seem to be captured by :func:`torch.export.export`. - """ - if patch_is_initialized: - transformers.cache_utils.CacheLayerMixin.is_initialized = property( - _get_is_initialized, _set_is_initialized - ) - - -def disable_patch_for_is_initialized(): - """Disables the patch applied by function :func:`applies_patch_for_is_initialized`.""" - if patch_is_initialized: - delattr(transformers.cache_utils.CacheLayerMixin, "is_initialized") - - if patch_masking_utils: # Introduced in 4.52 from transformers.masking_utils import ( @@ -245,6 +216,8 @@ def lazy_initialization(self, key_states: torch.Tensor): new_shape[-2] = 0 self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) + if patch_is_initialized: + self.is_initialized = True def _patch_make_causal_mask(