diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3cb416dc..9029a24d 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.16 ++++++ +* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs * :pr:`269`: adds one unit test to track a patch fixing broadcast output shape * :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``) * :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches diff --git a/_doc/api/torch_models/code_sample.rst b/_doc/api/torch_models/code_sample.rst new file mode 100644 index 00000000..ae6df4f3 --- /dev/null +++ b/_doc/api/torch_models/code_sample.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_models.code_sample +======================================== + +.. automodule:: onnx_diagnostic.torch_models.code_sample + :members: + :no-undoc-members: diff --git a/_doc/api/torch_models/index.rst b/_doc/api/torch_models/index.rst index 3279e918..9a472cec 100644 --- a/_doc/api/torch_models/index.rst +++ b/_doc/api/torch_models/index.rst @@ -5,6 +5,7 @@ onnx_diagnostic.torch_models :maxdepth: 1 :caption: submodules + code_sample hghub/index llms validate diff --git a/_unittests/ut_torch_models/test_code_sample.py b/_unittests/ut_torch_models/test_code_sample.py new file mode 100644 index 00000000..d3794086 --- /dev/null +++ b/_unittests/ut_torch_models/test_code_sample.py @@ -0,0 +1,100 @@ +import unittest +import subprocess +import sys +import torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_torch, + requires_experimental, + requires_transformers, +) +from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache +from onnx_diagnostic.torch_models.code_sample import code_sample, make_code_for_inputs + + +class TestCodeSample(ExtTestCase): + @requires_transformers("4.53") + @requires_torch("2.9") + @requires_experimental() + @hide_stdout() + def test_code_sample_tiny_llm_custom(self): + code = code_sample( + "arnir0/Tiny-LLM", + verbose=2, + exporter="custom", + patch=True, + dump_folder="dump_test/validate_tiny_llm_custom", + dtype="float16", + device="cpu", + optimization="default", + ) + filename = self.get_dump_file("test_code_sample_tiny_llm_custom.py") + with open(filename, "w") as f: + f.write(code) + cmds = [sys.executable, "-u", filename] + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + _out, err = res + st = err.decode("ascii", errors="ignore") + self.assertNotIn("Traceback", st) + + @requires_transformers("4.53") + @requires_torch("2.9") + @requires_experimental() + @hide_stdout() + def test_code_sample_tiny_llm_dynamo(self): + code = code_sample( + "arnir0/Tiny-LLM", + verbose=2, + exporter="onnx-dynamo", + patch=True, + dump_folder="dump_test/validate_tiny_llm_dynamo", + dtype="float16", + device="cpu", + optimization="ir", + ) + filename = self.get_dump_file("test_code_sample_tiny_llm_dynamo.py") + with open(filename, "w") as f: + f.write(code) + cmds = [sys.executable, "-u", filename] + p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + res = p.communicate() + _out, err = res + st = err.decode("ascii", errors="ignore") + self.assertNotIn("Traceback", st) + + def test_make_code_for_inputs(self): + values = [ + ("dict(a=True)", dict(a=True)), + ("dict(a=1)", dict(a=1)), + ( + "dict(a=torch.randint(3, size=(2,), dtype=torch.int64))", + dict(a=torch.tensor([2, 3], dtype=torch.int64)), + ), + ( + "dict(a=torch.rand((2,), dtype=torch.float16))", + dict(a=torch.tensor([2, 3], dtype=torch.float16)), + ), + ] + for res, inputs in values: + self.assertEqual(res, make_code_for_inputs(inputs)) + + res = make_code_for_inputs( + dict( + cc=make_dynamic_cache( + [(torch.randn(2, 2, 2, 2), torch.randn(2, 2, 2, 2)) for i in range(2)] + ) + ) + ) + self.assertEqual( + "dict(cc=make_dynamic_cache([(torch.rand((2, 2, 2, 2), " + "dtype=torch.float32),torch.rand((2, 2, 2, 2), dtype=torch.float32)), " + "(torch.rand((2, 2, 2, 2), dtype=torch.float32)," + "torch.rand((2, 2, 2, 2), dtype=torch.float32))]))", + res, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index cf7cc863..f4c2eca5 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -557,7 +557,7 @@ def get_parser_validate() -> ArgumentParser: "--quiet-input-sets", default="", help="Avoids raising an exception when an input sets does not work with " - "the exported model, example: --quiet-input-sets=inputs,inputs22", + "the exported model.\nExample: --quiet-input-sets=inputs,inputs22", ) return parser @@ -631,6 +631,94 @@ def _cmd_validate(argv: List[Any]): print(f":{k},{v};") +def _cmd_export_sample(argv: List[Any]): + from .helpers import string_type + from .torch_models.validate import get_inputs_for_task, _make_folder_name + from .torch_models.code_sample import code_sample + from .tasks import supported_tasks + + parser = get_parser_validate() + args = parser.parse_args(argv[1:]) + if not args.task and not args.mid: + print("-- list of supported tasks:") + print("\n".join(supported_tasks())) + elif not args.mid: + data = get_inputs_for_task(args.task) + if args.verbose: + print(f"task: {args.task}") + max_length = max(len(k) for k in data["inputs"]) + 1 + print("-- inputs") + for k, v in data["inputs"].items(): + print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}") + print("-- dynamic_shapes") + for k, v in data["dynamic_shapes"].items(): + print(f" + {k.ljust(max_length)}: {string_type(v)}") + else: + # Let's skip any invalid combination if known to be unsupported + if ( + "onnx" not in (args.export or "") + and "custom" not in (args.export or "") + and (args.opt or "") + ): + print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}") + return + patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch} + code = code_sample( + model_id=args.mid, + task=args.task, + do_run=args.run, + verbose=args.verbose, + quiet=args.quiet, + same_as_pretrained=args.same_as_trained, + use_pretrained=args.trained, + dtype=args.dtype, + device=args.device, + patch=patch_dict, + rewrite=args.rewrite and patch_dict.get("patch", True), + stop_if_static=args.stop_if_static, + optimization=args.opt, + exporter=args.export, + dump_folder=args.dump_folder, + drop_inputs=None if not args.drop else args.drop.split(","), + input_options=args.iop, + model_options=args.mop, + subfolder=args.subfolder, + opset=args.opset, + runtime=args.runtime, + output_names=( + None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",") + ), + ) + if args.dump_folder: + os.makedirs(args.dump_folder, exist_ok=True) + name = ( + _make_folder_name( + model_id=args.model_id, + exporter=args.exporter, + optimization=args.optimization, + dtype=args.dtype, + device=args.device, + subfolder=args.subfolder, + opset=args.opset, + drop_inputs=None if not args.drop else args.drop.split(","), + same_as_pretrained=args.same_as_pretrained, + use_pretrained=args.use_pretrained, + task=args.task, + ).replace("/", "-") + + ".py" + ) + fullname = os.path.join(args.dump_folder, name) + if args.verbose: + print(f"-- prints code in {fullname!r}") + print("--") + with open(fullname, "w") as f: + f.write(code) + if args.verbose: + print("-- done") + else: + print(code) + + def get_parser_stats() -> ArgumentParser: parser = ArgumentParser( prog="stats", @@ -960,14 +1048,15 @@ def get_main_parser() -> ArgumentParser: Type 'python -m onnx_diagnostic --help' to get help for a specific command. - agg - aggregates statistics from multiple files - config - prints a configuration for a model id - find - find node consuming or producing a result - lighten - makes an onnx model lighter by removing the weights, - print - prints the model on standard output - stats - produces statistics on a model - unlighten - restores an onnx model produces by the previous experiment - validate - validate a model + agg - aggregates statistics from multiple files + config - prints a configuration for a model id + exportsample - produces a code to export a model + find - find node consuming or producing a result + lighten - makes an onnx model lighter by removing the weights, + print - prints the model on standard output + stats - produces statistics on a model + unlighten - restores an onnx model produces by the previous experiment + validate - validate a model """ ), ) @@ -976,6 +1065,7 @@ def get_main_parser() -> ArgumentParser: choices=[ "agg", "config", + "exportsample", "find", "lighten", "print", @@ -998,6 +1088,7 @@ def main(argv: Optional[List[Any]] = None): validate=_cmd_validate, stats=_cmd_stats, agg=_cmd_agg, + exportsample=_cmd_export_sample, ) if argv is None: @@ -1020,6 +1111,7 @@ def main(argv: Optional[List[Any]] = None): validate=get_parser_validate, stats=get_parser_stats, agg=get_parser_agg, + exportsample=get_parser_validate, ) cmd = argv[0] if cmd not in parsers: diff --git a/onnx_diagnostic/torch_models/code_sample.py b/onnx_diagnostic/torch_models/code_sample.py new file mode 100644 index 00000000..cdcd8970 --- /dev/null +++ b/onnx_diagnostic/torch_models/code_sample.py @@ -0,0 +1,343 @@ +import os +import textwrap +import torch +from typing import Any, Dict, List, Optional, Tuple, Union +from ..helpers import flatten_object +from ..helpers.torch_helper import to_any +from .hghub.model_inputs import _preprocess_model_id +from .hghub import get_untrained_model_with_inputs +from .validate import filter_inputs, make_patch_kwargs + + +CODE_SAMPLES = { + "imports": "from typing import Any\nimport torch", + "get_model_with_inputs": textwrap.dedent( + """ + def get_model_with_inputs( + model_id:str, + subfolder: str | None = None, + dtype: str | torch.dtype | None = None, + device: str | torch.device | None = None, + same_as_pretrained: bool = False, + use_pretrained: bool = False, + input_options: dict[str, Any] | None = None, + model_options: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if use_pretrained: + import transformers + assert same_as_pretrained, ( + "same_as_pretrained must be True if use_pretrained is True" + ) + # tokenizer = AutoTokenizer.from_pretrained(model_path) + model = transformers.AutoModel.from_pretrained( + model_id, + trust_remote_code=True, + subfolder=subfolder, + dtype=dtype, + device=device, + ) + data = {"model": model} + assert not input_options, "Not implemented yet with input_options{input_options}" + assert not model_options, "Not implemented yet with input_options{model_options}" + else: + from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs + data = get_untrained_model_with_inputs( + model_id, + use_pretrained=use_pretrained, + same_as_pretrained=same_as_pretrained, + inputs_kwargs=input_options, + model_kwargs=model_options, + subfolder=subfolder, + add_second_input=False, + ) + if dtype: + data["model"] = data["model"].to( + getattr(torch, dtype) if isinstance(dtype, str) else dtype + ) + if device: + data["model"] = data["model"].to(device) + return data["model"] + """ + ), +} + + +def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str: + """ + Creates a code to generate random inputs. + + :param inputs: dictionary + :return: code + """ + codes = [] + for k, v in inputs.items(): + if isinstance(v, (int, bool, float)): + code = f"{k}={v}" + elif isinstance(v, torch.Tensor): + shape = tuple(map(int, v.shape)) + if v.dtype in (torch.int32, torch.int64): + code = f"{k}=torch.randint({v.max()}, size={shape}, dtype={v.dtype})" + elif v.dtype in (torch.float32, torch.float16, torch.bfloat16): + code = f"{k}=torch.rand({shape}, dtype={v.dtype})" + else: + raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}") + elif v.__class__.__name__ == "DynamicCache": + obj = flatten_object(v) + cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj] + va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])] + va2 = ", ".join(va) + code = f"{k}=make_dynamic_cache([{va2}])" + else: + raise ValueError(f"Unexpected type {type(v)} for k={k!r}") + codes.append(code) + st = ", ".join(codes) + return f"dict({st})" + + +def make_export_code( + exporter: str, + optimization: Optional[str] = None, + patch_kwargs: Optional[Dict[str, Any]] = None, + stop_if_static: int = 0, + dump_folder: Optional[str] = None, + opset: Optional[int] = None, + dynamic_shapes: Optional[Dict[str, Any]] = None, + output_names: Optional[List[str]] = None, + verbose: int = 0, +) -> Tuple[str, str]: + args = [f"dynamic_shapes={dynamic_shapes}"] + if output_names: + args.append(f"output_names={output_names}") + code = [] + imports = [] + if dump_folder: + code.append(f"os.makedirs({dump_folder!r})") + imports.append("import os") + filename = os.path.join(dump_folder, "model.onnx") + if exporter == "custom": + if opset: + args.append(f"target_opset={opset}") + if optimization: + args.append(f"options=OptimizationOptions(patterns={optimization!r})") + args.append(f"large_model=True, filename={filename!r}") + sargs = ", ".join(args) + imports.extend( + [ + "from experimental_experiment.torch_interpreter import to_onnx", + "from experimental_experiment.xbuilder import OptimizationOptions", + ] + ) + code.extend([f"onx = to_onnx(model, inputs, {sargs})"]) + elif exporter == "onnx-dynamo": + if opset: + args.append(f"opset_version={opset}") + sargs = ", ".join(args) + code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"]) + if optimization: + imports.append("import onnxscript") + code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"]) + if "os_ort" in optimization: + imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions") + code.extend(["ort_fusions.optimize_for_ort(epo.model)"]) + if dump_folder: + code.extend([f"epo.save({filename!r})"]) + else: + raise ValueError(f"Unexpected exporter {exporter!r}") + if not patch_kwargs: + return "\n".join(imports), "\n".join(code) + + imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches") + if stop_if_static: + patch_kwargs["stop_if_static"] = stop_if_static + sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items()) + code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]] + return "\n".join(imports), "\n".join(code) + + +def code_sample( + model_id: str, + task: Optional[str] = None, + do_run: bool = False, + exporter: Optional[str] = None, + do_same: bool = False, + verbose: int = 0, + dtype: Optional[Union[str, torch.dtype]] = None, + device: Optional[Union[str, torch.device]] = None, + same_as_pretrained: bool = False, + use_pretrained: bool = False, + optimization: Optional[str] = None, + quiet: bool = False, + patch: Union[bool, str, Dict[str, bool]] = False, + rewrite: bool = False, + stop_if_static: int = 1, + dump_folder: Optional[str] = None, + drop_inputs: Optional[List[str]] = None, + input_options: Optional[Dict[str, Any]] = None, + model_options: Optional[Dict[str, Any]] = None, + subfolder: Optional[str] = None, + opset: Optional[int] = None, + runtime: str = "onnxruntime", + output_names: Optional[List[str]] = None, +) -> str: + """ + This generates a code to export a model with the proper settings. + + :param model_id: model id to validate + :param task: task used to generate the necessary inputs, + can be left empty to use the default task for this model + if it can be determined + :param do_run: checks the model works with the defined inputs + :param exporter: exporter the model using this exporter, + available list: ``export-strict``, ``export-nostrict``, ... + see below + :param do_same: checks the discrepancies of the exported model + :param verbose: verbosity level + :param dtype: uses this dtype to check the model + :param device: do the verification on this device + :param same_as_pretrained: use a model equivalent to the trained, + this is not always possible + :param use_pretrained: use the trained model, not the untrained one + :param optimization: optimization to apply to the exported model, + depend on the the exporter + :param quiet: if quiet, catches exception if any issue + :param patch: applies patches (``patch_transformers=True, path_diffusers=True``) + if True before exporting + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`, + a string can be used to specify only one of them + :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting, + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` + :param stop_if_static: stops if a dynamic dimension becomes static, + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` + :param dump_folder: dumps everything in a subfolder of this one + :param drop_inputs: drops this list of inputs (given their names) + :param input_options: additional options to define the dummy inputs + used to export + :param model_options: additional options when creating the model such as + ``num_hidden_layers`` or ``attn_implementation`` + :param subfolder: version or subfolders to uses when retrieving a model id + :param opset: onnx opset to use for the conversion + :param runtime: onnx runtime to use to check about discrepancies, + possible values ``onnxruntime``, ``torch``, ``orteval``, + ``orteval10``, ``ref`` only if `do_run` is true + :param output_names: output names the onnx exporter should use + :return: a code + + .. runpython:: + :showcode: + + from onnx_diagnostic.torch_models.code_sample import code_sample + + print( + code_sample( + "arnir0/Tiny-LLM", + exporter="onnx-dynamo", + optimization="ir", + patch=True, + ) + ) + """ + model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id( + model_id, + subfolder, + same_as_pretrained=same_as_pretrained, + use_pretrained=use_pretrained, + ) + patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite) + + iop = input_options or {} + mop = model_options or {} + data = get_untrained_model_with_inputs( + model_id, + verbose=verbose, + task=task, + use_pretrained=use_pretrained, + same_as_pretrained=same_as_pretrained, + inputs_kwargs=iop, + model_kwargs=mop, + subfolder=subfolder, + add_second_input=False, + ) + if drop_inputs: + update = {} + for k in data: + if k.startswith("inputs"): + update[k], ds = filter_inputs( + data[k], + drop_names=drop_inputs, + model=data["model"], + dynamic_shapes=data["dynamic_shapes"], + ) + update["dynamic_shapes"] = ds + data.update(update) + + update = {} + for k in data: + if k.startswith("inputs"): + v = data[k] + if dtype: + update[k] = v = to_any( + v, getattr(torch, dtype) if isinstance(dtype, str) else dtype + ) + if device: + update[k] = v = to_any(v, device) + if update: + data.update(update) + + args = [f"{model_id!r}"] + if subfolder: + args.append(f"subfolder={subfolder!r}") + if dtype: + args.append(f"dtype={dtype!r}") + if device: + args.append(f"device={device!r}") + if same_as_pretrained: + args.append(f"same_as_pretrained={same_as_pretrained!r}") + if use_pretrained: + args.append(f"use_pretrained={use_pretrained!r}") + if input_options: + args.append(f"input_options={input_options!r}") + if model_options: + args.append(f"model_options={model_options!r}") + model_args = ", ".join(args) + imports, exporter_code = ( + make_export_code( + exporter=exporter, + patch_kwargs=patch_kwargs, + verbose=verbose, + optimization=optimization, + stop_if_static=stop_if_static, + dump_folder=dump_folder, + opset=opset, + dynamic_shapes=data["dynamic_shapes"], + ) + if exporter is not None + else ([], []) + ) + input_code = make_code_for_inputs(data["inputs"]) + cache_import = ( + "from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache" + if "dynamic_cache" in input_code + else "" + ) + + pieces = [ + CODE_SAMPLES["imports"], + imports, + cache_import, + CODE_SAMPLES["get_model_with_inputs"], + textwrap.dedent( + f""" + model = get_model_with_inputs({model_args}) + """ + ), + f"inputs = {input_code}", + exporter_code, + ] + code = "\n".join(pieces) # type: ignore[arg-type] + try: + import black + except ImportError: + # No black formatting. + return code + + return black.format_str(code, mode=black.Mode()) diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index fca1cfbd..e9706f4f 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -293,6 +293,33 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]: return new_cfg +def make_patch_kwargs( + patch: Union[bool, str, Dict[str, bool]] = False, + rewrite: bool = False, +) -> Dict[str, Any]: + """Creates patch arguments.""" + default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True) + if isinstance(patch, bool): + patch_kwargs = default_patch if patch else dict(patch=False) + elif isinstance(patch, str): + patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420 + else: + assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}" + patch_kwargs = patch.copy() + if "patch" not in patch_kwargs: + if any(patch_kwargs.values()): + patch_kwargs["patch"] = True + elif len(patch) == 1 and patch.get("patch", False): + patch_kwargs.update(default_patch) + + assert not rewrite or patch_kwargs.get("patch", False), ( + f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} " + f"patch must be True to enable rewriting, " + f"if --patch=0 was specified on the command line, rewrites are disabled." + ) + return patch_kwargs + + def validate_model( model_id: str, task: Optional[str] = None, @@ -420,25 +447,8 @@ def validate_model( use_pretrained=use_pretrained, ) time_preprocess_model_id = time.perf_counter() - main_validation_begin - default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True) - if isinstance(patch, bool): - patch_kwargs = default_patch if patch else dict(patch=False) - elif isinstance(patch, str): - patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420 - else: - assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}" - patch_kwargs = patch.copy() - if "patch" not in patch_kwargs: - if any(patch_kwargs.values()): - patch_kwargs["patch"] = True - elif len(patch) == 1 and patch.get("patch", False): - patch_kwargs.update(default_patch) + patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite) - assert not rewrite or patch_kwargs.get("patch", False), ( - f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} " - f"patch must be True to enable rewriting, " - f"if --patch=0 was specified on the command line, rewrites are disabled." - ) summary = version_summary() summary.update( dict(