Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.16
++++++

* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/torch_models/code_sample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_models.code_sample
========================================

.. automodule:: onnx_diagnostic.torch_models.code_sample
:members:
:no-undoc-members:
1 change: 1 addition & 0 deletions _doc/api/torch_models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ onnx_diagnostic.torch_models
:maxdepth: 1
:caption: submodules

code_sample
hghub/index
llms
validate
Expand Down
100 changes: 100 additions & 0 deletions _unittests/ut_torch_models/test_code_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest
import subprocess
import sys
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_torch,
requires_experimental,
requires_transformers,
)
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.torch_models.code_sample import code_sample, make_code_for_inputs


class TestCodeSample(ExtTestCase):
@requires_transformers("4.53")
@requires_torch("2.9")
@requires_experimental()
@hide_stdout()
def test_code_sample_tiny_llm_custom(self):
code = code_sample(
"arnir0/Tiny-LLM",
verbose=2,
exporter="custom",
patch=True,
dump_folder="dump_test/validate_tiny_llm_custom",
dtype="float16",
device="cpu",
optimization="default",
)
filename = self.get_dump_file("test_code_sample_tiny_llm_custom.py")
with open(filename, "w") as f:
f.write(code)
cmds = [sys.executable, "-u", filename]
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
res = p.communicate()
_out, err = res
st = err.decode("ascii", errors="ignore")
self.assertNotIn("Traceback", st)

@requires_transformers("4.53")
@requires_torch("2.9")
@requires_experimental()
@hide_stdout()
def test_code_sample_tiny_llm_dynamo(self):
code = code_sample(
"arnir0/Tiny-LLM",
verbose=2,
exporter="onnx-dynamo",
patch=True,
dump_folder="dump_test/validate_tiny_llm_dynamo",
dtype="float16",
device="cpu",
optimization="ir",
)
filename = self.get_dump_file("test_code_sample_tiny_llm_dynamo.py")
with open(filename, "w") as f:
f.write(code)
cmds = [sys.executable, "-u", filename]
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
res = p.communicate()
_out, err = res
st = err.decode("ascii", errors="ignore")
self.assertNotIn("Traceback", st)

def test_make_code_for_inputs(self):
values = [
("dict(a=True)", dict(a=True)),
("dict(a=1)", dict(a=1)),
(
"dict(a=torch.randint(3, size=(2,), dtype=torch.int64))",
dict(a=torch.tensor([2, 3], dtype=torch.int64)),
),
(
"dict(a=torch.rand((2,), dtype=torch.float16))",
dict(a=torch.tensor([2, 3], dtype=torch.float16)),
),
]
for res, inputs in values:
self.assertEqual(res, make_code_for_inputs(inputs))

res = make_code_for_inputs(
dict(
cc=make_dynamic_cache(
[(torch.randn(2, 2, 2, 2), torch.randn(2, 2, 2, 2)) for i in range(2)]
)
)
)
self.assertEqual(
"dict(cc=make_dynamic_cache([(torch.rand((2, 2, 2, 2), "
"dtype=torch.float32),torch.rand((2, 2, 2, 2), dtype=torch.float32)), "
"(torch.rand((2, 2, 2, 2), dtype=torch.float32),"
"torch.rand((2, 2, 2, 2), dtype=torch.float32))]))",
res,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
110 changes: 101 additions & 9 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def get_parser_validate() -> ArgumentParser:
"--quiet-input-sets",
default="",
help="Avoids raising an exception when an input sets does not work with "
"the exported model, example: --quiet-input-sets=inputs,inputs22",
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
)
return parser

Expand Down Expand Up @@ -631,6 +631,94 @@ def _cmd_validate(argv: List[Any]):
print(f":{k},{v};")


def _cmd_export_sample(argv: List[Any]):
from .helpers import string_type
from .torch_models.validate import get_inputs_for_task, _make_folder_name
from .torch_models.code_sample import code_sample
from .tasks import supported_tasks

parser = get_parser_validate()
args = parser.parse_args(argv[1:])
if not args.task and not args.mid:
print("-- list of supported tasks:")
print("\n".join(supported_tasks()))
elif not args.mid:
data = get_inputs_for_task(args.task)
if args.verbose:
print(f"task: {args.task}")
max_length = max(len(k) for k in data["inputs"]) + 1
print("-- inputs")
for k, v in data["inputs"].items():
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
print("-- dynamic_shapes")
for k, v in data["dynamic_shapes"].items():
print(f" + {k.ljust(max_length)}: {string_type(v)}")
else:
# Let's skip any invalid combination if known to be unsupported
if (
"onnx" not in (args.export or "")
and "custom" not in (args.export or "")
and (args.opt or "")
):
print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}")
return
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
code = code_sample(
model_id=args.mid,
task=args.task,
do_run=args.run,
verbose=args.verbose,
quiet=args.quiet,
same_as_pretrained=args.same_as_trained,
use_pretrained=args.trained,
dtype=args.dtype,
device=args.device,
patch=patch_dict,
rewrite=args.rewrite and patch_dict.get("patch", True),
stop_if_static=args.stop_if_static,
optimization=args.opt,
exporter=args.export,
dump_folder=args.dump_folder,
drop_inputs=None if not args.drop else args.drop.split(","),
input_options=args.iop,
model_options=args.mop,
subfolder=args.subfolder,
opset=args.opset,
runtime=args.runtime,
output_names=(
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
),
)
if args.dump_folder:
os.makedirs(args.dump_folder, exist_ok=True)
name = (
_make_folder_name(
model_id=args.model_id,
exporter=args.exporter,
optimization=args.optimization,
dtype=args.dtype,
device=args.device,
subfolder=args.subfolder,
opset=args.opset,
drop_inputs=None if not args.drop else args.drop.split(","),
same_as_pretrained=args.same_as_pretrained,
use_pretrained=args.use_pretrained,
task=args.task,
).replace("/", "-")
+ ".py"
)
fullname = os.path.join(args.dump_folder, name)
if args.verbose:
print(f"-- prints code in {fullname!r}")
print("--")
with open(fullname, "w") as f:
f.write(code)
if args.verbose:
print("-- done")
else:
print(code)


def get_parser_stats() -> ArgumentParser:
parser = ArgumentParser(
prog="stats",
Expand Down Expand Up @@ -960,14 +1048,15 @@ def get_main_parser() -> ArgumentParser:
Type 'python -m onnx_diagnostic <cmd> --help'
to get help for a specific command.

agg - aggregates statistics from multiple files
config - prints a configuration for a model id
find - find node consuming or producing a result
lighten - makes an onnx model lighter by removing the weights,
print - prints the model on standard output
stats - produces statistics on a model
unlighten - restores an onnx model produces by the previous experiment
validate - validate a model
agg - aggregates statistics from multiple files
config - prints a configuration for a model id
exportsample - produces a code to export a model
find - find node consuming or producing a result
lighten - makes an onnx model lighter by removing the weights,
print - prints the model on standard output
stats - produces statistics on a model
unlighten - restores an onnx model produces by the previous experiment
validate - validate a model
"""
),
)
Expand All @@ -976,6 +1065,7 @@ def get_main_parser() -> ArgumentParser:
choices=[
"agg",
"config",
"exportsample",
"find",
"lighten",
"print",
Expand All @@ -998,6 +1088,7 @@ def main(argv: Optional[List[Any]] = None):
validate=_cmd_validate,
stats=_cmd_stats,
agg=_cmd_agg,
exportsample=_cmd_export_sample,
)

if argv is None:
Expand All @@ -1020,6 +1111,7 @@ def main(argv: Optional[List[Any]] = None):
validate=get_parser_validate,
stats=get_parser_stats,
agg=get_parser_agg,
exportsample=get_parser_validate,
)
cmd = argv[0]
if cmd not in parsers:
Expand Down
Loading
Loading