diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c40ab2df..de947320 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,15 +106,16 @@ jobs: pip install torch==${{ matrix.torch }} torchvision torchaudio fi - - name: Cache pip - if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }} - uses: actions/cache@v4 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }} - restore-keys: | - ${{ runner.os }}-pip- - ${{ runner.os }}- + # two slow + #- name: Cache pip + # if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }} + # uses: actions/cache@v4 + # with: + # path: ~/.cache/pip + # key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }} + # restore-keys: | + # ${{ runner.os }}-pip- + # ${{ runner.os }}- - name: pip freeze run: python -m pip freeze diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index ee0b1a3e..30fd0e48 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.4 +++++ +* :pr:`337`: fixes extract_subset_of_nodes * :pr:`336`: implements versioned onnx plugs 0.8.3 diff --git a/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx b/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx new file mode 100644 index 00000000..86d72d01 Binary files /dev/null and b/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx differ diff --git a/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx.data b/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx.data new file mode 100644 index 00000000..85375546 Binary files /dev/null and b/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx.data differ diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 69aff474..9a9242e8 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -1,6 +1,8 @@ +import os import unittest from typing import Any, Dict, List import numpy as np +import onnx import onnx.helper as oh import onnx.numpy_helper as onh from onnx import TensorProto, FunctionProto, ValueInfoProto @@ -475,7 +477,7 @@ def _mkv_(name): def test_onnx_dtype_name(self): for k in dir(TensorProto): - if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL"}: + if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}: self.assertEqual(k, onnx_dtype_name(getattr(TensorProto, k))) self.assertRaise(lambda: onnx_dtype_name(1000), ValueError) self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED") @@ -532,6 +534,42 @@ def _type_rank_fn(name): check_model(new_model) self.check_ort(new_model) + def test_extract_subset_of_nodes_bigger(self): + model = onnx.load( + os.path.join( + os.path.dirname(__file__), "data", "test_sbs_mha_split_every_piece.onnx" + ) + ) + nodes = extract_subset_of_nodes( + model=model, + name="scaled_dot_product_attention", + node_index=16, + cut_points={ + "linear", + "linear_1", + "linear_2", + "output_0", + "scaled_dot_product_attention", + "transpose_2", + "view_2", + "x", + }, + ) + self.assertEqual( + [ + "Mul", + "Reshape", + "Transpose", + "Mul", + "Reshape", + "Transpose", + "FusedMatMul", + "Softmax", + "MatMul", + ], + [n.op_type for n in nodes], + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index c6381769..450ce6d6 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -1,3 +1,4 @@ +import os import unittest import pandas import onnx @@ -777,6 +778,81 @@ def forward(self, query, key, value, seq_lens): df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx")) # self.clean_dump() + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_mha_split_every_piece(self): + torch = self.torch + + class Model(self.torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int): + super(Model, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + assert embed_dim % num_heads == 0, ( + f"embed_dim % num_heads =! 0 -> " + f"{embed_dim} % {num_heads} = {embed_dim % num_heads}" + ) + + self.W_q = torch.nn.Linear(embed_dim, embed_dim) + self.W_k = torch.nn.Linear(embed_dim, embed_dim) + self.W_v = torch.nn.Linear(embed_dim, embed_dim) + + def split_heads(self, t, seq_len): + return t.view(t.shape[0], seq_len, self.num_heads, self.head_dim).transpose( + 1, 2 + ) + + def forward(self, x): + q = self.split_heads(self.W_q(x), x.shape[1]) + k = self.split_heads(self.W_k(x), x.shape[1]) + v = self.split_heads(self.W_v(x), x.shape[1]) + return ( + torch.nn.functional.scaled_dot_product_attention(q, k, v) + .transpose(1, 2) + .reshape(x.shape[0], x.shape[1], self.embed_dim) + ) + + embed_dim = 16 + num_heads = 4 + seq_len = 10 + batch_size = 2 + inputs = dict(x=torch.randn(batch_size, seq_len, embed_dim)) + model = Model(embed_dim, num_heads) + model(**inputs) + ds = dict(x={0: "batch", 1: "seqlen"}) + + ep = self.torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + self.dump_text("test_sbs_mha_split_every_piece.ep", str(ep)) + filename = self.get_dump_file("test_sbs_mha_split_every_piece.onnx") + to_onnx(ep, exporter="custom", filename=filename) + replay = self.get_dump_folder("test_sbs_mha_split_every_piece_replay") + onx = onnx.load(filename) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + run_onnx_with_torch_inputs=True, + replay_configuration=ReplayConfiguration( + dump_folder=replay, selected_op_types={"MatMul"}, threshold=2**20 + ), + ), + ) + df = pandas.DataFrame(list(results)).dropna(axis=1, how="all") + df.to_excel(self.get_dump_file("test_sbs_mha_split_every_piece.xlsx")) + max_abs = df["err_abs"].max() + self.assertLess(max_abs, 1e-5) + # self.clean_dump() + subonnx = onnx.load(os.path.join(replay, "scaled_dot_product_attention", "model.onnx")) + self.assertEqual(len(subonnx.graph.input), 3) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 6c19b409..0da11c02 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -845,6 +845,13 @@ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str: f.write(proto.SerializeToString()) return fullname + def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str: + """Dumps text in a file.""" + fullname = self.get_dump_file(name, folder=folder) + with open(fullname, "w") as f: + f.write(text) + return fullname + def assertExists(self, name): """Checks the existing of a file.""" if not os.path.exists(name): diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index d12bbd95..85713459 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -332,7 +332,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str: print(onnx_dtype_name(7)) """ for k in dir(TensorProto): - if k.upper() == k and k != "EXTERNAL": + if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}: v = getattr(TensorProto, k) if v == itype: return k @@ -1219,11 +1219,14 @@ def extract_subset_of_nodes( if name in node.output: node_index = i break - assert ( - node_index is not None - and node_index < len(model.graph.node) - and name in model.graph.node[node_index].output - ), f"node_index is still empty or wrong for result {name!r}" + assert node_index is not None and node_index < len(model.graph.node), ( + f"node_index={node_index} (n_nodes={len(model.graph.node)}) " + f"is still empty or wrong for result {name!r}" + ) + assert name in model.graph.node[node_index].output, ( + f"Unable to find {name!r} in {model.graph.node[node_index].output}, " + f"node={pretty_onnx(model.graph.node[node_index])}" + ) if cut_points is None: cut_points = {n.name for n in model.graph.input} | { n.name for n in model.graph.initializer @@ -1236,16 +1239,26 @@ def extract_subset_of_nodes( current_node_index = node_index current_input_index = 0 intermediate = {name} + cut_points -= {name} inputs = set(k for k in node.input if k) while not (inputs <= cut_points) and current_node_index >= 0: node = model.graph.node[current_node_index] - if current_input_index == 0: + if current_input_index == 0 or not node.input: needs = [o for o in node.output if o in intermediate and o not in cut_points] if needs: selected.add(current_node_index) + if not node.input: + current_node_index -= 1 + current_input_index = 0 + continue else: current_node_index -= 1 + current_input_index = 0 continue + assert current_input_index < len(node.input), ( + f"current_input_index={current_input_index} but node.input={node.input}, " + f"node={pretty_onnx(node)}" + ) res = node.input[current_input_index] if res not in cut_points: intermediate.add(res) @@ -1290,8 +1303,8 @@ def _mkv_(name, itype, irank): oh.make_graph( nodes, "submodel", - [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)], - [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)], + [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n], + [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n], ), ir_version=ir_version, opset_imports=opset_imports, diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 7cac5e79..4bdfed0a 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -381,7 +381,8 @@ def _preparation_with_fx_graph( assert len(torch_input_names) < len(onx.graph.input), ( f"torch_input_names={torch_input_names!r}, " f"onnx_input_names={[n.name for n in onx.graph.input]}, " - f"node.name={node.name!r} cannot be an input" + f"node.name={node.name!r} cannot be an input, " + f"placeholders_to_state_dict={sorted(placeholders_to_state_dict)}" ) assert node.name not in skip_mapping_torch_onnx, ( f"{node.name!r} is ambiguous, cannot be mapped due to " @@ -772,9 +773,9 @@ def forward(self, x): # preparation with ep.graph.nodes ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)} placeholders_to_state_dict = { - **{f"p_{name.replace('.', '_')}": name for name in ep.state_dict}, - **{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()}, - **{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants}, + **{f"p_{name.replace('.', '_').lower()}": name for name in ep.state_dict}, + **{f"b_{name.replace('.', '_').lower()}": name for name, _ in ep.named_buffers()}, + **{f"c_{name.replace('.', '_').lower()}": name for name in ep.tensor_constants}, } skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict) placeholders = {} diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index 234e424d..2b7255d2 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -243,7 +243,17 @@ def dump( :return: the folder created to dump everything """ if verbose: - print(f"[ReplayConfiguration.dump] extract subset of node for {name!r}") + print( + f"[ReplayConfiguration.dump] extract subset of nodes for " + f"{name!r} (onnx_id_node={onnx_id_node})" + ) + if verbose >= 10: + print(f"[ReplayConfiguration.dump] onnx_results={sorted(onnx_results)}") + print(f"[ReplayConfiguration.dump] torch_results={sorted(torch_results)}") + print( + f"[ReplayConfiguration.dump] onnx_name_to_ep_name=" + f"{sorted(onnx_name_to_ep_name)}" + ) nodes = extract_subset_of_nodes( model=model, name=name, @@ -253,7 +263,8 @@ def dump( if not nodes: if verbose: print( - f"[ReplayConfiguration.dump] could not extract subset of node for {name!r}" + f"[ReplayConfiguration.dump] could not extract subset of " + f"nodes for {name!r}" ) return None if verbose: