Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _doc/api/helpers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ onnx_diagnostic.helpers
onnx_helper
ort_session
rt_helper
torch_test_helper
torch_helper

.. autofunction:: onnx_diagnostic.helpers.flatten_object

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/helpers/torch_helper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.helpers.torch_helper
====================================

.. automodule:: onnx_diagnostic.helpers.torch_helper
:members:
:no-undoc-members:
7 changes: 0 additions & 7 deletions _doc/api/helpers/torch_test_helper.rst

This file was deleted.

4 changes: 2 additions & 2 deletions _doc/examples/plot_export_tiny_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm


Expand Down Expand Up @@ -77,7 +77,7 @@ def _forward_(*args, _f=None, **kwargs):
model.forward = keep_model_forward

# %%
# Another syntax with :func:`onnx_diagnostic.helpers.torch_test_helper.steal_forward`.
# Another syntax with :func:`onnx_diagnostic.helpers.torch_helper.steal_forward`.

with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_export/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
requires_onnxscript,
)
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting

try:
from experimental_experiment.torch_interpreter import to_onnx
Expand Down
6 changes: 4 additions & 2 deletions _unittests/ut_helpers/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@
get_onnx_signature,
type_info,
onnx_dtype_name,
onnx_dtype_to_torch_dtype,
onnx_dtype_to_np_dtype,
np_dtype_to_tensor_dtype,
torch_dtype_to_onnx_dtype,
from_array_extended,
to_array_extended,
convert_endian,
from_array_ml_dtypes,
dtype_to_tensor_dtype,
)
from onnx_diagnostic.helpers.torch_helper import (
onnx_dtype_to_torch_dtype,
torch_dtype_to_onnx_dtype,
)
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config

Expand Down
133 changes: 131 additions & 2 deletions _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import unittest
from typing import Any, Dict, List
import numpy as np
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx import TensorProto
from onnx import TensorProto, FunctionProto, ValueInfoProto
from onnx.checker import check_model
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers.onnx_helper import (
onnx_lighten,
onnx_unlighten,
onnx_find,
_validate_function,
check_model_ort,
iterator_initializer_constant,
from_array_extended,
tensor_statistics,
)


TFLOAT = TensorProto.FLOAT


class TestOnnxTools(ExtTestCase):
class TestOnnxHelper(ExtTestCase):

def _get_model(self):
model = oh.make_model(
Expand Down Expand Up @@ -122,6 +127,130 @@ def test_check_model_ort(self):
)
check_model_ort(model)

def test_iterate_init(self):
itype = TensorProto.FLOAT
cst = np.arange(6).astype(np.float32)
model = oh.make_model(
oh.make_graph(
[
oh.make_node("IsNaN", ["x"], ["xi"]),
oh.make_node("IsNaN", ["y"], ["yi"]),
oh.make_node("Cast", ["xi"], ["xii"], to=TensorProto.INT64),
oh.make_node("Cast", ["yi"], ["yii"], to=TensorProto.INT64),
oh.make_node("Add", ["xii", "yii"], ["gggg"]),
oh.make_node("Cast", ["gggg"], ["final"], to=itype),
],
"dummy",
[oh.make_tensor_value_info("x", itype, [None, None])],
[oh.make_tensor_value_info("final", itype, [None, None])],
[from_array_extended(cst, name="y")],
),
opset_imports=[oh.make_opsetid("", 20)],
ir_version=10,
)
li = list(iterator_initializer_constant(model))
self.assertEqual(len(li), 1)
self.assertEqual(li[0][0], "y")
self.assertEqualArray(li[0][1], cst)
li = list(iterator_initializer_constant(model, use_numpy=False))
self.assertEqual(len(li), 1)
self.assertEqual(li[0][0], "y")
self.assertEqualArray(li[0][1], cst)
self.assertIsInstance(li[0][1], torch.Tensor)

def _get_cdist_implementation(
self,
node_inputs: List[str],
node_outputs: List[str],
opsets: Dict[str, int],
**kwargs: Any,
) -> FunctionProto:
"""
Returns the CDist implementation as a function.
"""
assert len(node_inputs) == 2
assert len(node_outputs) == 1
assert opsets
assert "" in opsets
assert set(kwargs) == {"metric"}, f"kwargs={kwargs}"
metric = kwargs["metric"]
assert metric in ("euclidean", "sqeuclidean")
# subgraph
nodes = [
oh.make_node("Sub", ["next", "next_in"], ["diff"]),
oh.make_node("Constant", [], ["axis"], value_ints=[1]),
oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0),
oh.make_node("Identity", ["next_in"], ["next_out"]),
]

def make_value(name):
value = ValueInfoProto()
value.name = name
return value

graph = oh.make_graph(
nodes,
"loop",
[make_value("next_in"), make_value("next")],
[make_value("next_out"), make_value("scan_out")],
)

scan = oh.make_node(
"Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph
)
final = (
oh.make_node("Sqrt", ["zout"], ["z"])
if metric == "euclidean"
else oh.make_node("Identity", ["zout"], ["z"])
)
return oh.make_function(
"npx",
f"CDist_{metric}",
["xa", "xb"],
["z"],
[scan, final],
[oh.make_opsetid("", opsets[""])],
)

def test_iterate_function(self):
itype = TensorProto.FLOAT
proto = self._get_cdist_implementation(
["X", "Y"], ["Z"], opsets={"": 18}, metric="euclidean"
)
model = oh.make_model(
oh.make_graph(
[
oh.make_node(proto.name, ["X", "Y"], ["Z"]),
],
"dummy",
[
oh.make_tensor_value_info("X", itype, [None, None]),
oh.make_tensor_value_info("Y", itype, [None, None]),
],
[oh.make_tensor_value_info("final", itype, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
model.functions.append(proto)
li = list(iterator_initializer_constant(model))
self.assertEqual(len(li), 1)
self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis")
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
li = list(iterator_initializer_constant(model, use_numpy=False))
self.assertEqual(len(li), 1)
self.assertEqual(li[0][0], "CDist_euclideanCDist_euclidean.axis")
self.assertEqualArray(li[0][1], np.array([1], dtype=np.int64))
self.assertIsInstance(li[0][1], torch.Tensor)

def test_statistics(self):
rnd = np.random.rand(40, 50).astype(np.float16)
stat = tensor_statistics(rnd)
self.assertEqual(stat["stype"], "FLOAT16")
rnd = np.random.rand(40, 50).astype(np.float32)
stat = tensor_statistics(rnd)
self.assertEqual(stat["stype"], "FLOAT")


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 2 additions & 5 deletions _unittests/ut_helpers/test_ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
requires_onnxruntime_training,
requires_cuda,
)
from onnx_diagnostic.helpers.onnx_helper import (
from_array_extended,
onnx_dtype_to_np_dtype,
onnx_dtype_to_torch_dtype,
)
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
from onnx_diagnostic.helpers.ort_session import (
InferenceSessionForNumpy,
InferenceSessionForTorch,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import unittest
import numpy as np
import ml_dtypes
import onnx
import torch
import transformers
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers import max_diff, string_type
from onnx_diagnostic.helpers.torch_test_helper import (
from onnx_diagnostic.helpers.torch_helper import (
dummy_llm,
to_numpy,
is_torchdynamo_exporting,
Expand All @@ -24,6 +25,8 @@
make_sliding_window_cache,
)
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, to_array_extended
from onnx_diagnostic.helpers.torch_helper import to_tensor

TFLOAT = onnx.TensorProto.FLOAT

Expand Down Expand Up @@ -205,7 +208,7 @@ def forward(self, x, y):
else:
print("output", k, v)
print(string_type(restored, with_shape=True))
l1, l2 = 183, 192
l1, l2 = 186, 195
self.assertEqual(
[
(f"-Model-{l2}", 0, "I"),
Expand Down Expand Up @@ -344,6 +347,35 @@ def forward(self, x, y=None):
stat,
)

def test_to_tensor(self):
for dtype in [
np.int8,
np.uint8,
np.int16,
np.uint16,
np.int32,
np.uint32,
np.int64,
np.uint64,
np.float16,
np.float32,
np.float64,
]:
with self.subTest(dtype=dtype):
a = np.random.rand(4, 5).astype(dtype)
proto = from_array_extended(a)
b = to_array_extended(proto)
self.assertEqualArray(a, b)
c = to_tensor(proto)
self.assertEqualArray(a, c)

for dtype in [torch.bfloat16]:
with self.subTest(dtype=dtype):
a = torch.rand((4, 5), dtype=dtype)
proto = from_array_extended(a)
c = to_tensor(proto)
self.assertEqualArray(a, c)


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 2 additions & 5 deletions _unittests/ut_reference/test_ort_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
ignore_warnings,
requires_cuda,
)
from onnx_diagnostic.helpers.onnx_helper import (
from_array_extended,
onnx_dtype_to_torch_dtype,
onnx_dtype_to_np_dtype,
)
from onnx_diagnostic.helpers.onnx_helper import from_array_extended, onnx_dtype_to_np_dtype
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
from onnx_diagnostic.helpers.ort_session import _InferenceSession

Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/try_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
from onnx_diagnostic.helpers.torch_helper import steal_forward


class TestHuggingFaceHubModel(ExtTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
patched_selector,
patched_float_arange,
)
from onnx_diagnostic.helpers.torch_test_helper import fake_torchdynamo_exporting
from onnx_diagnostic.helpers.torch_helper import fake_torchdynamo_exporting


class TestOnnxExportErrors(ExtTestCase):
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_export_patches/test_patch_loops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.helpers.torch_test_helper import (
from onnx_diagnostic.helpers.torch_helper import (
is_torchdynamo_exporting,
fake_torchdynamo_exporting,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
torch_export_patches,
)
from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy


class TestPatchSerialization(ExtTestCase):
Expand Down
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
get_parser_find,
get_parser_lighten,
get_parser_print,
get_parser_stats,
get_parser_unlighten,
get_parser_validate,
)
Expand Down Expand Up @@ -63,6 +64,13 @@ def test_parser_validate(self):
text = st.getvalue()
self.assertIn("mid", text)

def test_parser_stats(self):
st = StringIO()
with redirect_stdout(st):
get_parser_stats().print_help()
text = st.getvalue()
self.assertIn("input", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading
Loading