diff --git a/_doc/api/helpers/index.rst b/_doc/api/helpers/index.rst index ede2174e..59f3846e 100644 --- a/_doc/api/helpers/index.rst +++ b/_doc/api/helpers/index.rst @@ -16,7 +16,7 @@ onnx_diagnostic.helpers onnx_helper ort_session rt_helper - torch_test_helper + torch_helper .. autofunction:: onnx_diagnostic.helpers.flatten_object diff --git a/_doc/api/helpers/torch_helper.rst b/_doc/api/helpers/torch_helper.rst new file mode 100644 index 00000000..e887ba81 --- /dev/null +++ b/_doc/api/helpers/torch_helper.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.helpers.torch_helper +==================================== + +.. automodule:: onnx_diagnostic.helpers.torch_helper + :members: + :no-undoc-members: diff --git a/_doc/api/helpers/torch_test_helper.rst b/_doc/api/helpers/torch_test_helper.rst deleted file mode 100644 index 64782990..00000000 --- a/_doc/api/helpers/torch_test_helper.rst +++ /dev/null @@ -1,7 +0,0 @@ - -onnx_diagnostic.helpers.torch_test_helper -========================================= - -.. automodule:: onnx_diagnostic.helpers.torch_test_helper - :members: - :no-undoc-members: diff --git a/_doc/examples/plot_export_tiny_llm.py b/_doc/examples/plot_export_tiny_llm.py index 05762187..06c90384 100644 --- a/_doc/examples/plot_export_tiny_llm.py +++ b/_doc/examples/plot_export_tiny_llm.py @@ -31,7 +31,7 @@ import transformers from onnx_diagnostic import doc from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.helpers.torch_test_helper import steal_forward +from onnx_diagnostic.helpers.torch_helper import steal_forward from onnx_diagnostic.torch_models.llms import get_tiny_llm @@ -77,7 +77,7 @@ def _forward_(*args, _f=None, **kwargs): model.forward = keep_model_forward # %% -# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steal_forward`. +# Another syntax with :func:`onnx_diagnostic.helpers.torch_helper.steal_forward`. with steal_forward(model): model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True) diff --git a/_unittests/ut_export/test_jit.py b/_unittests/ut_export/test_jit.py index 525eae01..0ae60482 100644 --- a/_unittests/ut_export/test_jit.py +++ b/_unittests/ut_export/test_jit.py @@ -8,7 +8,7 @@ requires_onnxscript, ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator -from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting +from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting try: from experimental_experiment.torch_interpreter import to_onnx diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index 5a380bdb..faa09fc6 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -28,16 +28,18 @@ get_onnx_signature, type_info, onnx_dtype_name, - onnx_dtype_to_torch_dtype, onnx_dtype_to_np_dtype, np_dtype_to_tensor_dtype, - torch_dtype_to_onnx_dtype, from_array_extended, to_array_extended, convert_endian, from_array_ml_dtypes, dtype_to_tensor_dtype, ) +from onnx_diagnostic.helpers.torch_helper import ( + onnx_dtype_to_torch_dtype, + torch_dtype_to_onnx_dtype, +) from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index b4fd46ce..9d22409e 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -1,9 +1,11 @@ import unittest +from typing import Any, Dict, List import numpy as np import onnx.helper as oh import onnx.numpy_helper as onh -from onnx import TensorProto +from onnx import TensorProto, FunctionProto, ValueInfoProto from onnx.checker import check_model +import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout from onnx_diagnostic.helpers.onnx_helper import ( onnx_lighten, @@ -11,13 +13,16 @@ onnx_find, _validate_function, check_model_ort, + iterator_initializer_constant, + from_array_extended, + tensor_statistics, ) TFLOAT = TensorProto.FLOAT -class TestOnnxTools(ExtTestCase): +class TestOnnxHelper(ExtTestCase): def _get_model(self): model = oh.make_model( @@ -122,6 +127,130 @@ def test_check_model_ort(self): ) check_model_ort(model) + def test_iterate_init(self): + itype = TensorProto.FLOAT + cst = np.arange(6).astype(np.float32) + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("IsNaN", ["x"], ["xi"]), + oh.make_node("IsNaN", ["y"], ["yi"]), + oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64), + oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64), + oh.make_node("Add", ["xii", "yii"], ["gggg"]), + oh.make_node("Cast", ["gggg"], ["final"], to=itype), + ], + "dummy", + [oh.make_tensor_value_info("x", itype, [None, None])], + [oh.make_tensor_value_info("final", itype, [None, None])], + [from_array_extended(cst, name="y")], + ), + opset_imports=[oh.make_opsetid("", 20)], + ir_version=10, + ) + li = list(iterator_initializer_constant(model)) + self.assertEqual(len(li), 1) + self.assertEqual(li[0][0], "y") + self.assertEqualArray(li[0][1], cst) + li = list(iterator_initializer_constant(model, use_numpy=False)) + self.assertEqual(len(li), 1) + self.assertEqual(li[0][0], "y") + self.assertEqualArray(li[0][1], cst) + self.assertIsInstance(li[0][1], torch.Tensor) + + def _get_cdist_implementation( + self, + node_inputs: List[str], + node_outputs: List[str], + opsets: Dict[str, int], + **kwargs: Any, + ) -> FunctionProto: + """ + Returns the CDist implementation as a function. + """ + assert len(node_inputs) == 2 + assert len(node_outputs) == 1 + assert opsets + assert "" in opsets + assert set(kwargs) == {"metric"}, f"kwargs={kwargs}" + metric = kwargs["metric"] + assert metric in ("euclidean", "sqeuclidean") + # subgraph + nodes = [ + oh.make_node("Sub", ["next", "next_in"], ["diff"]), + oh.make_node("Constant", [], ["axis"], value_ints=[1]), + oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0), + oh.make_node("Identity", ["next_in"], ["next_out"]), + ] + + def make_value(name): + value = ValueInfoProto() + value.name = name + return value + + graph = oh.make_graph( + nodes, + "loop", + [make_value("next_in"), make_value("next")], + [make_value("next_out"), make_value("scan_out")], + ) + + scan = oh.make_node( + "Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph + ) + final = ( + oh.make_node("Sqrt", ["zout"], ["z"]) + if metric == "euclidean" + else oh.make_node("Identity", ["zout"], ["z"]) + ) + return oh.make_function( + "npx", + f"CDist_{metric}", + ["xa", "xb"], + ["z"], + [scan, final], + [oh.make_opsetid("", opsets[""])], + ) + + def test_iterate_function(self): + itype = TensorProto.FLOAT + proto = self._get_cdist_implementation( + ["X", "Y"], ["Z"], opsets={"": 18}, metric="euclidean" + ) + model = oh.make_model( + oh.make_graph( + [ + oh.make_node(proto.name, ["X", "Y"], ["Z"]), + ], + "dummy", + [ + oh.make_tensor_value_info("X", itype, [None, None]), + oh.make_tensor_value_info("Y", itype, [None, None]), + ], + [oh.make_tensor_value_info("final", itype, [None, None])], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=10, + ) + model.functions.append(proto) + li = list(iterator_initializer_constant(model)) + self.assertEqual(len(li), 1) + self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis") + self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64)) + li = list(iterator_initializer_constant(model, use_numpy=False)) + self.assertEqual(len(li), 1) + self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis") + self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64)) + self.assertIsInstance(li[0][1], torch.Tensor) + + def test_statistics(self): + rnd = np.random.rand(40, 50).astype(np.float16) + stat = tensor_statistics(rnd) + self.assertEqual(stat["stype"], "FLOAT16") + rnd = np.random.rand(40, 50).astype(np.float32) + stat = tensor_statistics(rnd) + self.assertEqual(stat["stype"], "FLOAT") + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_ort_session.py b/_unittests/ut_helpers/test_ort_session.py index ca257794..dc899f2e 100644 --- a/_unittests/ut_helpers/test_ort_session.py +++ b/_unittests/ut_helpers/test_ort_session.py @@ -12,11 +12,8 @@ requires_onnxruntime_training, requires_cuda, ) -from onnx_diagnostic.helpers.onnx_helper import ( - from_array_extended, - onnx_dtype_to_np_dtype, - onnx_dtype_to_torch_dtype, -) +from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype +from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype from onnx_diagnostic.helpers.ort_session import ( InferenceSessionForNumpy, InferenceSessionForTorch, diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_helper.py similarity index 91% rename from _unittests/ut_helpers/test_torch_test_helper.py rename to _unittests/ut_helpers/test_torch_helper.py index e53ee65a..c71b6ea9 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -1,11 +1,12 @@ import unittest +import numpy as np import ml_dtypes import onnx import torch import transformers from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout from onnx_diagnostic.helpers import max_diff, string_type -from onnx_diagnostic.helpers.torch_test_helper import ( +from onnx_diagnostic.helpers.torch_helper import ( dummy_llm, to_numpy, is_torchdynamo_exporting, @@ -24,6 +25,8 @@ make_sliding_window_cache, ) from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model +from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended +from onnx_diagnostic.helpers.torch_helper import to_tensor TFLOAT = onnx.TensorProto.FLOAT @@ -205,7 +208,7 @@ def forward(self, x, y): else: print("output", k, v) print(string_type(restored, with_shape=True)) - l1, l2 = 183, 192 + l1, l2 = 186, 195 self.assertEqual( [ (f"-Model-{l2}", 0, "I"), @@ -344,6 +347,35 @@ def forward(self, x, y=None): stat, ) + def test_to_tensor(self): + for dtype in [ + np.int8, + np.uint8, + np.int16, + np.uint16, + np.int32, + np.uint32, + np.int64, + np.uint64, + np.float16, + np.float32, + np.float64, + ]: + with self.subTest(dtype=dtype): + a = np.random.rand(4, 5).astype(dtype) + proto = from_array_extended(a) + b = to_array_extended(proto) + self.assertEqualArray(a, b) + c = to_tensor(proto) + self.assertEqualArray(a, c) + + for dtype in [torch.bfloat16]: + with self.subTest(dtype=dtype): + a = torch.rand((4, 5), dtype=dtype) + proto = from_array_extended(a) + c = to_tensor(proto) + self.assertEqualArray(a, c) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_reference/test_ort_evaluator.py b/_unittests/ut_reference/test_ort_evaluator.py index 37ca827d..653c18b0 100644 --- a/_unittests/ut_reference/test_ort_evaluator.py +++ b/_unittests/ut_reference/test_ort_evaluator.py @@ -14,11 +14,8 @@ ignore_warnings, requires_cuda, ) -from onnx_diagnostic.helpers.onnx_helper import ( - from_array_extended, - onnx_dtype_to_torch_dtype, - onnx_dtype_to_np_dtype, -) +from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype +from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.helpers.ort_session import _InferenceSession diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index 76fe731e..9217da63 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -1,7 +1,7 @@ import unittest from onnx_diagnostic.ext_test_case import ExtTestCase, never_test from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.helpers.torch_test_helper import steal_forward +from onnx_diagnostic.helpers.torch_helper import steal_forward class TestHuggingFaceHubModel(ExtTestCase): diff --git a/_unittests/ut_torch_export_patches/test_patch_expressions.py b/_unittests/ut_torch_export_patches/test_patch_expressions.py index 610d9757..b65908c0 100644 --- a/_unittests/ut_torch_export_patches/test_patch_expressions.py +++ b/_unittests/ut_torch_export_patches/test_patch_expressions.py @@ -7,7 +7,7 @@ patched_selector, patched_float_arange, ) -from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting +from onnx_diagnostic.helpers.torch_helper import fake_torchdynamo_exporting class TestOnnxExportErrors(ExtTestCase): diff --git a/_unittests/ut_torch_export_patches/test_patch_loops.py b/_unittests/ut_torch_export_patches/test_patch_loops.py index 346d9cb4..4ec01c4a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_loops.py +++ b/_unittests/ut_torch_export_patches/test_patch_loops.py @@ -1,7 +1,7 @@ import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch -from onnx_diagnostic.helpers.torch_test_helper import ( +from onnx_diagnostic.helpers.torch_helper import ( is_torchdynamo_exporting, fake_torchdynamo_exporting, ) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index 1e0bcb12..851627a5 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -11,7 +11,7 @@ from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( torch_export_patches, ) -from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy class TestPatchSerialization(ExtTestCase): diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index e15fdebb..9f2eabb6 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -8,6 +8,7 @@ get_parser_find, get_parser_lighten, get_parser_print, + get_parser_stats, get_parser_unlighten, get_parser_validate, ) @@ -63,6 +64,13 @@ def test_parser_validate(self): text = st.getvalue() self.assertIn("mid", text) + def test_parser_stats(self): + st = StringIO() + with redirect_stdout(st): + get_parser_stats().print_help() + text = st.getvalue() + self.assertIn("input", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index 09bf111b..e9beb9e6 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -20,6 +20,15 @@ def test_parser_print(self): text = st.getvalue() self.assertIn("Add", text) + def test_parser_stats(self): + output = self.get_dump_file("test_parser_stats.xlsx") + st = StringIO() + with redirect_stdout(st): + main(["stats", "-i", self.dummy_path, "-o", output, "-r", ".*"]) + text = st.getvalue() + self.assertIn("processing", text) + self.assertExists(output) + def test_parser_find(self): st = StringIO() with redirect_stdout(st): diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 39ec66ce..9cd6a3cb 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1,5 +1,7 @@ import argparse import json +import os +import re import sys import textwrap import onnx @@ -425,6 +427,106 @@ def _cmd_validate(argv: List[Any]): print(f":{k},{v};") +def get_parser_stats() -> ArgumentParser: + parser = ArgumentParser( + prog="stats", + description=dedent( + """ + Prints out statistics on an ONNX model. + """ + ), + epilog="", + ) + parser.add_argument( + "-i", + "--input", + type=str, + required=True, + help="ONNX file", + ) + parser.add_argument( + "-o", + "--output", + required=False, + default="", + help="outputs the statistics in a file", + ) + parser.add_argument( + "-v", + "--verbose", + required=False, + default=1, + type=int, + help="verbosity", + ) + parser.add_argument( + "-e", + "--end", + required=False, + default=-1, + type=int, + help="ends after this many tensors", + ) + parser.add_argument( + "-b", + "--begin", + required=False, + default=0, + type=int, + help="starts after this many tensors", + ) + parser.add_argument( + "-r", + "--regex", + required=False, + default="", + type=str, + help="keeps only tensors whose name verifies " + "this regular expression, empty = no filter", + ) + return parser + + +def _cmd_stats(argv: List[Any]): + from .helpers.onnx_helper import iterator_initializer_constant, tensor_statistics + + parser = get_parser_stats() + args = parser.parse_args(argv[1:]) + assert os.path.exists(args.input), f"Missing filename {args.input!r}" + if args.verbose: + print(f"Loading {args.input}") + onx = onnx.load(args.input) + reg = re.compile(args.regex) if args.regex else None + data = [] + for index, (name, init) in enumerate(iterator_initializer_constant(onx)): + if reg and not reg.search(name): + continue + if index < args.begin: + continue + if args.end > 0 and index >= args.end: + break + if args.verbose: + print(f"processing {index + 1}: {name!r}") + stats = tensor_statistics(init) + if not args.output: + print(f"{name}: {stats}") + stats["name"] = name + data.append(stats) + if args.output: + if args.verbose: + print(f"saving into {args.output!r}") + import pandas + + df = pandas.DataFrame(data) + ext = os.path.splitext(args.output) + if ext[-1] == ".xlsx": + df.to_excel(args.output, index=False) + else: + df.to_csv(args.output, index=False) + if args.verbose: + print("done.") + + def get_main_parser() -> ArgumentParser: parser = ArgumentParser( prog="onnx_diagnostic", @@ -441,12 +543,13 @@ def get_main_parser() -> ArgumentParser: unlighten - restores an onnx model produces by the previous experiment print - prints the model on standard output validate - validate a model + stats - produces statistics on a model """ ), ) parser.add_argument( "cmd", - choices=["config", "find", "lighten", "print", "unlighten", "validate"], + choices=["config", "find", "lighten", "print", "stats", "unlighten", "validate"], help="Selects a command.", ) return parser @@ -460,6 +563,7 @@ def main(argv: Optional[List[Any]] = None): find=_cmd_find, config=_cmd_config, validate=_cmd_validate, + stats=_cmd_stats, ) if argv is None: @@ -480,6 +584,7 @@ def main(argv: Optional[List[Any]] = None): find=get_parser_find, config=get_parser_config, validate=get_parser_validate, + stats=get_parser_stats, ) cmd = argv[0] if cmd not in parsers: diff --git a/onnx_diagnostic/export/validate.py b/onnx_diagnostic/export/validate.py index f98b13f1..8e7a000f 100644 --- a/onnx_diagnostic/export/validate.py +++ b/onnx_diagnostic/export/validate.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from ..helpers import string_type, max_diff, string_diff -from ..helpers.torch_test_helper import torch_deepcopy +from ..helpers.torch_helper import torch_deepcopy from .dynamic_shapes import CoupleInputsDynamicShapes diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 7e006e57..9662dc74 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -871,7 +871,7 @@ def assertEqualArray( raise AssertionError("\n".join(rows)) # noqa: B904 return - from .helpers.torch_test_helper import to_numpy + from .helpers.torch_helper import to_numpy if hasattr(expected, "detach"): expected = to_numpy(expected.detach().cpu()) diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index c2a69195..51a1c054 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -427,7 +427,7 @@ def string_type( # Tensors if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor): - from .onnx_helper import torch_dtype_to_onnx_dtype + from .torch_helper import torch_dtype_to_onnx_dtype i = torch_dtype_to_onnx_dtype(obj.dtype) prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" @@ -439,7 +439,7 @@ def string_type( print(f"[string_type] F2:{type(obj)}") return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}" if isinstance(obj, torch.Tensor): - from .onnx_helper import torch_dtype_to_onnx_dtype + from .torch_helper import torch_dtype_to_onnx_dtype if with_min_max: s = string_type(obj, with_shape=with_shape, with_device=with_device) @@ -1260,10 +1260,21 @@ def max_diff( return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) # nan are replace by 1e10, any discrepancies in that order of magnitude # is likely caused by nans - exp_cpu = expected.to(torch.float64).cpu().nan_to_num(1e10) - got_cpu = got.to(torch.float64).cpu().nan_to_num(1e10) + exp_cpu = expected.to(torch.float64).nan_to_num(1e10) + got_cpu = got.to(torch.float64).nan_to_num(1e10) + if got_cpu.device != exp_cpu.device: + if torch.device("cuda:0") in {got_cpu.device, exp_cpu.device}: + got_cpu = got_cpu.to("cuda:0") + exp_cpu = exp_cpu.to("cuda:0") + expected = expected.to("cuda:0") + got = got.to("cuda:0") + else: + got_cpu = got_cpu.detach().to("cpu") + exp_cpu = exp_cpu.detach().to("cpu") + expected = expected.to("cpu") + got = got.to("cpu") diff = (got_cpu - exp_cpu).abs() - ndiff = (expected.isnan().cpu().to(int) - got.isnan().cpu().to(int)).abs() + ndiff = (expected.isnan().to(int) - got.isnan().to(int)).abs() rdiff = diff / (exp_cpu.abs() + 1e-3) if diff.numel() > 0: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( @@ -1320,6 +1331,7 @@ def max_diff( hist = torch.tensor( [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype ) + hist = hist.to(diff.device) ind = torch.bucketize(diff.reshape((-1,)), hist, right=False) cou = torch.bincount(ind, minlength=ind.shape[0] + 1) res["rep"] = dict( diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 0558b2a6..ec9cab1d 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1,9 +1,9 @@ -import ctypes import functools import json import os import sys -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +import warnings +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union import numpy as np import numpy.typing as npt import onnx @@ -331,9 +331,10 @@ def onnx_dtype_name(itype: int) -> str: print(onnx_dtype_name(7)) """ for k in dir(TensorProto): - v = getattr(TensorProto, k) - if v == itype: - return k + if "FLOAT" in k or "INT" in k or "TEXT" in k or "BOOL" in k: + v = getattr(TensorProto, k) + if v == itype: + return k raise ValueError(f"Unexpected value itype: {itype}") @@ -518,76 +519,6 @@ def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> Tens } -def proto_from_tensor( - arr: "torch.Tensor", # noqa: F821 - name: Optional[str] = None, - verbose: int = 0, -) -> TensorProto: - """ - Converts a torch Tensor into a TensorProto. - - :param arr: tensor - :param verbose: display the type and shape - :return: a TensorProto - """ - import torch - - if not isinstance(arr, torch.Tensor): - raise TypeError(f"Unexpected type {type(arr)}.") - if arr.is_sparse: - raise NotImplementedError( - f"Sparse tensor is not supported yet but initializer {name!r} is." - ) - - # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this. - if arr.is_contiguous(): - arr_cpu = arr.cpu() - else: - arr_cpu = arr.contiguous().cpu() - - numel = torch.numel(arr_cpu) - element_size = arr_cpu.element_size() - - if arr_cpu.dtype in {torch.bfloat16}: - np_arr = arr_cpu - elif arr_cpu.data_ptr() == arr.data_ptr(): - copy = arr_cpu.clone().detach().requires_grad_(False) - assert ( - arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr() - ), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}" - np_arr = np.from_dlpack(copy) - else: - np_arr = np.from_dlpack(arr_cpu.detach()) - - tensor = TensorProto() - tensor.dims.extend(arr_cpu.shape) - if name: - tensor.name = name - itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype) - assert not hasattr(TensorProto, "INT4") or itype not in { - TensorProto.INT4, - TensorProto.UINT4, - }, f"Type {arr.dtype} is not supported yet for name={name!r}" - tensor.data_type = itype - - if verbose > 1 and numel > 100: - print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]") - - if isinstance(np_arr, torch.Tensor): - byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr()) - tensor.raw_data = bytes(byte_data) - if sys.byteorder == "big": - np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore - np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore - else: - tensor.raw_data = np_arr.tobytes() - if sys.byteorder == "big": - np_dtype = tensor_dtype_to_np_dtype(tensor.data_type) - np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) - - return tensor - - def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto: """ Converts an array into a :class:`onnx.TensorProto`. @@ -596,11 +527,13 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te :param name: name :return: TensorProto """ - try: + if not isinstance(tensor, np.ndarray): import torch - except ImportError: - torch = None - if torch is not None and isinstance(tensor, torch.Tensor): + from .torch_helper import proto_from_tensor + + assert isinstance( + tensor, torch.Tensor + ), f"Unable to convert type {type(tensor)} into TensorProto." return proto_from_tensor(tensor, name=name) from onnx.reference.ops.op_cast import ( @@ -657,50 +590,6 @@ def to_array_extended(proto: TensorProto) -> npt.ArrayLike: return arr -def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821 - """ - Converts an onnx type into a torch dtype. - - :param to: onnx dtype - :return: torch dtype - """ - import torch - - if itype == TensorProto.FLOAT: - return torch.float32 - if itype == TensorProto.FLOAT16: - return torch.float16 - if itype == TensorProto.BFLOAT16: - return torch.bfloat16 - if itype == TensorProto.DOUBLE: - return torch.float64 - if itype == TensorProto.INT32: - return torch.int32 - if itype == TensorProto.INT64: - return torch.int64 - if itype == TensorProto.UINT32: - return torch.uint32 - if itype == TensorProto.UINT64: - return torch.uint64 - if itype == TensorProto.BOOL: - return torch.bool - if itype == TensorProto.INT16: - return torch.int16 - if itype == TensorProto.UINT16: - return torch.uint16 - if itype == TensorProto.INT8: - return torch.int16 - if itype == TensorProto.UINT8: - return torch.uint16 - if itype == TensorProto.COMPLEX64: - return torch.complex64 - if itype == TensorProto.COMPLEX128: - return torch.complex128 - raise NotImplementedError( - f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type." - ) - - def onnx_dtype_to_np_dtype(itype: int) -> Any: """ Converts an onnx type into a to numpy dtype. @@ -746,52 +635,6 @@ def onnx_dtype_to_np_dtype(itype: int) -> Any: ) -def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821 - """ - Converts a torch dtype into a onnx element type. - - :param to: torch dtype - :return: onnx type - """ - import torch - - if to == torch.float32: - return TensorProto.FLOAT - if to == torch.float16: - return TensorProto.FLOAT16 - if to == torch.bfloat16: - return TensorProto.BFLOAT16 - if to == torch.float64: - return TensorProto.DOUBLE - if to == torch.int64: - return TensorProto.INT64 - if to == torch.int32: - return TensorProto.INT32 - if to == torch.uint64: - return TensorProto.UINT64 - if to == torch.uint32: - return TensorProto.UINT32 - if to == torch.bool: - return TensorProto.BOOL - if to == torch.SymInt: - return TensorProto.INT64 - if to == torch.int16: - return TensorProto.INT16 - if to == torch.uint16: - return TensorProto.UINT16 - if to == torch.int8: - return TensorProto.INT8 - if to == torch.uint8: - return TensorProto.UINT8 - if to == torch.SymFloat: - return TensorProto.FLOAT - if to == torch.complex64: - return TensorProto.COMPLEX64 - if to == torch.complex128: - return TensorProto.COMPLEX128 - raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.") - - def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821 """ Converts a torch dtype or numpy dtype into a onnx element type. @@ -803,6 +646,8 @@ def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F return np_dtype_to_tensor_dtype(dt) except (KeyError, TypeError, ValueError): pass + from .torch_helper import torch_dtype_to_onnx_dtype + return torch_dtype_to_onnx_dtype(dt) @@ -919,3 +764,147 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: return mapping[tensor_dtype] return oh.tensor_dtype_to_np_dtype(tensor_dtype) + + +def iterator_initializer_constant( + model: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto], + use_numpy: bool = True, + prefix: str = "", +) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821 + """ + Iterates on iniatialiers and constant in an onnx model. + + :param model: model + :param use_numpy: use numpy or pytorch + :param prefix: for subgraph + :return: iterator + """ + if not isinstance(model, onnx.FunctionProto): + graph = model if isinstance(model, onnx.GraphProto) else model.graph + if not use_numpy: + from .torch_helper import to_tensor + if prefix: + prefix += "." + for init in graph.initializer: + yield f"{prefix}{init.name}", ( + to_array_extended(init) if use_numpy else to_tensor(init) + ) + nodes = graph.node + name = graph.name + if isinstance(model, onnx.ModelProto): + for f in model.functions: + yield from iterator_initializer_constant( + f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}" + ) + else: + nodes = model.node + name = model.name + for node in nodes: + if node.op_type == "Constant" and node.domain == "": + from ..reference import ExtendedReferenceEvaluator as Inference + + if not use_numpy: + import torch + sess = Inference(node) + value = sess.run(None, {})[0] + yield f"{prefix}{node.output[0]}", ( + value if use_numpy else torch.from_numpy(value) + ) + + if node.op_type in {"Loop", "Body", "Scan"}: + for att in node.attribute: + assert ( + att.type != onnx.AttributeProto.GRAPHS + ), "Not implemented for type AttributeProto.GRAPHS." + if att.type == onnx.AttributeProto.GRAPH: + yield from iterator_initializer_constant( + att.g, use_numpy=use_numpy, prefix=f"{prefix}{name}" + ) + + +def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union[float, str]]: + """ + Produces statistics on a tensor. + + :param tensor: tensor + :return: statistics + + .. runpython:: + :showcode: + + import pprint + import numpy as np + from onnx_diagnostic.helper.onnx_helper import tensor_statistics + + t = np.random.rand(40, 50).astype(np.float16) + pprint.pprint(tensor_statistics(t)) + """ + from .helper import size_type + + if isinstance(tensor, TensorProto): + tensor = to_array_extended(tensor) + itype = np_dtype_to_tensor_dtype(tensor.dtype) + stat = dict( + mean=float(tensor.mean()), + std=float(tensor.std()), + shape="x".join(map(str, tensor.shape)), + numel=tensor.size, + size=tensor.size * size_type(tensor.dtype), + itype=itype, + stype=onnx_dtype_name(itype), + min=float(tensor.min()), + max=float(tensor.max()), + nnan=float(np.isnan(tensor).sum()), + ) + + if tensor.size < 8: + return stat + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + hist = np.array( + [ + 0, + 1e-10, + 1e-8, + 1e-7, + 1e-6, + 1e-5, + 0.0001, + 0.001, + 0.01, + 0.1, + 0.5, + 1, + 1.96, + 10, + 1e2, + 1e3, + 1e4, + 1e5, + 1e6, + 1e7, + 1e8, + 1e10, + 1e50, + ], + dtype=tensor.dtype, + ) + except OverflowError as e: + from .helper import string_type + + raise ValueError( + f"Unable to convert one value into {tensor.dtype}, " + f"tensor={string_type(tensor, with_shape=True)}" + ) from e + hist = np.array(sorted(set(hist[~np.isinf(hist)])), dtype=tensor.dtype) + ind = np.digitize(np.abs(tensor).reshape((-1,)), hist, right=True) + cou = np.bincount(ind, minlength=ind.shape[0] + 1) + stat.update( + dict(zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])) + ) + ii = (np.arange(9) + 1) / 10 + qu = np.quantile(tensor, ii) + stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)}) + return stat diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 79937e6b..8138f896 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -8,11 +8,12 @@ from onnxruntime.capi import _pybind_state as ORTC from .helper import size_type from .onnx_helper import ( - torch_dtype_to_onnx_dtype, onnx_dtype_to_np_dtype, np_dtype_to_tensor_dtype, onnx_dtype_name, ) +from .torch_helper import torch_dtype_to_onnx_dtype + DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)} diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_helper.py similarity index 76% rename from onnx_diagnostic/helpers/torch_test_helper.py rename to onnx_diagnostic/helpers/torch_helper.py index d0c1fcb2..96fdd60b 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -1,10 +1,14 @@ import contextlib +import ctypes import inspect import os +import sys +import warnings from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import onnx +from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data import torch from .helper import string_type, size_type from .cache_helper import ( @@ -14,6 +18,171 @@ make_mamba_cache, ) from .mini_onnx_builder import create_onnx_model_from_input_tensors +from .onnx_helper import ( + to_array_extended, + tensor_dtype_to_np_dtype, + _STORAGE_TYPE, + onnx_dtype_name, +) + + +def proto_from_tensor( + arr: "torch.Tensor", # noqa: F821 + name: Optional[str] = None, + verbose: int = 0, +) -> onnx.TensorProto: + """ + Converts a torch Tensor into a TensorProto. + + :param arr: tensor + :param verbose: display the type and shape + :return: a TensorProto + """ + import torch + + if not isinstance(arr, torch.Tensor): + raise TypeError(f"Unexpected type {type(arr)}.") + if arr.is_sparse: + raise NotImplementedError( + f"Sparse tensor is not supported yet but initializer {name!r} is." + ) + + # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this. + if arr.is_contiguous(): + arr_cpu = arr.cpu() + else: + arr_cpu = arr.contiguous().cpu() + + numel = torch.numel(arr_cpu) + element_size = arr_cpu.element_size() + + if arr_cpu.dtype in {torch.bfloat16}: + np_arr = arr_cpu + elif arr_cpu.data_ptr() == arr.data_ptr(): + copy = arr_cpu.clone().detach().requires_grad_(False) + assert ( + arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr() + ), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}" + np_arr = np.from_dlpack(copy) + else: + np_arr = np.from_dlpack(arr_cpu.detach()) + + tensor = onnx.TensorProto() + tensor.dims.extend(arr_cpu.shape) + if name: + tensor.name = name + itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype) + assert not hasattr(onnx.TensorProto, "INT4") or itype not in { + onnx.TensorProto.INT4, + onnx.TensorProto.UINT4, + }, f"Type {arr.dtype} is not supported yet for name={name!r}" + tensor.data_type = itype + + if verbose > 1 and numel > 100: + print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]") + + if isinstance(np_arr, torch.Tensor): + byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr()) + tensor.raw_data = bytes(byte_data) + if sys.byteorder == "big": + np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore + np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore + else: + tensor.raw_data = np_arr.tobytes() + if sys.byteorder == "big": + np_dtype = tensor_dtype_to_np_dtype(tensor.data_type) + np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) + return tensor + + +def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821 + """ + Converts an onnx type into a torch dtype. + + :param to: onnx dtype + :return: torch dtype + """ + import torch + + if itype == onnx.TensorProto.FLOAT: + return torch.float32 + if itype == onnx.TensorProto.FLOAT16: + return torch.float16 + if itype == onnx.TensorProto.BFLOAT16: + return torch.bfloat16 + if itype == onnx.TensorProto.DOUBLE: + return torch.float64 + if itype == onnx.TensorProto.INT32: + return torch.int32 + if itype == onnx.TensorProto.INT64: + return torch.int64 + if itype == onnx.TensorProto.UINT32: + return torch.uint32 + if itype == onnx.TensorProto.UINT64: + return torch.uint64 + if itype == onnx.TensorProto.BOOL: + return torch.bool + if itype == onnx.TensorProto.INT16: + return torch.int16 + if itype == onnx.TensorProto.UINT16: + return torch.uint16 + if itype == onnx.TensorProto.INT8: + return torch.int8 + if itype == onnx.TensorProto.UINT8: + return torch.uint8 + if itype == onnx.TensorProto.COMPLEX64: + return torch.complex64 + if itype == onnx.TensorProto.COMPLEX128: + return torch.complex128 + raise NotImplementedError( + f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type." + ) + + +def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821 + """ + Converts a torch dtype into a onnx element type. + + :param to: torch dtype + :return: onnx type + """ + import torch + + if to == torch.float32: + return onnx.TensorProto.FLOAT + if to == torch.float16: + return onnx.TensorProto.FLOAT16 + if to == torch.bfloat16: + return onnx.TensorProto.BFLOAT16 + if to == torch.float64: + return onnx.TensorProto.DOUBLE + if to == torch.int64: + return onnx.TensorProto.INT64 + if to == torch.int32: + return onnx.TensorProto.INT32 + if to == torch.uint64: + return onnx.TensorProto.UINT64 + if to == torch.uint32: + return onnx.TensorProto.UINT32 + if to == torch.bool: + return onnx.TensorProto.BOOL + if to == torch.SymInt: + return onnx.TensorProto.INT64 + if to == torch.int16: + return onnx.TensorProto.INT16 + if to == torch.uint16: + return onnx.TensorProto.UINT16 + if to == torch.int8: + return onnx.TensorProto.INT8 + if to == torch.uint8: + return onnx.TensorProto.UINT8 + if to == torch.SymFloat: + return onnx.TensorProto.FLOAT + if to == torch.complex64: + return onnx.TensorProto.COMPLEX64 + if to == torch.complex128: + return onnx.TensorProto.COMPLEX128 + raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.") def _forward_( @@ -144,7 +313,7 @@ def steal_forward( :showcode: import torch - from onnx_diagnostic.helpers.torch_test_helper import steal_forward + from onnx_diagnostic.helpers.torch_helper import steal_forward class SubModel(torch.nn.Module): def forward(self, x): @@ -331,7 +500,7 @@ def dummy_llm( .. runpython:: :showcode: - from onnx_diagnostic.helpers.torch_test_helper import dummy_llm + from onnx_diagnostic.helpers.torch_helper import dummy_llm print(dummy_llm()) """ @@ -656,3 +825,38 @@ def model_statistics(model: torch.nn.Module): ) res.update(sizes) return res + + +def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor: + """ + Converts a TensorProto to a numpy array. + + :param tensor: a TensorProto object. + :param base_dir: if external tensor exists, base_dir can help to find the path to it + :return: the converted tensor + """ + assert not tensor.HasField("segment"), "Currently not supporting loading segments." + assert ( + tensor.data_type != onnx.TensorProto.UNDEFINED + ), "The element type in the input tensor is not defined." + assert tensor.data_type != onnx.TensorProto.STRING, "to_tensor not implemented for strings" + + tensor_dtype = tensor.data_type + torch_dtype = onnx_dtype_to_torch_dtype(tensor_dtype) + dims = tuple(tensor.dims) + if uses_external_data(tensor): + # Load raw data from external tensor if it exists + load_external_data_for_tensor(tensor, base_dir) + + if tensor.HasField("raw_data"): + raw_data = tensor.raw_data + if sys.byteorder == "big": + # Convert endian from little to big + raw_data = torch.frombuffer(raw_data, dtype=torch_dtype).byteswap().tobytes() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return torch.frombuffer(raw_data, dtype=torch_dtype).reshape(dims) + + # Other cases, it should be small tensor. We use numpy. + np_tensor = to_array_extended(tensor) + return torch.from_numpy(np_tensor) diff --git a/onnx_diagnostic/torch_export_patches/patch_expressions.py b/onnx_diagnostic/torch_export_patches/patch_expressions.py index a25f0d5f..0b6b1990 100644 --- a/onnx_diagnostic/torch_export_patches/patch_expressions.py +++ b/onnx_diagnostic/torch_export_patches/patch_expressions.py @@ -1,6 +1,6 @@ from typing import Callable, Set import torch -from ..helpers.torch_test_helper import is_torchdynamo_exporting +from ..helpers.torch_helper import is_torchdynamo_exporting def make_undefined_dimension(i: int) -> torch.SymInt: diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index a27005ce..fa0facc8 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -6,7 +6,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.cache_utils import StaticCache, Cache, DynamicCache from ...ext_test_case import has_transformers -from ...helpers.torch_test_helper import is_torchdynamo_exporting +from ...helpers.torch_helper import is_torchdynamo_exporting def _patch_make_causal_mask( diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 5d1d472d..b187de3d 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -11,7 +11,7 @@ from ..helpers import max_diff, string_type, string_diff from ..helpers.helper import flatten_object from ..helpers.rt_helper import make_feeds -from ..helpers.torch_test_helper import to_any, torch_deepcopy +from ..helpers.torch_helper import to_any, torch_deepcopy from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes from ..tasks import random_input_kwargs from ..torch_export_patches import torch_export_patches