From a6b9d871fc4b9885103634614ae856865d650d36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 21 Oct 2025 15:12:59 +0200 Subject: [PATCH 1/8] add function to generate dummy code --- _doc/api/torch_models/index.rst | 1 + .../ut_torch_models/test_code_sample.py | 41 +++++ onnx_diagnostic/_command_lines_parser.py | 9 +- onnx_diagnostic/torch_models/code_sample.py | 166 ++++++++++++++++++ onnx_diagnostic/torch_models/validate.py | 3 + 5 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 _unittests/ut_torch_models/test_code_sample.py create mode 100644 onnx_diagnostic/torch_models/code_sample.py diff --git a/_doc/api/torch_models/index.rst b/_doc/api/torch_models/index.rst index 3279e918..9a472cec 100644 --- a/_doc/api/torch_models/index.rst +++ b/_doc/api/torch_models/index.rst @@ -5,6 +5,7 @@ onnx_diagnostic.torch_models :maxdepth: 1 :caption: submodules + code_sample hghub/index llms validate diff --git a/_unittests/ut_torch_models/test_code_sample.py b/_unittests/ut_torch_models/test_code_sample.py new file mode 100644 index 00000000..91f197b9 --- /dev/null +++ b/_unittests/ut_torch_models/test_code_sample.py @@ -0,0 +1,41 @@ +import unittest +import subprocess +import sys +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_torch, + requires_experimental, + requires_transformers, +) +from onnx_diagnostic.torch_models.code_sample import code_sample + + +class TestCodeSample(ExtTestCase): + @requires_transformers("4.53") + @requires_torch("2.7.99") + @requires_experimental() + # @hide_stdout() + def test_code_sample_tiny_llm(self): + code = code_sample( + "arnir0/Tiny-LLM", + verbose=2, + exporter="custom", + patch=True, + dump_folder="dump_test/validate_tiny_llm", + dtype="float16", + device="cuda", + ) + filename = self.get_dump_file("test_code_sample_tiny_llm.py") + with open(filename, "w") as f: + f.write(code) + cmds = [sys.executable, "-u", filename] + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + _out, err = res + st = err.decode("ascii", errors="ignore") + self.assertNotIn("Traceback", st) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index cf7cc863..6010ec96 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -557,7 +557,13 @@ def get_parser_validate() -> ArgumentParser: "--quiet-input-sets", default="", help="Avoids raising an exception when an input sets does not work with " - "the exported model, example: --quiet-input-sets=inputs,inputs22", + "the exported model.\nExample: --quiet-input-sets=inputs,inputs22", + ) + parser.add_argument( + "--sample-code", + default="", + help="Generates a sample code to export a model without " + "without this package.\nExample --sample-code=export_sample", ) return parser @@ -624,6 +630,7 @@ def _cmd_validate(argv: List[Any]): output_names=( None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",") ), + sample_code=args.sample_code, ) print("") print("-- summary --") diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py new file mode 100644 index 00000000..9abf5b1e --- /dev/null +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -0,0 +1,166 @@ +import textwrap +import torch +from typing import Any, Dict, List, Optional, Union + + +CODE_SAMPLES = { + "imports": "from typing import Any\nimport torch", + "get_model_with_inputs": textwrap.dedent( + """ + def get_model_with_inputs( + model_id:str, + subfolder: str | None = None, + dtype: str | torch.dtype | None = None, + device: str | torch.device | None = None, + same_as_pretrained: bool = False, + use_pretrained: bool = False, + input_options: dict[str, Any] | None = None, + model_options: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if use_pretrained: + import transformers + assert same_as_pretrained, ( + "same_as_pretrained must be True if use_pretrained is True" + ) + # tokenizer = AutoTokenizer.from_pretrained(model_path) + model = transformers.AutoModel.from_pretrained( + model_id, + trust_remote_code=True, + subfolder=subfolder, + dtype=dtype, + device=device, + ) + data = {"model": model} + assert not input_options, "Not implemented yet with input_options{input_options}" + assert not model_options, "Not implemented yet with input_options{model_options}" + else: + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + data = get_untrained_model_with_inputs( + model_id, + use_pretrained=use_pretrained, + same_as_pretrained=same_as_pretrained, + inputs_kwargs=input_options, + model_kwargs=model_options, + subfolder=subfolder, + add_second_input=False, + ) + if dtype: + data["model"] = data["model"].to( + getattr(torch, dtype) if isinstance(dtype, str) else dtype + ) + if device: + data["model"] = data["model"].to(device) + return data["model"] + """ + ), +} + + +def code_sample( + model_id: str, + task: Optional[str] = None, + do_run: bool = False, + exporter: Optional[str] = None, + do_same: bool = False, + verbose: int = 0, + dtype: Optional[Union[str, torch.dtype]] = None, + device: Optional[Union[str, torch.device]] = None, + same_as_pretrained: bool = False, + use_pretrained: bool = False, + optimization: Optional[str] = None, + quiet: bool = False, + patch: Union[bool, str, Dict[str, bool]] = False, + rewrite: bool = False, + stop_if_static: int = 1, + dump_folder: Optional[str] = None, + drop_inputs: Optional[List[str]] = None, + input_options: Optional[Dict[str, Any]] = None, + model_options: Optional[Dict[str, Any]] = None, + subfolder: Optional[str] = None, + opset: Optional[int] = None, + runtime: str = "onnxruntime", + output_names: Optional[List[str]] = None, +) -> str: + """ + This generates a code to export a model with the proper settings. + + :param model_id: model id to validate + :param task: task used to generate the necessary inputs, + can be left empty to use the default task for this model + if it can be determined + :param do_run: checks the model works with the defined inputs + :param exporter: exporter the model using this exporter, + available list: ``export-strict``, ``export-nostrict``, ... + see below + :param do_same: checks the discrepancies of the exported model + :param verbose: verbosity level + :param dtype: uses this dtype to check the model + :param device: do the verification on this device + :param same_as_pretrained: use a model equivalent to the trained, + this is not always possible + :param use_pretrained: use the trained model, not the untrained one + :param optimization: optimization to apply to the exported model, + depend on the the exporter + :param quiet: if quiet, catches exception if any issue + :param patch: applies patches (``patch_transformers=True, path_diffusers=True``) + if True before exporting + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`, + a string can be used to specify only one of them + :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting, + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` + :param stop_if_static: stops if a dynamic dimension becomes static, + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` + :param dump_folder: dumps everything in a subfolder of this one + :param drop_inputs: drops this list of inputs (given their names) + :param input_options: additional options to define the dummy inputs + used to export + :param model_options: additional options when creating the model such as + ``num_hidden_layers`` or ``attn_implementation`` + :param subfolder: version or subfolders to uses when retrieving a model id + :param opset: onnx opset to use for the conversion + :param runtime: onnx runtime to use to check about discrepancies, + possible values ``onnxruntime``, ``torch``, ``orteval``, + ``orteval10``, ``ref`` only if `do_run` is true + :param output_names: output names the onnx exporter should use + :return: a code + + .. runpython:: + :showcode: + + from onnx_diagnostic.torch_models.code_sample import code_sample + + print(code_sample("arnir0/Tiny-LLM")) + """ + args = [f"{model_id!r}"] + if subfolder: + args.append(f"subfolder={subfolder!r}") + if dtype: + args.append(f"dtype={dtype!r}") + if device: + args.append(f"device={device!r}") + if same_as_pretrained: + args.append(f"same_as_pretrained={same_as_pretrained!r}") + if use_pretrained: + args.append(f"use_pretrained={use_pretrained!r}") + if input_options: + args.append(f"input_options={input_options!r}") + if model_options: + args.append(f"model_options={model_options!r}") + model_args = ", ".join(args) + pieces = [ + CODE_SAMPLES["imports"], + CODE_SAMPLES["get_model_with_inputs"], + textwrap.dedent( + f""" + model = get_model_with_inputs({model_args}) + """ + ), + ] + code = "\n".join(pieces) + try: + import black + except ImportError: + # No black formatting. + return code + + return black.format_str(code, mode=black.Mode()) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index fca1cfbd..5db8515a 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -323,6 +323,7 @@ def validate_model( output_names: Optional[List[str]] = None, ort_logs: bool = False, quiet_input_sets: Optional[Set[str]] = None, + sample_code: Optional[str] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Validates a model. @@ -379,6 +380,8 @@ def validate_model( :param ort_logs: increases onnxruntime verbosity when creating the session :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set even if quiet is False + :param sample_code: if specified, the function generates a code + which exports this model id without this package. :return: two dictionaries, one with some metrics, another one with whatever the function produces From a29f02c8798d147d6c052707a41a1c825cb71571 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 21 Oct 2025 17:18:55 +0200 Subject: [PATCH 2/8] changes --- .../ut_torch_models/test_code_sample.py | 37 ++++- onnx_diagnostic/torch_models/code_sample.py | 154 +++++++++++++++++- onnx_diagnostic/torch_models/validate.py | 46 ++++-- 3 files changed, 216 insertions(+), 21 deletions(-) diff --git a/_unittests/ut_torch_models/test_code_sample.py b/_unittests/ut_torch_models/test_code_sample.py index 91f197b9..fe57819d 100644 --- a/_unittests/ut_torch_models/test_code_sample.py +++ b/_unittests/ut_torch_models/test_code_sample.py @@ -1,6 +1,7 @@ import unittest import subprocess import sys +import torch from onnx_diagnostic.ext_test_case import ( ExtTestCase, hide_stdout, @@ -8,7 +9,8 @@ requires_experimental, requires_transformers, ) -from onnx_diagnostic.torch_models.code_sample import code_sample +from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache +from onnx_diagnostic.torch_models.code_sample import code_sample, make_code_for_inputs class TestCodeSample(ExtTestCase): @@ -24,7 +26,7 @@ def test_code_sample_tiny_llm(self): patch=True, dump_folder="dump_test/validate_tiny_llm", dtype="float16", - device="cuda", + device="cpu", ) filename = self.get_dump_file("test_code_sample_tiny_llm.py") with open(filename, "w") as f: @@ -36,6 +38,37 @@ def test_code_sample_tiny_llm(self): st = err.decode("ascii", errors="ignore") self.assertNotIn("Traceback", st) + def test_make_code_for_inputs(self): + values = [ + ("dict(a=True)", dict(a=True)), + ("dict(a=1)", dict(a=1)), + ( + "dict(a=torch.randint(3, size=(2,), dtype=torch.int64))", + dict(a=torch.tensor([2, 3], dtype=torch.int64)), + ), + ( + "dict(a=torch.rand((2,), dtype=torch.float16))", + dict(a=torch.tensor([2, 3], dtype=torch.float16)), + ), + ] + for res, inputs in values: + self.assertEqual(res, make_code_for_inputs(inputs)) + + res = make_code_for_inputs( + dict( + cc=make_dynamic_cache( + [(torch.randn(2, 2, 2, 2), torch.randn(2, 2, 2, 2)) for i in range(2)] + ) + ) + ) + self.assertEqual( + "dict(cc=make_dynamic_cache([(torch.rand((2, 2, 2, 2), " + "dtype=torch.float32),torch.rand((2, 2, 2, 2), dtype=torch.float32)), " + "(torch.rand((2, 2, 2, 2), dtype=torch.float32)," + "torch.rand((2, 2, 2, 2), dtype=torch.float32))]))", + res, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index 9abf5b1e..f5d7aafe 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -1,6 +1,12 @@ +import os import textwrap import torch -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union +from ..helpers import flatten_object +from ..helpers.torch_helper import to_any +from .hghub.model_inputs import _preprocess_model_id +from .hghub import get_untrained_model_with_inputs +from .validate import filter_inputs, make_patch_kwargs CODE_SAMPLES = { @@ -56,6 +62,94 @@ def get_model_with_inputs( } +def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str: + """ + Creates a code to generate random inputs. + + :param inputs: dictionary + :return: code + """ + codes = [] + for k, v in inputs.items(): + if isinstance(v, (int, bool, float)): + code = f"{k}={v}" + elif isinstance(v, torch.Tensor): + shape = tuple(map(int, v.shape)) + if v.dtype in (torch.int32, torch.int64): + code = f"{k}=torch.randint({v.max()}, size={shape}, dtype={v.dtype})" + elif v.dtype in (torch.float32, torch.float16, torch.bfloat16): + code = f"{k}=torch.rand({shape}, dtype={v.dtype})" + else: + raise ValueError(f"Unexpeted dtype = {v.dtype} for k={k!r}") + elif v.__class__.__name__ == "DynamicCache": + obj = flatten_object(v) + cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj] + va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])] + vas = ", ".join(va) + code = f"{k}=make_dynamic_cache([{vas}])" + else: + raise ValueError(f"Unexpected type {type(v)} for k={k!r}") + codes.append(code) + st = ", ".join(codes) + return f"dict({st})" + + +def make_export_code( + exporter: str, + optimization: Optional[str] = None, + patch_kwargs: Optional[Dict[str, Any]] = None, + stop_if_static: int = 0, + dump_folder: Optional[str] = None, + opset: int = 18, + dynamic_shapes: Optional[Dict[str, Any]] = None, + output_names: Optional[List[str]] = None, + verbose: int = 0, +) -> Tuple[str, str]: + args = [f"dynamic_shapes={dynamic_shapes}"] + if output_names: + args.append(f"output_names={output_names}") + if dump_folder: + filename = os.path.join(dump_folder, "model.onnx") + if exporter == "custom": + args.append(f"target_opset={opset}") + if optimization: + args.append(f"options=OptimizationOptions(pattern={optimization!r})") + args.append(f"large_model=True, filename={filename!r}") + sargs = ", ".join(args) + imports = [ + "from experimental_experiment.torch_interpreter import to_onnx", + "from experimental_experiment.xbuilder import OptimizationOptions", + ] + code = [f"onx = to_onnx(model, inputs, {sargs})"] + elif exporter == "onnx-dynamo": + args.append(f"opset={opset}") + if optimization: + args.append("options=OptimizationOptions(pattern={optimization!r})") + sargs = ", ".join(args) + imports = [] + code = [f"epo = torch.onnx.export(model, (), inputs, {sargs})"] + if optimization: + imports.append("import onnxscript") + code.extend( + [ + "ir_model = epo.to_ir()", + "onnxscript.optimizer.optimize_ir(ir_model)", + "ir_optimized = ort_fusions.optimize_for_ort(ir_model)", + "epo.model = ir_optimized", + ] + ) + if dump_folder: + code.append("epo.save({filename!r}") + else: + raise ValueError(f"Unexpected exporter {exporter!r}") + if not patch_kwargs: + return "\n".join(imports), "\n".join(code) + + imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches") + code = [f"with torch_export_patches(**{patch_kwargs}):", *[" " + _ for _ in code]] + return "\n".join(imports), "\n".join(code) + + def code_sample( model_id: str, task: Optional[str] = None, @@ -131,6 +225,44 @@ def code_sample( print(code_sample("arnir0/Tiny-LLM")) """ + model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id( + model_id, + subfolder, + same_as_pretrained=same_as_pretrained, + use_pretrained=use_pretrained, + ) + patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite) + + iop = input_options or {} + mop = model_options or {} + data = get_untrained_model_with_inputs( + model_id, + verbose=verbose, + task=task, + use_pretrained=use_pretrained, + same_as_pretrained=same_as_pretrained, + inputs_kwargs=iop, + model_kwargs=mop, + subfolder=subfolder, + add_second_input=False, + ) + if drop_inputs: + update = {} + for k in data: + if k.startswith("inputs"): + update[k], ds = filter_inputs( + data[k], + drop_names=drop_inputs, + model=data["model"], + dynamic_shapes=data["dynamic_shapes"], + ) + update["dynamic_shapes"] = ds + data.update(update) + + for k in data: + if k.startswith("inputs"): + update[k] = to_any(data[]) to_any + args = [f"{model_id!r}"] if subfolder: args.append(f"subfolder={subfolder!r}") @@ -147,14 +279,34 @@ def code_sample( if model_options: args.append(f"model_options={model_options!r}") model_args = ", ".join(args) + imports, exporter_code = make_export_code( + exporter=exporter, + patch_kwargs=patch_kwargs, + verbose=verbose, + optimization=optimization, + stop_if_static=stop_if_static, + dump_folder=dump_folder, + opset=opset, + ) + input_code = make_code_for_inputs(data["inputs"]) + cache_import = ( + "from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache" + if "dynamic_cache" in input_code + else "" + ) + pieces = [ CODE_SAMPLES["imports"], + imports, + cache_import, CODE_SAMPLES["get_model_with_inputs"], textwrap.dedent( f""" model = get_model_with_inputs({model_args}) """ ), + f"inputs = {input_code}", + exporter_code, ] code = "\n".join(pieces) try: diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 5db8515a..b430b668 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -293,6 +293,33 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]: return new_cfg +def make_patch_kwargs( + patch: Union[bool, str, Dict[str, bool]] = False, + rewrite: bool = False, +) -> Dict[str, Any]: + """Creates patch arguments.""" + default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True) + if isinstance(patch, bool): + patch_kwargs = default_patch if patch else dict(patch=False) + elif isinstance(patch, str): + patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420 + else: + assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}" + patch_kwargs = patch.copy() + if "patch" not in patch_kwargs: + if any(patch_kwargs.values()): + patch_kwargs["patch"] = True + elif len(patch) == 1 and patch.get("patch", False): + patch_kwargs.update(default_patch) + + assert not rewrite or patch_kwargs.get("patch", False), ( + f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} " + f"patch must be True to enable rewriting, " + f"if --patch=0 was specified on the command line, rewrites are disabled." + ) + return patch_kwargs + + def validate_model( model_id: str, task: Optional[str] = None, @@ -423,25 +450,8 @@ def validate_model( use_pretrained=use_pretrained, ) time_preprocess_model_id = time.perf_counter() - main_validation_begin - default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True) - if isinstance(patch, bool): - patch_kwargs = default_patch if patch else dict(patch=False) - elif isinstance(patch, str): - patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420 - else: - assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}" - patch_kwargs = patch.copy() - if "patch" not in patch_kwargs: - if any(patch_kwargs.values()): - patch_kwargs["patch"] = True - elif len(patch) == 1 and patch.get("patch", False): - patch_kwargs.update(default_patch) + patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite) - assert not rewrite or patch_kwargs.get("patch", False), ( - f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} " - f"patch must be True to enable rewriting, " - f"if --patch=0 was specified on the command line, rewrites are disabled." - ) summary = version_summary() summary.update( dict( From 3f2f6125ea975f587257404c4e6fdf1ef1cd3374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Oct 2025 00:51:12 +0200 Subject: [PATCH 3/8] fix dtype --- _unittests/ut_torch_models/test_code_sample.py | 2 +- onnx_diagnostic/torch_models/code_sample.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_torch_models/test_code_sample.py b/_unittests/ut_torch_models/test_code_sample.py index fe57819d..8a8066fe 100644 --- a/_unittests/ut_torch_models/test_code_sample.py +++ b/_unittests/ut_torch_models/test_code_sample.py @@ -17,7 +17,7 @@ class TestCodeSample(ExtTestCase): @requires_transformers("4.53") @requires_torch("2.7.99") @requires_experimental() - # @hide_stdout() + @hide_stdout() def test_code_sample_tiny_llm(self): code = code_sample( "arnir0/Tiny-LLM", diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index f5d7aafe..818b4316 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -259,9 +259,18 @@ def code_sample( update["dynamic_shapes"] = ds data.update(update) + update = {} for k in data: if k.startswith("inputs"): - update[k] = to_any(data[]) to_any + v = data[k] + if dtype: + update[k] = v = to_any( + v, getattr(torch, dtype) if isinstance(dtype, str) else dtype + ) + if device: + update[k] = v = to_any(v, device) + if update: + data.update(update) args = [f"{model_id!r}"] if subfolder: From b52e798ad2deed96b492c67391cfff0aec62aff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Oct 2025 02:01:48 +0200 Subject: [PATCH 4/8] fix dynamo --- .../ut_torch_models/test_code_sample.py | 32 +++++++++++++++-- onnx_diagnostic/torch_models/code_sample.py | 34 +++++++++++-------- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/_unittests/ut_torch_models/test_code_sample.py b/_unittests/ut_torch_models/test_code_sample.py index 8a8066fe..ff742099 100644 --- a/_unittests/ut_torch_models/test_code_sample.py +++ b/_unittests/ut_torch_models/test_code_sample.py @@ -18,17 +18,43 @@ class TestCodeSample(ExtTestCase): @requires_torch("2.7.99") @requires_experimental() @hide_stdout() - def test_code_sample_tiny_llm(self): + def test_code_sample_tiny_llm_custom(self): code = code_sample( "arnir0/Tiny-LLM", verbose=2, exporter="custom", patch=True, - dump_folder="dump_test/validate_tiny_llm", + dump_folder="dump_test/validate_tiny_llm_custom", dtype="float16", device="cpu", + optimization="default", ) - filename = self.get_dump_file("test_code_sample_tiny_llm.py") + filename = self.get_dump_file("test_code_sample_tiny_llm_custom.py") + with open(filename, "w") as f: + f.write(code) + cmds = [sys.executable, "-u", filename] + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + _out, err = res + st = err.decode("ascii", errors="ignore") + self.assertNotIn("Traceback", st) + + @requires_transformers("4.53") + @requires_torch("2.7.99") + @requires_experimental() + @hide_stdout() + def test_code_sample_tiny_llm_dynamo(self): + code = code_sample( + "arnir0/Tiny-LLM", + verbose=2, + exporter="onnx-dynamo", + patch=True, + dump_folder="dump_test/validate_tiny_llm_dynamo", + dtype="float16", + device="cpu", + optimization="ir", + ) + filename = self.get_dump_file("test_code_sample_tiny_llm_dynamo.py") with open(filename, "w") as f: f.write(code) cmds = [sys.executable, "-u", filename] diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index 818b4316..1368c9af 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -122,31 +122,29 @@ def make_export_code( ] code = [f"onx = to_onnx(model, inputs, {sargs})"] elif exporter == "onnx-dynamo": - args.append(f"opset={opset}") - if optimization: - args.append("options=OptimizationOptions(pattern={optimization!r})") + args.append(f"opset_version={opset}") sargs = ", ".join(args) imports = [] - code = [f"epo = torch.onnx.export(model, (), inputs, {sargs})"] + code = [f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"] if optimization: imports.append("import onnxscript") + code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"]) + if "os_ort" in optimization: + imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions") + code.extend(["ort_fusions.optimize_for_ort(epo.model)"]) + if dump_folder: + imports.insert(0, "import os") code.extend( - [ - "ir_model = epo.to_ir()", - "onnxscript.optimizer.optimize_ir(ir_model)", - "ir_optimized = ort_fusions.optimize_for_ort(ir_model)", - "epo.model = ir_optimized", - ] + [f"os.makedirs({dump_folder!r}, exist_ok=True)", f"epo.save({filename!r})"] ) - if dump_folder: - code.append("epo.save({filename!r}") else: raise ValueError(f"Unexpected exporter {exporter!r}") if not patch_kwargs: return "\n".join(imports), "\n".join(code) imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches") - code = [f"with torch_export_patches(**{patch_kwargs}):", *[" " + _ for _ in code]] + sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items()) + code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]] return "\n".join(imports), "\n".join(code) @@ -223,7 +221,14 @@ def code_sample( from onnx_diagnostic.torch_models.code_sample import code_sample - print(code_sample("arnir0/Tiny-LLM")) + print( + code_sample( + "arnir0/Tiny-LLM", + exporter="onnx-dynamo", + optimization="ir", + patch=True, + ) + ) """ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id( model_id, @@ -296,6 +301,7 @@ def code_sample( stop_if_static=stop_if_static, dump_folder=dump_folder, opset=opset, + dynamic_shapes=data["dynamic_shapes"], ) input_code = make_code_for_inputs(data["inputs"]) cache_import = ( From a9a1fd90059cac2a79b57315b126c7473beab254 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Oct 2025 13:55:48 +0200 Subject: [PATCH 5/8] fix issues --- onnx_diagnostic/torch_models/code_sample.py | 34 +++++++++++++-------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index 1368c9af..89f365d4 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -100,7 +100,7 @@ def make_export_code( patch_kwargs: Optional[Dict[str, Any]] = None, stop_if_static: int = 0, dump_folder: Optional[str] = None, - opset: int = 18, + opset: Optional[int] = None, dynamic_shapes: Optional[Dict[str, Any]] = None, output_names: Optional[List[str]] = None, verbose: int = 0, @@ -111,9 +111,10 @@ def make_export_code( if dump_folder: filename = os.path.join(dump_folder, "model.onnx") if exporter == "custom": - args.append(f"target_opset={opset}") + if opset: + args.append(f"target_opset={opset}") if optimization: - args.append(f"options=OptimizationOptions(pattern={optimization!r})") + args.append(f"options=OptimizationOptions(patterns={optimization!r})") args.append(f"large_model=True, filename={filename!r}") sargs = ", ".join(args) imports = [ @@ -122,7 +123,8 @@ def make_export_code( ] code = [f"onx = to_onnx(model, inputs, {sargs})"] elif exporter == "onnx-dynamo": - args.append(f"opset_version={opset}") + if opset: + args.append(f"opset_version={opset}") sargs = ", ".join(args) imports = [] code = [f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"] @@ -143,6 +145,8 @@ def make_export_code( return "\n".join(imports), "\n".join(code) imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches") + if stop_if_static: + patch_kwargs["patch_kwargs"] = stop_if_static sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items()) code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]] return "\n".join(imports), "\n".join(code) @@ -293,15 +297,19 @@ def code_sample( if model_options: args.append(f"model_options={model_options!r}") model_args = ", ".join(args) - imports, exporter_code = make_export_code( - exporter=exporter, - patch_kwargs=patch_kwargs, - verbose=verbose, - optimization=optimization, - stop_if_static=stop_if_static, - dump_folder=dump_folder, - opset=opset, - dynamic_shapes=data["dynamic_shapes"], + imports, exporter_code = ( + make_export_code( + exporter=exporter, + patch_kwargs=patch_kwargs, + verbose=verbose, + optimization=optimization, + stop_if_static=stop_if_static, + dump_folder=dump_folder, + opset=opset, + dynamic_shapes=data["dynamic_shapes"], + ) + if exporter is not None + else ([], []) ) input_code = make_code_for_inputs(data["inputs"]) cache_import = ( From 34b798c23d250a3775a9e00d5f3b3a1b4309f03f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Oct 2025 14:16:27 +0200 Subject: [PATCH 6/8] fixes --- onnx_diagnostic/_command_lines_parser.py | 115 +++++++++++++++++--- onnx_diagnostic/torch_models/code_sample.py | 10 +- onnx_diagnostic/torch_models/validate.py | 3 - 3 files changed, 105 insertions(+), 23 deletions(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 6010ec96..f4c2eca5 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -559,12 +559,6 @@ def get_parser_validate() -> ArgumentParser: help="Avoids raising an exception when an input sets does not work with " "the exported model.\nExample: --quiet-input-sets=inputs,inputs22", ) - parser.add_argument( - "--sample-code", - default="", - help="Generates a sample code to export a model without " - "without this package.\nExample --sample-code=export_sample", - ) return parser @@ -630,7 +624,6 @@ def _cmd_validate(argv: List[Any]): output_names=( None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",") ), - sample_code=args.sample_code, ) print("") print("-- summary --") @@ -638,6 +631,94 @@ def _cmd_validate(argv: List[Any]): print(f":{k},{v};") +def _cmd_export_sample(argv: List[Any]): + from .helpers import string_type + from .torch_models.validate import get_inputs_for_task, _make_folder_name + from .torch_models.code_sample import code_sample + from .tasks import supported_tasks + + parser = get_parser_validate() + args = parser.parse_args(argv[1:]) + if not args.task and not args.mid: + print("-- list of supported tasks:") + print("\n".join(supported_tasks())) + elif not args.mid: + data = get_inputs_for_task(args.task) + if args.verbose: + print(f"task: {args.task}") + max_length = max(len(k) for k in data["inputs"]) + 1 + print("-- inputs") + for k, v in data["inputs"].items(): + print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}") + print("-- dynamic_shapes") + for k, v in data["dynamic_shapes"].items(): + print(f" + {k.ljust(max_length)}: {string_type(v)}") + else: + # Let's skip any invalid combination if known to be unsupported + if ( + "onnx" not in (args.export or "") + and "custom" not in (args.export or "") + and (args.opt or "") + ): + print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}") + return + patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch} + code = code_sample( + model_id=args.mid, + task=args.task, + do_run=args.run, + verbose=args.verbose, + quiet=args.quiet, + same_as_pretrained=args.same_as_trained, + use_pretrained=args.trained, + dtype=args.dtype, + device=args.device, + patch=patch_dict, + rewrite=args.rewrite and patch_dict.get("patch", True), + stop_if_static=args.stop_if_static, + optimization=args.opt, + exporter=args.export, + dump_folder=args.dump_folder, + drop_inputs=None if not args.drop else args.drop.split(","), + input_options=args.iop, + model_options=args.mop, + subfolder=args.subfolder, + opset=args.opset, + runtime=args.runtime, + output_names=( + None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",") + ), + ) + if args.dump_folder: + os.makedirs(args.dump_folder, exist_ok=True) + name = ( + _make_folder_name( + model_id=args.model_id, + exporter=args.exporter, + optimization=args.optimization, + dtype=args.dtype, + device=args.device, + subfolder=args.subfolder, + opset=args.opset, + drop_inputs=None if not args.drop else args.drop.split(","), + same_as_pretrained=args.same_as_pretrained, + use_pretrained=args.use_pretrained, + task=args.task, + ).replace("/", "-") + + ".py" + ) + fullname = os.path.join(args.dump_folder, name) + if args.verbose: + print(f"-- prints code in {fullname!r}") + print("--") + with open(fullname, "w") as f: + f.write(code) + if args.verbose: + print("-- done") + else: + print(code) + + def get_parser_stats() -> ArgumentParser: parser = ArgumentParser( prog="stats", @@ -967,14 +1048,15 @@ def get_main_parser() -> ArgumentParser: Type 'python -m onnx_diagnostic --help' to get help for a specific command. - agg - aggregates statistics from multiple files - config - prints a configuration for a model id - find - find node consuming or producing a result - lighten - makes an onnx model lighter by removing the weights, - print - prints the model on standard output - stats - produces statistics on a model - unlighten - restores an onnx model produces by the previous experiment - validate - validate a model + agg - aggregates statistics from multiple files + config - prints a configuration for a model id + exportsample - produces a code to export a model + find - find node consuming or producing a result + lighten - makes an onnx model lighter by removing the weights, + print - prints the model on standard output + stats - produces statistics on a model + unlighten - restores an onnx model produces by the previous experiment + validate - validate a model """ ), ) @@ -983,6 +1065,7 @@ def get_main_parser() -> ArgumentParser: choices=[ "agg", "config", + "exportsample", "find", "lighten", "print", @@ -1005,6 +1088,7 @@ def main(argv: Optional[List[Any]] = None): validate=_cmd_validate, stats=_cmd_stats, agg=_cmd_agg, + exportsample=_cmd_export_sample, ) if argv is None: @@ -1027,6 +1111,7 @@ def main(argv: Optional[List[Any]] = None): validate=get_parser_validate, stats=get_parser_stats, agg=get_parser_agg, + exportsample=get_parser_validate, ) cmd = argv[0] if cmd not in parsers: diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index 89f365d4..56543ae0 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -80,13 +80,13 @@ def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str: elif v.dtype in (torch.float32, torch.float16, torch.bfloat16): code = f"{k}=torch.rand({shape}, dtype={v.dtype})" else: - raise ValueError(f"Unexpeted dtype = {v.dtype} for k={k!r}") + raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}") elif v.__class__.__name__ == "DynamicCache": obj = flatten_object(v) cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj] va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])] - vas = ", ".join(va) - code = f"{k}=make_dynamic_cache([{vas}])" + va2 = ", ".join(va) + code = f"{k}=make_dynamic_cache([{va2}])" else: raise ValueError(f"Unexpected type {type(v)} for k={k!r}") codes.append(code) @@ -146,7 +146,7 @@ def make_export_code( imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches") if stop_if_static: - patch_kwargs["patch_kwargs"] = stop_if_static + patch_kwargs["stop_if_static"] = stop_if_static sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items()) code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]] return "\n".join(imports), "\n".join(code) @@ -331,7 +331,7 @@ def code_sample( f"inputs = {input_code}", exporter_code, ] - code = "\n".join(pieces) + code = "\n".join(pieces) # type: ignore[arg-type] try: import black except ImportError: diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index b430b668..e9706f4f 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -350,7 +350,6 @@ def validate_model( output_names: Optional[List[str]] = None, ort_logs: bool = False, quiet_input_sets: Optional[Set[str]] = None, - sample_code: Optional[str] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Validates a model. @@ -407,8 +406,6 @@ def validate_model( :param ort_logs: increases onnxruntime verbosity when creating the session :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set even if quiet is False - :param sample_code: if specified, the function generates a code - which exports this model id without this package. :return: two dictionaries, one with some metrics, another one with whatever the function produces From 116b59ce894dc9c6475353f8d6a0427c5f9a9ddc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Oct 2025 16:01:08 +0200 Subject: [PATCH 7/8] fix examples --- CHANGELOGS.rst | 1 + .../ut_torch_models/test_code_sample.py | 4 ++-- onnx_diagnostic/torch_models/code_sample.py | 23 +++++++++++-------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 617cbe10..60c2a7ac 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.16 ++++++ +* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs * :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``) * :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches diff --git a/_unittests/ut_torch_models/test_code_sample.py b/_unittests/ut_torch_models/test_code_sample.py index ff742099..d3794086 100644 --- a/_unittests/ut_torch_models/test_code_sample.py +++ b/_unittests/ut_torch_models/test_code_sample.py @@ -15,7 +15,7 @@ class TestCodeSample(ExtTestCase): @requires_transformers("4.53") - @requires_torch("2.7.99") + @requires_torch("2.9") @requires_experimental() @hide_stdout() def test_code_sample_tiny_llm_custom(self): @@ -40,7 +40,7 @@ def test_code_sample_tiny_llm_custom(self): self.assertNotIn("Traceback", st) @requires_transformers("4.53") - @requires_torch("2.7.99") + @requires_torch("2.9") @requires_experimental() @hide_stdout() def test_code_sample_tiny_llm_dynamo(self): diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py index 56543ae0..4d760022 100644 --- a/onnx_diagnostic/torch_models/code_sample.py +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -108,7 +108,11 @@ def make_export_code( args = [f"dynamic_shapes={dynamic_shapes}"] if output_names: args.append(f"output_names={output_names}") + code = [] + imports = [] if dump_folder: + code.append(f"os.makedirs({dump_folder!r})") + imports.append("import os") filename = os.path.join(dump_folder, "model.onnx") if exporter == "custom": if opset: @@ -117,17 +121,19 @@ def make_export_code( args.append(f"options=OptimizationOptions(patterns={optimization!r})") args.append(f"large_model=True, filename={filename!r}") sargs = ", ".join(args) - imports = [ - "from experimental_experiment.torch_interpreter import to_onnx", - "from experimental_experiment.xbuilder import OptimizationOptions", - ] - code = [f"onx = to_onnx(model, inputs, {sargs})"] + imports.extend( + [ + "from experimental_experiment.torch_interpreter import to_onnx", + "from experimental_experiment.xbuilder import OptimizationOptions", + ] + ) + code.extend([f"onx = to_onnx(model, inputs, {sargs})"]) elif exporter == "onnx-dynamo": if opset: args.append(f"opset_version={opset}") sargs = ", ".join(args) imports = [] - code = [f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"] + code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"]) if optimization: imports.append("import onnxscript") code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"]) @@ -135,10 +141,7 @@ def make_export_code( imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions") code.extend(["ort_fusions.optimize_for_ort(epo.model)"]) if dump_folder: - imports.insert(0, "import os") - code.extend( - [f"os.makedirs({dump_folder!r}, exist_ok=True)", f"epo.save({filename!r})"] - ) + code.extend([f"epo.save({filename!r})"]) else: raise ValueError(f"Unexpected exporter {exporter!r}") if not patch_kwargs: From 1955408412d6e2d811d252e4802b2a80f58a9b57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 22 Oct 2025 18:27:17 +0200 Subject: [PATCH 8/8] status --- _doc/api/torch_models/code_sample.rst | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 _doc/api/torch_models/code_sample.rst diff --git a/_doc/api/torch_models/code_sample.rst b/_doc/api/torch_models/code_sample.rst new file mode 100644 index 00000000..ae6df4f3 --- /dev/null +++ b/_doc/api/torch_models/code_sample.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_models.code_sample +======================================== + +.. automodule:: onnx_diagnostic.torch_models.code_sample + :members: + :no-undoc-members: