Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Change Logs
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
* :pr:`310`: splits patches into multiple files
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`: improves side-by-side comparison, creates command line sbs
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs

0.8.2
+++++
Expand Down
54 changes: 54 additions & 0 deletions _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
enumerate_results,
shadowing_names,
onnx_dtype_name,
extract_subset_of_nodes,
make_submodel,
)


Expand Down Expand Up @@ -476,6 +478,58 @@ def test_onnx_dtype_name(self):
self.assertRaise(lambda: onnx_dtype_name(1000), ValueError)
self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED")

def test_extract_subset_of_nodes(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
submodel = extract_subset_of_nodes(model, "xm", cut_points={"Y", "xu2", "xm1"})
op_types = [n.op_type for n in submodel]
self.assertEqual(["Reshape", "Cast", "MatMul"], op_types)

def _type_rank_fn(name):
if name in {"Y", "xu2"}:
return TensorProto.FLOAT, 4
if name in {"xm1", "xm"}:
return TensorProto.FLOAT, 3
if name == "shape2":
return TensorProto.INT64, 1
raise AssertionError(f"unexpected name={name!r}")

new_model = make_submodel(
submodel,
ir_version=model.ir_version,
opset_imports=model.opset_import,
type_rank_fn=_type_rank_fn,
output_names=["xm"],
)
check_model(new_model)
self.check_ort(new_model)


if __name__ == "__main__":
unittest.main(verbosity=2)
54 changes: 52 additions & 2 deletions _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord
from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord, ReplayConfiguration
from onnx_diagnostic.export.api import to_onnx


Expand All @@ -23,7 +23,7 @@ def setUpClass(cls):

def test_run_aligned_record(self):
r = RunAlignedRecord(
ep_id_node=-1,
ep_id_node=1,
onnx_id_node=-1,
ep_name="A",
onnx_name="B",
Expand Down Expand Up @@ -512,6 +512,56 @@ def forward(self, x):
self.assertEqual(onnx_op_type.count("reset"), 1)
self.clean_dump()

@hide_stdout()
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
def test_sbs_replay(self):
torch = self.torch

class Model(self.torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(10, 3200) # input size 10 → hidden size 32
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(3200, 1) # hidden → output
with torch.no_grad():
self.fc2.bias += 1999
self.fc1.bias += 999

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x

inputs = dict(x=self.torch.randn((5, 10), dtype=torch.float16))
ds = dict(x={0: "batch"})
model = Model()
model = model.to(torch.float16)
model(**inputs)
ep = self.torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
filename = self.get_dump_file("test_sbs_replay.onnx")
dump_folder = self.get_dump_folder("test_sbs_replay_linear")
to_onnx(ep, exporter="custom", filename=filename)
onx = onnx.load(filename)
results = list(
run_aligned(
ep,
onx,
kwargs=inputs,
run_cls=OnnxruntimeEvaluator,
verbose=11,
use_tensor=True,
replay_configuration=ReplayConfiguration(
dump_folder=dump_folder, selected_op_types={"Gemm"}
),
),
)
df = pandas.DataFrame(list(results))
df.to_excel(self.get_dump_file("test_sbs_replay.xlsx"))
print(df)
# self.clean_dump()


if __name__ == "__main__":
unittest.main(verbosity=2)
5 changes: 5 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def forward(self, x):
input_file = self.get_dump_file("test_h_parser_sbs.inputs.pt")
ep_file = self.get_dump_file("test_h_parser_sbs.ep")
onnx_file = self.get_dump_file("test_h_parser_sbs.model.onnx")
replay_foler = self.get_dump_folder("test_h_parser_sbs.replay")
torch.save(inputs, input_file)
to_onnx(
Model(),
Expand Down Expand Up @@ -139,6 +140,10 @@ def forward(self, x):
output,
"-m",
onnx_file,
"-t",
"Gemm",
"-f",
replay_foler,
]
)
text = st.getvalue()
Expand Down
54 changes: 48 additions & 6 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,14 @@ def get_parser_sbs() -> ArgumentParser:
- torch.export.save(ep: torch.export.ExportedProgram)
- torch.save(**inputs)
- onnx.save(...)

The Replay functionality is just a way to investigates a part of a model.
It saves torch and onnx inputs, the torch outputs, and the minimal onnx model
which shares its inputs with the exported program.
This is used to investigate the discrepancies between the torch
model (through the exported program) and its onnx conversion.
This functionality dumps everything it can to disk
so that it be replayed in a separate process.
"""
),
)
Expand Down Expand Up @@ -1222,10 +1230,33 @@ def get_parser_sbs() -> ArgumentParser:
),
)
parser.add_argument(
"--gemmlinear",
action=BooleanOptionalAction,
default=False,
help="Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side",
"-s",
"--replay-threshold",
type=float,
required=False,
default=1e6,
help="Triggers the replay if the discrepancies are higher than this value.",
)
parser.add_argument(
"-n",
"--replay-names",
required=False,
default="",
help="Triggers the replay if a result name is in this set of values (comma separated)",
)
parser.add_argument(
"-t",
"--replay-op-types",
required=False,
default="",
help="Triggers the replay if an onnx type is in this set of values (comma separated)",
)
parser.add_argument(
"-f",
"--replay-folder",
required=False,
default="replay",
help="If the replay is triggered, this defines the folder where everything is dumped.",
)

return parser
Expand All @@ -1235,7 +1266,7 @@ def _cmd_sbs(argv: List[Any]):
import pandas
import torch
from .helpers import flatten_object, max_diff, string_diff, string_type
from .torch_onnx.sbs import run_aligned
from .torch_onnx.sbs import run_aligned, ReplayConfiguration
from .reference import OnnxruntimeEvaluator

parser = get_parser_sbs()
Expand Down Expand Up @@ -1306,6 +1337,17 @@ def _size(name):
onx = onnx.load(args.onnx)
print(f"-- done in {time.perf_counter() - begin:1.1f}s")

replay_configuration = None
if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types:
replay_configuration = ReplayConfiguration(
threshold=args.replay_threshold,
selected_names=set(args.replay_names.split(",")) if args.replay_names else None,
selected_op_types=(
set(args.replay_op_types.split(",")) if args.replay_op_types else None
),
dump_folder=args.replay_folder,
)

print("-- starts side-by-side")
ratio = int(args.ratio)
data = []
Expand All @@ -1319,9 +1361,9 @@ def _size(name):
args=margs,
kwargs=mkwargs,
use_tensor=True,
gemmlinear=args.gemmlinear,
reset_names=args.reset.split(","),
exc=False,
replay_configuration=replay_configuration,
):
data.append(obs)
if (
Expand Down
104 changes: 103 additions & 1 deletion onnx_diagnostic/helpers/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
import warnings
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
import numpy.typing as npt
import onnx
Expand All @@ -15,6 +15,7 @@
GraphProto,
ModelProto,
NodeProto,
OperatorSetIdProto,
TensorProto,
ValueInfoProto,
load as onnx_load,
Expand Down Expand Up @@ -1195,3 +1196,104 @@ def shadowing_names(
existing |= not_empty
created |= not_empty
return shadow, post_shadow, created


def extract_subset_of_nodes(
model: ModelProto,
name: str,
node_index: Optional[int] = None,
cut_points: Optional[Set[str]] = None,
) -> List[NodeProto]:
"""
Extracts the minimal subgraphs which can produce the output ``name``
knowing ``cut_points``.

:param model: original model
:param name: result name
:param node_index: if the node index is known, otherwise searches for it
:param cut_points: the known results or input name otherwise
:return: minimal list of nodes
"""
if node_index is None:
for i, node in enumerate(model.graph.node):
if name in node.output:
node_index = i
break
assert (
node_index is not None
and node_index < len(model.graph.node)
and name in model.graph.node[node_index].output
), f"node_index is still empty or wrong for result {name!r}"
if cut_points is None:
cut_points = {n.name for n in model.graph.input} | {
n.name for n in model.graph.initializer
}
elif model.graph.initializer:
cut_points = cut_points | {n.name for n in model.graph.initializer}

node = model.graph.node[node_index]
selected = {node_index}
current_node_index = node_index
current_input_index = 0
intermediate = {name}
inputs = set(k for k in node.input if k)
while not (inputs <= cut_points) and current_node_index >= 0:
node = model.graph.node[current_node_index]
if current_input_index == 0:
needs = [o for o in node.output if o in intermediate and o not in cut_points]
if needs:
selected.add(current_node_index)
else:
current_node_index -= 1
continue
res = node.input[current_input_index]
if res not in cut_points:
intermediate.add(res)
current_input_index += 1
if current_input_index >= len(node.input):
current_node_index -= 1
current_input_index = 0

return [model.graph.node[i] for i in sorted(selected)]


def make_submodel(
nodes: List[NodeProto],
ir_version: int,
opset_imports: List[OperatorSetIdProto],
output_names: List[str],
type_rank_fn: Callable[[str], Tuple[int, int]],
) -> ModelProto:
"""
Creates a model with the given list of nodes.
It computes the minimum list of inputs needed for this model.
The function assumes the nodes are sorted.
It does not handle yet subgraphs.

:param nodes: list of nodes
:param ir_version: ir version
:param opset_imports: opset import
:param output_names: desired outputs
:param function: function returning the type and the rank of a result
:return: model proto
"""

def _mkv_(name, itype, irank):
return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])

not_known: Set[str] = set()
for node in nodes[::-1]:
not_known -= set(node.output)
not_known |= set(node.input)

model = oh.make_model(
oh.make_graph(
nodes,
"submodel",
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
),
ir_version=ir_version,
opset_imports=opset_imports,
)
return model
Loading
Loading