From f703ad67cd9a3d35d384bbc6891a68ad42552f74 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 22 Sep 2025 16:38:45 +0200 Subject: [PATCH 1/4] adds speed up --- onnx_diagnostic/torch_models/validate.py | 60 ++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 0e2ff083..4273ccaf 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -841,6 +841,8 @@ def validate_model( ) summary.update(summary_valid) + _compute_final_statistics(summary) + if verbose: print("[validate_model] -- done (final)") if dump_stats: @@ -853,15 +855,24 @@ def validate_model( def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]: """Computes some statistics on the model itself.""" onx = onnx.load(onnx_filename, load_external_data=False) + cache_functions = {(f.domain, f.name): f for f in onx.functions} + local_domains = set(f.domain for f in onx.functions) def node_iter(proto): if isinstance(proto, onnx.ModelProto): - yield from node_iter(proto.graph) for f in proto.functions: yield from node_iter(f) + yield from node_iter(proto.graph) elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)): for node in proto.node: yield node + + # Let's inline the function + key = node.domain, node.op_type + if key in cache_functions: + yield from node_iter(cache_functions[key]) + + # Let's continue for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: yield from node_iter(att.g) @@ -879,6 +890,11 @@ def node_iter(proto): n_nodes += 1 if proto.op_type != "Constant": n_nodes_nocst += 1 + if proto.domain in local_domains: + key = "n_node_local_function" + if key not in counts: + counts[key] = 0 + counts[key] += 1 else: key = f"n_node_initializer_{proto.data_type}" @@ -1400,7 +1416,7 @@ def call_torch_export_onnx( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - available = {None, "", "ir", "os_ort"} + available = {None, "", "ir", "os_ort", "ir+default"} assert ( optimization in available ), f"unexpected value for optimization={optimization}, available={available}" @@ -1490,11 +1506,31 @@ def call_torch_export_onnx( print(epo) print("[call_torch_export_onnx] -- End of ONNXProgram") - if optimization in {"ir", "os_ort"}: + if optimization in {"ir", "os_ort", "ir+default"}: if verbose: print(f"[call_torch_export_onnx] starts optimization={optimization!r}...") if optimization == "ir": label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize()) + elif optimization == "ir+default": + import onnxscript + from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions + + def _ir_default_opt(epo): + onnxscript.optimizer.optimize_ir(epo.model) + onx = epo.model_proto + # not very efficient + gr = GraphBuilder( + onx, + infer_shapes_options=True, + optimization_options=OptimizationOptions(patterns="default"), + ) + cont = gr.to_onnx(large_model=True) + epo.model = cont.to_ir() + + label, f_optim = "export_onnx_opt_ir_default", ( + lambda epo=epo: _ir_default_opt(epo) + ) + else: import onnxscript import onnxscript.rewriter.ort_fusions as ort_fusions @@ -1893,3 +1929,21 @@ def run_ort_fusion( f"opt_ort_{model_type}_duration": duration, f"opt_ort_{model_type}_duration_save": d, }, {f"opt_ort_{model_type}": output_path} + + +def _compute_final_statistics(summary: Dict[str, Any]): + """ + Updates inline the list of statistics. It adds: + + - speedup + """ + stats = {} + if ( + "time_run_latency" in summary + and "time_run_onnx_ort_latency" in summary + and summary["time_run_onnx_ort_latency"] > 0 + ): + stats["stat_estimated_speedup_ort"] = ( + summary["time_run_latency"] / summary["time_run_onnx_ort_latency"] + ) + summary.update(stats) From a7b3d4a9bc2a279fd75b8da4c362b876e5929c4a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 22 Sep 2025 17:20:43 +0200 Subject: [PATCH 2/4] fix none value --- .../torch_models/hghub/model_inputs.py | 2 +- onnx_diagnostic/torch_models/validate.py | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 3ab2ec5f..dde656a1 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -189,7 +189,7 @@ def get_untrained_model_with_inputs( f"subfolder={subfolder!r}" ) model = transformers.AutoModel.from_pretrained( - model_id, subfolder=subfolder, trust_remote_code=True, **mkwargs + model_id, subfolder=subfolder or "", trust_remote_code=True, **mkwargs ) if verbose: print( diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 4273ccaf..b9d808d9 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]: return new_cfg -def _preprocess_model_id(model_id, subfolder): +def _preprocess_model_id( + model_id: str, subfolder: str, same_as_pretrained: bool, use_pretrained: bool +) -> Tuple[str, str, bool, bool]: if subfolder or "//" not in model_id: - return model_id, subfolder + return model_id, subfolder, same_as_pretrained, use_pretrained spl = model_id.split("//") + if spl[-1] == "pretrained": + return _preprocess_model_id("//".join(spl[:-1]), "", True, True) if spl[-1] in {"transformer", "vae"}: # known subfolder - return "//".join(spl[:-1]), spl[-1] - return model_id, subfolder + return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained + return model_id, subfolder, same_as_pretrained, use_pretrained def validate_model( @@ -384,7 +388,12 @@ def validate_model( if ``runtime == 'ref'``, ``orteval10`` increases the verbosity. """ - model_id, subfolder = _preprocess_model_id(model_id, subfolder) + model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id( + model_id, + subfolder, + same_as_pretrained=same_as_pretrained, + use_pretrained=use_pretrained, + ) if isinstance(patch, bool): patch_kwargs = ( dict(patch_transformers=True, patch_diffusers=True, patch=True) From f620a90f85b7166d8d805d243a4bf760c1dc1721 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 22 Sep 2025 17:21:44 +0200 Subject: [PATCH 3/4] changes --- CHANGELOGS.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index a7f634b1..537e7906 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,9 +4,9 @@ Change Logs 0.7.12 ++++++ +* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate * :pr:`226`: fix input order for models created with modelbuilder - 0.7.11 ++++++ From 0f05f48054e52ceac8544cc5e37bc0deda75c2e4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 22 Sep 2025 17:35:34 +0200 Subject: [PATCH 4/4] fix mypy --- onnx_diagnostic/torch_models/validate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index b9d808d9..e2c5b9ec 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -265,8 +265,8 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]: def _preprocess_model_id( - model_id: str, subfolder: str, same_as_pretrained: bool, use_pretrained: bool -) -> Tuple[str, str, bool, bool]: + model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool +) -> Tuple[str, Optional[str], bool, bool]: if subfolder or "//" not in model_id: return model_id, subfolder, same_as_pretrained, use_pretrained spl = model_id.split("//")