Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.4
+++++

* :pr:`338`: fixes ReplayConfiguration.dump, add function to select of part of a model
* :pr:`337`: fixes extract_subset_of_nodes
* :pr:`336`: implements versioned onnx plugs

Expand Down
98 changes: 98 additions & 0 deletions _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
import onnx.numpy_helper as onh
from onnx import TensorProto, FunctionProto, ValueInfoProto
from onnx.checker import check_model
from onnx.external_data_helper import (
load_external_data_for_model,
_get_all_tensors,
uses_external_data,
)
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
from onnx_diagnostic.helpers.onnx_helper import (
onnx_lighten,
onnx_unlighten,
Expand All @@ -23,6 +29,7 @@
onnx_dtype_name,
extract_subset_of_nodes,
make_submodel,
select_model_inputs_outputs,
)


Expand Down Expand Up @@ -570,6 +577,97 @@ def test_extract_subset_of_nodes_bigger(self):
[n.op_type for n in nodes],
)

def _get_model_select(self):
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Z = oh.make_tensor_value_info("Z", TensorProto.INT64, [None, None])
graph = oh.make_graph(
[
oh.make_node("Mul", ["X", "X"], ["X2"]),
oh.make_node("Add", ["X2", "Y"], ["z1"]),
oh.make_node("Mul", ["z1", "W"], ["z2"]),
oh.make_node("Cast", ["z2"], ["Z"], to=TensorProto.INT64),
],
"add",
[X],
[Z],
[
onh.from_array(np.arange(16).reshape((-1, 4)).astype(np.float32), name="Y"),
onh.from_array(
(np.arange(16).reshape((-1, 4)) + 100).astype(np.float32), name="W"
),
],
)
onnx_model = oh.make_model(
graph, opset_imports=[oh.make_opsetid("", 18)], ir_version=8
)
return onnx_model

def test_select_model_inputs_outputs(self):
def enumerate_model_tensors(model):
for tensor in _get_all_tensors(model):
yield tensor, uses_external_data(tensor)

model = self._get_model_select()
root = self.get_dump_folder("test_select_model_inputs_outputs")
name = os.path.join(root, "model_ext.onnx")
location = os.path.basename(name) + ".data"
onnx.save(
model, name, save_as_external_data=True, size_threshold=15, location=location
)
self.assertEqual(
list(sorted(os.listdir(root))),
["model_ext.onnx", "model_ext.onnx.data"],
)

# X
name2 = os.path.join(root, "sub_model_ext.onnx")
model2 = onnx.load(name, load_external_data=False)
new_model = select_model_inputs_outputs(model2, outputs=["X2"])
onnx.save(new_model, name2)

x = np.arange(16).reshape((-1, 4)).astype(np.float32)
y = np.arange(16).reshape((-1, 4)).astype(np.float32)

sess = ExtendedReferenceEvaluator(new_model)
got = sess.run(None, {"X": x})[0]
self.assertEqual((x**2).tolist(), got.tolist())

sess = ExtendedReferenceEvaluator(name2)
got = sess.run(None, {"X": x})[0]
self.assertEqual((x**2).tolist(), got.tolist())

# z1
name3 = os.path.join(root, "sub_model_ext_z1.onnx")
model2 = onnx.load(name, load_external_data=False)
new_model = select_model_inputs_outputs(model2, outputs=["z1"])
onnx.save(new_model, name3)
self.assertEqual(
[
"model_ext.onnx",
"model_ext.onnx.data",
"sub_model_ext.onnx",
"sub_model_ext_z1.onnx",
],
list(sorted(os.listdir(root))),
)

x = np.arange(16).reshape((-1, 4)).astype(np.float32)

sess = ExtendedReferenceEvaluator(name3)
got = sess.run(None, {"X": x})[0]
self.assertEqual((x**2 + y).tolist(), got.tolist())

tensors = list(enumerate_model_tensors(new_model))
self.assertEqual(len(tensors), 1)
self.assertIsInstance(tensors[0], tuple)
self.assertEqual(len(tensors[0]), 2)
self.assertTrue(tensors[0][-1])
self.assertIsInstance(tensors[0][0], TensorProto)
load_external_data_for_model(new_model, root)
sess = ExtendedReferenceEvaluator(new_model)
got = sess.run(None, {"X": x})[0]
self.assertEqual((x**2 + y).tolist(), got.tolist())


if __name__ == "__main__":
unittest.main(verbosity=2)
Binary file not shown.
Binary file not shown.
176 changes: 176 additions & 0 deletions _unittests/ut_torch_onnx/test_discrepancies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import os
import unittest
import numpy as np
import onnx
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, has_onnxruntime
from onnx_diagnostic.reference import OnnxruntimeEvaluator
from onnx_diagnostic.helpers import max_diff, string_diff


class TestDiscrepancies(ExtTestCase):
@ignore_warnings(DeprecationWarning)
def test_attention_opset15_in_a_loop(self):
import torch
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_attention import ( # noqa: E501
patched_sdpa_attention_forward,
)

def qwen_sdpa_attention(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cu_seqlens: torch.Tensor,
scaling: float = 0,
num_heads: int = 16,
) -> torch.Tensor:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2)
for tensor in (query_states, key_states, value_states)
]

attn_outputs = [
patched_sdpa_attention_forward(
None,
q,
k,
v,
attention_mask=None,
scaling=scaling,
dropout=0.0,
is_causal=False,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
return attn_output

for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]:
if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"):
# not available
continue
with self.subTest(model=model_name):
model = onnx.load(os.path.join(os.path.dirname(__file__), "data", model_name))
sess = self.check_ort(model)

feeds = dict(
c_lifted_tensor_0=np.array([0], dtype=np.int64),
cat_2=np.array(
[
0,
64,
128,
192,
256,
304,
368,
432,
496,
560,
608,
672,
736,
800,
864,
912,
976,
1040,
1104,
1168,
1216,
1232,
1248,
1264,
1280,
1292,
],
dtype=np.int64,
),
unsqueeze_4=np.random.randn(1, 16, 1292, 80).astype(np.float32),
unsqueeze_5=np.random.randn(1, 16, 1292, 80).astype(np.float32),
unsqueeze_6=np.random.randn(1, 16, 1292, 80).astype(np.float32),
)

dummy_inputs = os.path.join(
os.path.dirname(__file__),
"..",
"..",
"dump_test",
"replay",
"qwen_sdpa_attention_loopmha",
"onnx_inputs.pt",
)
if os.path.exists(dummy_inputs):
print("-- use dummy inputs")

feeds1 = torch.load(dummy_inputs)
res1 = qwen_sdpa_attention(
feeds1["unsqueeze_4"],
feeds1["unsqueeze_5"],
feeds1["unsqueeze_6"],
feeds1["cat_2"],
scaling=0.11180339753627777,
num_heads=16,
)
feeds1o = {k: v.detach().cpu().numpy() for k, v in feeds1.items()}
reso1 = sess.run(None, feeds1o)[0]
dummy_inputs2 = dummy_inputs.replace("onnx_inputs", "torch_inputs")
assert dummy_inputs != dummy_inputs2
feeds2 = torch.load(dummy_inputs2)
res2 = qwen_sdpa_attention(
feeds2["unsqueeze_4"],
feeds2["unsqueeze_5"],
feeds2["unsqueeze_6"],
feeds2["cat_2"],
scaling=0.11180339753627777,
num_heads=16,
)
feeds2o = {k: v.detach().cpu().numpy() for k, v in feeds2.items()}
reso2 = sess.run(None, feeds2o)[0]
diff = max_diff(res1, res2, hist=[0.1])
print(f"-- diff torch-onnx: {string_diff(diff)}")
diff = max_diff(res2, reso2, hist=[0.1])
print(f"-- diff torch-onnxo1: {string_diff(diff)}")
diff = max_diff(res1, reso1, hist=[0.1])
print(f"-- diff torch-onnxo2: {string_diff(diff)}")
if diff["abs"] > 0.1:
for k in feeds1:
print(
f"-- {k}: "
f"{string_diff(max_diff(feeds1[k], feeds2[k], hist=[0.1]))}"
)

feeds = {
k: v.detach().cpu().numpy()
for k, v in torch.load(dummy_inputs).items()
}

for k, v in feeds.items():
print(
f"-- {k}: "
f"{self.string_type(v, with_shape=True, with_min_max=True)}"
)

# feeds["cat_2"] = np.array([0, 1292], dtype=np.int64)
got = sess.run(None, feeds)
self.assertEqual(len(got), 1)
self.assertEqual((1, 1292, 16, 80), got[0].shape)
expected = qwen_sdpa_attention(
torch.from_numpy(feeds["unsqueeze_4"]),
torch.from_numpy(feeds["unsqueeze_5"]),
torch.from_numpy(feeds["unsqueeze_6"]),
torch.from_numpy(feeds["cat_2"]),
scaling=0.11180339753627777,
num_heads=16,
)
self.assertEqualArray(expected, got[0], atol=1e-5)

tfeeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
ev = OnnxruntimeEvaluator(model)
got2 = ev.run(None, tfeeds)
self.assertEqual(len(got2), 1)
self.assertEqualArray(got[0], got2[0], atol=1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)
15 changes: 15 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,20 @@ def get_parser_sbs() -> ArgumentParser:
default="replay",
help="If the replay is triggered, this defines the folder where everything is dumped.",
)
parser.add_argument(
"-p",
"--replay-prefix-model",
action=BooleanOptionalAction,
default=False,
help=textwrap.dedent(
"""
There are two ways to recompute an intermediate output, the first one is to "
produce the minimal model between torch and onnx.
The second one is to dump onnx models from the inputs
to the considered intermediate results. This enables the second one.
"""
),
)

return parser

Expand Down Expand Up @@ -1431,6 +1445,7 @@ def _size(name):
set(args.replay_op_types.split(",")) if args.replay_op_types else None
),
dump_folder=args.replay_folder,
dump_prefix_model=args.replay_prefix_model,
)

print("-- starts side-by-side")
Expand Down
15 changes: 10 additions & 5 deletions onnx_diagnostic/export/control_flow_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ def _loop_for_op_wrapper(*args, **kwargs):

from torch._higher_order_ops.utils import setup_compilation_env

with setup_compilation_env() as backend:
return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
n_iter, body_fn, operands
)
with setup_compilation_env() as _backend:
return _loop_for_op_wrapper(n_iter, body_fn, *operands)
# return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
# n_iter, body_fn, operands
# )


def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
Expand Down Expand Up @@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands):
), f"Dense implementation operands must be a list of tensors and ints {operands}"
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
return _loop_for_onnx_fn(body_fn, n_iter, None, operands)


@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
def inner(mode, n_iter, body_fn, operands):
return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)


simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
Loading
Loading