diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 95bf221b..5e452faf 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -69,6 +69,7 @@ def get_untrained_model_with_inputs( subfolder=subfolder, **(model_kwargs or {}), ) + if hasattr(config, "architecture") and config.architecture: archs = [config.architecture] if type(config) is dict: @@ -116,6 +117,17 @@ def get_untrained_model_with_inputs( if mkwargs: update_config(config, mkwargs) + # SDPA + if model_kwargs and "attn_implementation" in model_kwargs: + if hasattr(config, "_attn_implementation_autoset"): + config._attn_implementation_autoset = False + config._attn_implementation = model_kwargs["attn_implementation"] # type: ignore[union-attr] + if verbose: + print( + f"[get_untrained_model_with_inputs] config._attn_implementation=" + f"{config._attn_implementation!r}" # type: ignore[union-attr] + ) + # input kwargs kwargs, fct = random_input_kwargs(config, task) if verbose: diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index a9425ae5..d37371a9 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -1,6 +1,7 @@ import datetime import inspect import os +import sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import onnx @@ -375,8 +376,11 @@ def validate_model( summary[f"model_{k.replace('_','')}"] = data[k] summary["model_inputs_opionts"] = str(input_options or "") summary["model_inputs"] = string_type(data["inputs"], with_shape=True) - summary["model_shapes"] = string_type(str(data["dynamic_shapes"])) + summary["model_shapes"] = string_type(data["dynamic_shapes"]) summary["model_class"] = data["model"].__class__.__name__ + summary["model_module"] = str(data["model"].__class__.__module__) + if summary["model_module"] in sys.modules: + summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index] summary["model_config_class"] = data["configuration"].__class__.__name__ summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") summary["model_id"] = model_id @@ -482,6 +486,7 @@ def validate_model( verbose=verbose, optimization=optimization, do_run=do_run, + dump_folder=dump_folder, ) else: data["inputs_export"] = data["inputs"] @@ -493,6 +498,7 @@ def validate_model( verbose=verbose, optimization=optimization, do_run=do_run, + dump_folder=dump_folder, ) summary.update(summary_export) @@ -618,6 +624,7 @@ def call_exporter( verbose: int = 0, optimization: Optional[str] = None, do_run: bool = False, + dump_folder: Optional[str] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Calls an exporter on a model; @@ -629,6 +636,7 @@ def call_exporter( :param verbose: verbosity :param optimization: optimization to do :param do_run: runs and compute discrepancies + :param dump_folder: to dump additional information :return: two dictionaries, one with some metrics, another one with whatever the function produces """ @@ -661,6 +669,7 @@ def call_exporter( quiet=quiet, verbose=verbose, optimization=optimization, + dump_folder=dump_folder, ) return summary, data raise NotImplementedError( @@ -1045,6 +1054,7 @@ def call_torch_export_custom( quiet: bool = False, verbose: int = 0, optimization: Optional[str] = None, + dump_folder: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Exports a model into onnx. @@ -1056,6 +1066,7 @@ def call_torch_export_custom( :param quiet: catch exception or not :param verbose: verbosity :param optimization: optimization to do + :param dump_folder: to store additional information :return: two dictionaries, one with some metrics, another one with whatever the function produces """ @@ -1113,6 +1124,7 @@ def call_torch_export_custom( decomposition_table=( "default" if "-default" in exporter else ("all" if "-all" in exporter else None) ), + save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None), ) options = OptimizationOptions(patterns=optimization) if optimization else None model = data["model"]