diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 7eee0883..f8c85340 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -8,7 +8,7 @@ Change Logs * :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime * :pr:`310`: splits patches into multiple files * :pr:`308`: add option --save_ep to dump the exported program as well as torch input -* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`: improves side-by-side comparison, creates command line sbs +* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs 0.8.2 +++++ diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index bd7b4c81..a26b85dd 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -19,6 +19,8 @@ enumerate_results, shadowing_names, onnx_dtype_name, + extract_subset_of_nodes, + make_submodel, ) @@ -476,6 +478,58 @@ def test_onnx_dtype_name(self): self.assertRaise(lambda: onnx_dtype_name(1000), ValueError) self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED") + def test_extract_subset_of_nodes(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + submodel = extract_subset_of_nodes(model, "xm", cut_points={"Y", "xu2", "xm1"}) + op_types = [n.op_type for n in submodel] + self.assertEqual(["Reshape", "Cast", "MatMul"], op_types) + + def _type_rank_fn(name): + if name in {"Y", "xu2"}: + return TensorProto.FLOAT, 4 + if name in {"xm1", "xm"}: + return TensorProto.FLOAT, 3 + if name == "shape2": + return TensorProto.INT64, 1 + raise AssertionError(f"unexpected name={name!r}") + + new_model = make_submodel( + submodel, + ir_version=model.ir_version, + opset_imports=model.opset_import, + type_rank_fn=_type_rank_fn, + output_names=["xm"], + ) + check_model(new_model) + self.check_ort(new_model) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 3588ab19..220e0f9d 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -10,7 +10,7 @@ ) from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str -from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord +from onnx_diagnostic.torch_onnx.sbs import run_aligned, RunAlignedRecord, ReplayConfiguration from onnx_diagnostic.export.api import to_onnx @@ -23,7 +23,7 @@ def setUpClass(cls): def test_run_aligned_record(self): r = RunAlignedRecord( - ep_id_node=-1, + ep_id_node=1, onnx_id_node=-1, ep_name="A", onnx_name="B", @@ -512,6 +512,56 @@ def forward(self, x): self.assertEqual(onnx_op_type.count("reset"), 1) self.clean_dump() + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_replay(self): + torch = self.torch + + class Model(self.torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(10, 3200) # input size 10 → hidden size 32 + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(3200, 1) # hidden → output + with torch.no_grad(): + self.fc2.bias += 1999 + self.fc1.bias += 999 + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.fc2(x) + return x + + inputs = dict(x=self.torch.randn((5, 10), dtype=torch.float16)) + ds = dict(x={0: "batch"}) + model = Model() + model = model.to(torch.float16) + model(**inputs) + ep = self.torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + filename = self.get_dump_file("test_sbs_replay.onnx") + dump_folder = self.get_dump_folder("test_sbs_replay_linear") + to_onnx(ep, exporter="custom", filename=filename) + onx = onnx.load(filename) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + replay_configuration=ReplayConfiguration( + dump_folder=dump_folder, selected_op_types={"Gemm"} + ), + ), + ) + df = pandas.DataFrame(list(results)) + df.to_excel(self.get_dump_file("test_sbs_replay.xlsx")) + print(df) + # self.clean_dump() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index 66e7f964..f9ca7c5c 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -112,6 +112,7 @@ def forward(self, x): input_file = self.get_dump_file("test_h_parser_sbs.inputs.pt") ep_file = self.get_dump_file("test_h_parser_sbs.ep") onnx_file = self.get_dump_file("test_h_parser_sbs.model.onnx") + replay_foler = self.get_dump_folder("test_h_parser_sbs.replay") torch.save(inputs, input_file) to_onnx( Model(), @@ -139,6 +140,10 @@ def forward(self, x): output, "-m", onnx_file, + "-t", + "Gemm", + "-f", + replay_foler, ] ) text = st.getvalue() diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 54b8396b..c67fc770 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1140,6 +1140,14 @@ def get_parser_sbs() -> ArgumentParser: - torch.export.save(ep: torch.export.ExportedProgram) - torch.save(**inputs) - onnx.save(...) + + The Replay functionality is just a way to investigates a part of a model. + It saves torch and onnx inputs, the torch outputs, and the minimal onnx model + which shares its inputs with the exported program. + This is used to investigate the discrepancies between the torch + model (through the exported program) and its onnx conversion. + This functionality dumps everything it can to disk + so that it be replayed in a separate process. """ ), ) @@ -1222,10 +1230,33 @@ def get_parser_sbs() -> ArgumentParser: ), ) parser.add_argument( - "--gemmlinear", - action=BooleanOptionalAction, - default=False, - help="Replaces Gemm(A,X.T,B) by torch...linear(A,X,B) on onnx side", + "-s", + "--replay-threshold", + type=float, + required=False, + default=1e6, + help="Triggers the replay if the discrepancies are higher than this value.", + ) + parser.add_argument( + "-n", + "--replay-names", + required=False, + default="", + help="Triggers the replay if a result name is in this set of values (comma separated)", + ) + parser.add_argument( + "-t", + "--replay-op-types", + required=False, + default="", + help="Triggers the replay if an onnx type is in this set of values (comma separated)", + ) + parser.add_argument( + "-f", + "--replay-folder", + required=False, + default="replay", + help="If the replay is triggered, this defines the folder where everything is dumped.", ) return parser @@ -1235,7 +1266,7 @@ def _cmd_sbs(argv: List[Any]): import pandas import torch from .helpers import flatten_object, max_diff, string_diff, string_type - from .torch_onnx.sbs import run_aligned + from .torch_onnx.sbs import run_aligned, ReplayConfiguration from .reference import OnnxruntimeEvaluator parser = get_parser_sbs() @@ -1306,6 +1337,17 @@ def _size(name): onx = onnx.load(args.onnx) print(f"-- done in {time.perf_counter() - begin:1.1f}s") + replay_configuration = None + if args.replay_threshold < 1e6 or args.replay_names or args.replay_op_types: + replay_configuration = ReplayConfiguration( + threshold=args.replay_threshold, + selected_names=set(args.replay_names.split(",")) if args.replay_names else None, + selected_op_types=( + set(args.replay_op_types.split(",")) if args.replay_op_types else None + ), + dump_folder=args.replay_folder, + ) + print("-- starts side-by-side") ratio = int(args.ratio) data = [] @@ -1319,9 +1361,9 @@ def _size(name): args=margs, kwargs=mkwargs, use_tensor=True, - gemmlinear=args.gemmlinear, reset_names=args.reset.split(","), exc=False, + replay_configuration=replay_configuration, ): data.append(obs) if ( diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index d0e0aaed..d12bbd95 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -3,7 +3,7 @@ import os import sys import warnings -from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union import numpy as np import numpy.typing as npt import onnx @@ -15,6 +15,7 @@ GraphProto, ModelProto, NodeProto, + OperatorSetIdProto, TensorProto, ValueInfoProto, load as onnx_load, @@ -1195,3 +1196,104 @@ def shadowing_names( existing |= not_empty created |= not_empty return shadow, post_shadow, created + + +def extract_subset_of_nodes( + model: ModelProto, + name: str, + node_index: Optional[int] = None, + cut_points: Optional[Set[str]] = None, +) -> List[NodeProto]: + """ + Extracts the minimal subgraphs which can produce the output ``name`` + knowing ``cut_points``. + + :param model: original model + :param name: result name + :param node_index: if the node index is known, otherwise searches for it + :param cut_points: the known results or input name otherwise + :return: minimal list of nodes + """ + if node_index is None: + for i, node in enumerate(model.graph.node): + if name in node.output: + node_index = i + break + assert ( + node_index is not None + and node_index < len(model.graph.node) + and name in model.graph.node[node_index].output + ), f"node_index is still empty or wrong for result {name!r}" + if cut_points is None: + cut_points = {n.name for n in model.graph.input} | { + n.name for n in model.graph.initializer + } + elif model.graph.initializer: + cut_points = cut_points | {n.name for n in model.graph.initializer} + + node = model.graph.node[node_index] + selected = {node_index} + current_node_index = node_index + current_input_index = 0 + intermediate = {name} + inputs = set(k for k in node.input if k) + while not (inputs <= cut_points) and current_node_index >= 0: + node = model.graph.node[current_node_index] + if current_input_index == 0: + needs = [o for o in node.output if o in intermediate and o not in cut_points] + if needs: + selected.add(current_node_index) + else: + current_node_index -= 1 + continue + res = node.input[current_input_index] + if res not in cut_points: + intermediate.add(res) + current_input_index += 1 + if current_input_index >= len(node.input): + current_node_index -= 1 + current_input_index = 0 + + return [model.graph.node[i] for i in sorted(selected)] + + +def make_submodel( + nodes: List[NodeProto], + ir_version: int, + opset_imports: List[OperatorSetIdProto], + output_names: List[str], + type_rank_fn: Callable[[str], Tuple[int, int]], +) -> ModelProto: + """ + Creates a model with the given list of nodes. + It computes the minimum list of inputs needed for this model. + The function assumes the nodes are sorted. + It does not handle yet subgraphs. + + :param nodes: list of nodes + :param ir_version: ir version + :param opset_imports: opset import + :param output_names: desired outputs + :param function: function returning the type and the rank of a result + :return: model proto + """ + + def _mkv_(name, itype, irank): + return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)]) + + not_known: Set[str] = set() + for node in nodes[::-1]: + not_known -= set(node.output) + not_known |= set(node.input) + + model = oh.make_model( + oh.make_graph( + nodes, + "submodel", + [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)], + [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)], + ), + ir_version=ir_version, + opset_imports=opset_imports, + ) + return model diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 9df1b96f..a2f814bf 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -1,4 +1,5 @@ import inspect +import os import time from dataclasses import dataclass from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union @@ -7,7 +8,7 @@ import numpy as np import torch from ..helpers import string_type, string_diff, max_diff, flatten_object -from ..helpers.onnx_helper import pretty_onnx +from ..helpers.onnx_helper import pretty_onnx, extract_subset_of_nodes, make_submodel from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype @@ -172,6 +173,133 @@ def prepare_args_kwargs( return new_args, new_kwargs +@dataclass +class ReplayConfiguration: + """ + Configuration specifying how to replay or dump pieces of + onnx graph in order to replay them later and investigate + later possible sources of discrepancies. + + :param dump_folder: where to dump the onnx model corresponding to the + pieces to investigate + :param selected_names: list of results names to dump + :param selected_op_types: list of onnx operators to dump + :param threshold: only keep those whose discrepancies is greater than that threshold + """ + + dump_folder: str + selected_names: Optional[Set[str]] = None + selected_op_types: Optional[Set[str]] = None + threshold: float = 0.1 + + def __post_init__(self): + assert self.dump_folder, "dump_folder is empty and this is not allowed for the replay" + + def select( + self, + name: Optional[str] = None, + op_type: Optional[str] = None, + err_abs: Optional[float] = None, + ) -> bool: + """ + Returns true or false whether or not a piece of the onnx model should be dumped, + around a particular node. The results is True if one of the condition is true: + + * ``name in self.selected_names`` + * ``op_type in self.selected_op_types`` + * ``err_abs >= self.threshold`` + + :param name: result name + :param op_type: operator type + :param err_abs: measured discrepancy + :return: True if this should be dumped + """ + if name and self.selected_names: + if name in self.selected_names: + return True + if op_type and self.selected_op_types: + if op_type in self.selected_op_types: + return True + if err_abs is not None and err_abs >= self.threshold: + return True + return False + + def dump( + self, + name: str, + onnx_id_node: int, + model: onnx.ModelProto, + onnx_results: Dict[str, Any], + torch_results: Dict[str, torch.Tensor], + onnx_name_to_ep_name: Dict[str, str], + verbose: int = 0, + ) -> Optional[str]: + """ + Dumps the minimal graph which can be replayed outside the model. + + :param name: name of the result to look into + :param onnx_id_node: index of the node which produces it model `model` + :param model: onnx model + :param onnx_results: all known onnx results + :param torch_results: all known torch results + :param onnx_name_to_ep_name: correspondence between onnx_node name + and exported program name + :param verbose: verbosity level + :return: the folder created to dump everything + """ + if verbose: + print(f"[ReplayConfiguration.dump] extract subset of node for {name!r}") + nodes = extract_subset_of_nodes( + model=model, + name=name, + node_index=onnx_id_node, + cut_points=set(onnx_name_to_ep_name), + ) + if not nodes: + if verbose: + print( + f"[ReplayConfiguration.dump] could not extract subset of node for {name!r}" + ) + return None + if verbose: + print(f"[ReplayConfiguration.dump] make model with {len(nodes)} nodes") + submodel = make_submodel( + nodes, + ir_version=model.ir_version, + opset_imports=model.opset_import, + output_names=[name], + type_rank_fn=lambda name: ( + torch_dtype_to_onnx_dtype(onnx_results[name].dtype), + len(onnx_results[name].shape), + ), + ) + input_names = [n.name for n in submodel.graph.input] + if verbose: + print(f"[ReplayConfiguration.dump] model inputs {input_names}") + folder = os.path.join(self.dump_folder, name.replace(":", "_").replace("/", "_")) + os.makedirs(folder, exist_ok=True) + if verbose: + print(f"[ReplayConfiguration.dump] dumps into folder {folder!r}") + onnx.save(submodel, os.path.join(folder, "model.onnx")) + torch_inputs = {} + for n in input_names: + if n in onnx_name_to_ep_name: + torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]] + else: + raise AssertionError(f"n={n!r}, onnx_name_to_ep_name={onnx_name_to_ep_name}") + onnx_inputs = {n: onnx_results[n] for n in input_names} + assert ( + name in onnx_name_to_ep_name + ), f"Unable to find {name!r} in {onnx_name_to_ep_name}" + expected_outputs = (torch_results[onnx_name_to_ep_name[name]],) + torch.save(torch_inputs, os.path.join(folder, "torch_inputs.pt")) + torch.save(onnx_inputs, os.path.join(folder, "onnx_inputs.pt")) + torch.save(expected_outputs, os.path.join(folder, "torch_outputs.pt")) + if verbose: + print(f"[ReplayConfiguration.dump] done {folder!r}") + return folder + + @dataclass class RunAlignedRecord: """ @@ -222,7 +350,15 @@ class RunAlignedRecord: ep_time_run: Optional[float] = None onnx_time_run: Optional[float] = None + def __post_init__(self): + "Validation." + assert self.ep_id_node is None or self.ep_id_node >= 0, ( + f"Node id are always positive in the exported program but " + f"ep_id_node={self.ep_id_node}" + ) + def set_diff(self, diff: Dict[str, Any]): + """Sets error.""" if diff is None: return if "abs" in diff: @@ -246,19 +382,24 @@ class StatusRunAligned: :param n_inf: number of infinite values seen so far :param n_nan: number of nan values seen so for :param yielded_nodes: number of yielded pair of nodes seen so far + :param last_replay: last result dumped on disk for later replay """ max_abs: float = 0.0 n_inf: int = 0 n_nan: int = 0 yielded_nodes: int = 0 + last_replay: str = "" def to_str(self) -> str: "Nice display." - return ( + s = ( f"yielded={self.yielded_nodes} maxabs={self.max_abs:1.3f} " f"#inf={self.n_inf} #nan={self.n_nan}" ) + if self.last_replay: + return f"{s} -PLAY({self.last_replay})" + return s def update(self, err_abs: float): "Updates all attributes with the latest measure." @@ -289,10 +430,10 @@ def run_aligned( use_tensor: bool = False, atol: Optional[float] = None, rtol: Optional[float] = None, - gemmlinear: bool = False, verbose: int = 0, exc: bool = True, reset_names: Optional[List[str]] = None, + replay_configuration: Optional[ReplayConfiguration] = None, ) -> Iterator[RunAlignedRecord]: """ Runs in parallel both the exported program @@ -316,6 +457,10 @@ def run_aligned( :param exc: stops if an exception :param reset_names: list of names, the onnx execution takes the torch outputs instead of its own result if the names falls into that set + :param replay_configuration: configuration to let the user dump any problematic + piece of the onnx graph he wants to replay in order to investigate later, + see :class: `ReplayConfiguration + ` :return: a list of :class:`RunAlignedRecord` Example: @@ -543,12 +688,13 @@ def _loop_cmp( r = RunAlignedRecord( ep_id_node=i, onnx_id_node=i_onnx, - ep_name=o, - onnx_name=to, + ep_name=to, + onnx_name=o, ep_shape_type=string_type(torch_results[to], **str_kws), onnx_shape_type=string_type(r, **str_kws), ) r.set_diff(d) + mapping_onnx_to_torch[to] = o return r return None @@ -567,6 +713,7 @@ def _loop_onnx_node( str_kws, status, already_run, + torch_names_to_onnx_names, verbose, ): @@ -593,17 +740,13 @@ def _loop_onnx_node( ) assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}" begin = time.perf_counter() - res = None - if use_tensor and gemmlinear and node.op_type == "Gemm": - res = _gemm_linear(node, feeds, ref) - if res is None: - try: - res = ref.run(None, feeds) # type: ignore[attr-defined] - except Exception as e: - raise RuntimeError( - f"Unable to run node {node.op_type}, domain={node.domain} " - f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}" - ) from e + try: + res = ref.run(None, feeds) # type: ignore[attr-defined] + except Exception as e: + raise RuntimeError( + f"Unable to run node {node.op_type}, domain={node.domain} " + f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}" + ) from e duration = time.perf_counter() - begin assert ( not has_cuda @@ -661,6 +804,28 @@ def _loop_onnx_node( if tmp.err_abs is not None: status.update(tmp.err_abs) yield tmp + + # do we need to dump pieces if graph the user can replay? + if replay_configuration: + if replay_configuration.select( + name=tmp.onnx_name, op_type=tmp.onnx_op_type, err_abs=tmp.err_abs + ): + replay_configuration.dump( + name=tmp.onnx_name, + onnx_id_node=tmp.onnx_id_node, + model=onx, + onnx_results=onnx_results, + torch_results=torch_results, + onnx_name_to_ep_name={ + **{v: k for k, v in torch_names_to_onnx_names.items()}, + **mapping_onnx_to_torch, + }, + verbose=max(verbose - 1, 0), + ) + status.last_replay = tmp.onnx_name + + # reset_names: replaces onnx_results by torch_results to see + # if that fixes the discrepancies problem if reset_names and tmp.ep_name in reset_names: assert ( tmp.ep_name in torch_results @@ -701,30 +866,6 @@ def _duplicated_values(d): final |= set(v) return final - def _gemm_linear(node, feeds, sess): - if node.op_type != "Gemm" or node.domain != "": - return None - for att in node.attribute: - if att.name == "alpha": - if att.f != 1: - return None - elif att.name == "beta": - if att.f != 1: - return None - elif att.name == "transA": - if att.i != 0: - return None - elif att.name == "transB": - if att.i != 1: - return None - t = torch.nn.functional.linear( - feeds[node.input[0]], feeds[node.input[1]], feeds[node.input[2]] - ) - got = sess.run(None, feeds) - print(f"-- GEMM {node.output[0]}: {string_diff(max_diff(t, got, hist=[0.1]))}") - assert t.shape == got[0].shape, f"shape mismatch {t.shape} != {got[0].shape}" - return [t] - # preparation with ep.graph.nodes ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)} placeholders_to_state_dict = { @@ -1071,6 +1212,7 @@ def _gemm_linear(node, feeds, sess): str_kws, status, already_run, + torch_names_to_onnx_names, verbose, ): if r: @@ -1100,6 +1242,7 @@ def _gemm_linear(node, feeds, sess): str_kws, status, already_run, + torch_names_to_onnx_names, verbose, ): if r: