From 03d60cbfa46ad1da744f3c223b3da4bb4118e14a Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 17 May 2025 11:11:39 +0200 Subject: [PATCH 1/6] check sdpa is working --- onnx_diagnostic/torch_models/hghub/model_inputs.py | 12 ++++++++++++ onnx_diagnostic/torch_models/test_helper.py | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 95bf221b..29338dcc 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -69,6 +69,7 @@ def get_untrained_model_with_inputs( subfolder=subfolder, **(model_kwargs or {}), ) + if hasattr(config, "architecture") and config.architecture: archs = [config.architecture] if type(config) is dict: @@ -116,6 +117,17 @@ def get_untrained_model_with_inputs( if mkwargs: update_config(config, mkwargs) + # SDPA + if model_kwargs and "attn_implementation" in model_kwargs: + if hasattr(config, "_attn_implementation_autoset"): + config._attn_implementation_autoset = False + config._attn_implementation = model_kwargs["attn_implementation"] + if verbose: + print( + f"[get_untrained_model_with_inputs] config._attn_implementation=" + f"{config._attn_implementation!r}" + ) + # input kwargs kwargs, fct = random_input_kwargs(config, task) if verbose: diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index a9425ae5..0f733474 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -482,6 +482,7 @@ def validate_model( verbose=verbose, optimization=optimization, do_run=do_run, + dump_folder=dump_folder, ) else: data["inputs_export"] = data["inputs"] @@ -493,6 +494,7 @@ def validate_model( verbose=verbose, optimization=optimization, do_run=do_run, + dump_folder=dump_folder, ) summary.update(summary_export) @@ -618,6 +620,7 @@ def call_exporter( verbose: int = 0, optimization: Optional[str] = None, do_run: bool = False, + dump_folder: Optional[None] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Calls an exporter on a model; @@ -629,6 +632,7 @@ def call_exporter( :param verbose: verbosity :param optimization: optimization to do :param do_run: runs and compute discrepancies + :param dump_folder: to dump additional information :return: two dictionaries, one with some metrics, another one with whatever the function produces """ @@ -661,6 +665,7 @@ def call_exporter( quiet=quiet, verbose=verbose, optimization=optimization, + dump_folder=dump_folder, ) return summary, data raise NotImplementedError( @@ -1045,6 +1050,7 @@ def call_torch_export_custom( quiet: bool = False, verbose: int = 0, optimization: Optional[str] = None, + dump_folder: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Exports a model into onnx. @@ -1056,6 +1062,7 @@ def call_torch_export_custom( :param quiet: catch exception or not :param verbose: verbosity :param optimization: optimization to do + :param dump_folder: to store additional information :return: two dictionaries, one with some metrics, another one with whatever the function produces """ @@ -1113,6 +1120,7 @@ def call_torch_export_custom( decomposition_table=( "default" if "-default" in exporter else ("all" if "-all" in exporter else None) ), + save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None), ) options = OptimizationOptions(patterns=optimization) if optimization else None model = data["model"] From 9553a9023147d8729a7085caad586af31aae2b70 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 17 May 2025 11:14:47 +0200 Subject: [PATCH 2/6] fix mypy --- onnx_diagnostic/torch_models/hghub/model_inputs.py | 4 ++-- onnx_diagnostic/torch_models/test_helper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 29338dcc..5e452faf 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -121,11 +121,11 @@ def get_untrained_model_with_inputs( if model_kwargs and "attn_implementation" in model_kwargs: if hasattr(config, "_attn_implementation_autoset"): config._attn_implementation_autoset = False - config._attn_implementation = model_kwargs["attn_implementation"] + config._attn_implementation = model_kwargs["attn_implementation"] # type: ignore[union-attr] if verbose: print( f"[get_untrained_model_with_inputs] config._attn_implementation=" - f"{config._attn_implementation!r}" + f"{config._attn_implementation!r}" # type: ignore[union-attr] ) # input kwargs diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 0f733474..9b4afe7c 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -620,7 +620,7 @@ def call_exporter( verbose: int = 0, optimization: Optional[str] = None, do_run: bool = False, - dump_folder: Optional[None] = None, + dump_folder: Optional[str] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Calls an exporter on a model; From facf3295f0d04a90bae5a86f4c8168fa6ed3ff43 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 19 May 2025 11:41:09 +0200 Subject: [PATCH 3/6] more infos --- onnx_diagnostic/torch_models/test_helper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 9b4afe7c..8563e5bd 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -1,6 +1,7 @@ import datetime import inspect import os +import sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import onnx @@ -375,8 +376,11 @@ def validate_model( summary[f"model_{k.replace('_','')}"] = data[k] summary["model_inputs_opionts"] = str(input_options or "") summary["model_inputs"] = string_type(data["inputs"], with_shape=True) - summary["model_shapes"] = string_type(str(data["dynamic_shapes"])) + summary["model_shapes"] = string_type(data["dynamic_shapes"]) summary["model_class"] = data["model"].__class__.__name__ + summary["model_module"] = data["model"].__class__.__module__ + if summary["model_module"] in sys.modules: + summary["model_file"] = sys.modules[summary["model_module"]].__file__ summary["model_config_class"] = data["configuration"].__class__.__name__ summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") summary["model_id"] = model_id From b3e9c119de75fd4b5e20b135b69ff353beb5c97b Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 19 May 2025 11:47:02 +0200 Subject: [PATCH 4/6] mypy --- onnx_diagnostic/torch_models/test_helper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 8563e5bd..210be1c2 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -379,7 +379,9 @@ def validate_model( summary["model_shapes"] = string_type(data["dynamic_shapes"]) summary["model_class"] = data["model"].__class__.__name__ summary["model_module"] = data["model"].__class__.__module__ - if summary["model_module"] in sys.modules: + if summary["model_module"] in sys.modules and isinstance( + sys.modules[summary["model_module"]].__file__, str + ): summary["model_file"] = sys.modules[summary["model_module"]].__file__ summary["model_config_class"] = data["configuration"].__class__.__name__ summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") From 61cb2d1fb49acca8ab85665242c2484e6db81485 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 19 May 2025 11:51:56 +0200 Subject: [PATCH 5/6] fix --- onnx_diagnostic/torch_models/test_helper.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 210be1c2..1bf199e8 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -378,11 +378,9 @@ def validate_model( summary["model_inputs"] = string_type(data["inputs"], with_shape=True) summary["model_shapes"] = string_type(data["dynamic_shapes"]) summary["model_class"] = data["model"].__class__.__name__ - summary["model_module"] = data["model"].__class__.__module__ - if summary["model_module"] in sys.modules and isinstance( - sys.modules[summary["model_module"]].__file__, str - ): - summary["model_file"] = sys.modules[summary["model_module"]].__file__ + summary["model_module"] = str(data["model"].__class__.__module__) + if summary["model_module"] in sys.modules: + summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) summary["model_config_class"] = data["configuration"].__class__.__name__ summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") summary["model_id"] = model_id From 69b76cc18dc692bd22eb8885b55c24f66ed3f58f Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 19 May 2025 11:59:11 +0200 Subject: [PATCH 6/6] mypy --- onnx_diagnostic/torch_models/test_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 1bf199e8..d37371a9 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -380,7 +380,7 @@ def validate_model( summary["model_class"] = data["model"].__class__.__name__ summary["model_module"] = str(data["model"].__class__.__module__) if summary["model_module"] in sys.modules: - summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) + summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index] summary["model_config_class"] = data["configuration"].__class__.__name__ summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") summary["model_id"] = model_id