From 300ff8816b1678f77c7f1090da248aa33fc809f6 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Wed, 20 Sep 2023 14:36:32 -0700 Subject: [PATCH] Replace node.meta source_fn with source_fn_stack (#108595) Summary: X-link: https://github.com/pytorch/executorch/pull/210 A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions: This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', ) ('sin', 'sin') ('cos', 'cos') ('sub'. ) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', ), ('cond', )], [('cond', ), ('cond', ), ('sin', 'sin')], [('cond', ), ('cond', ), ('cos', 'cos')] [('cond', ), ('cond', ), ('sub', )] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4 --- test/dynamo/test_aot_autograd.py | 7 +- test/dynamo/test_export.py | 5 +- test/dynamo/test_higher_order_ops.py | 149 ++++++++++++++++++ test/export/test_export.py | 11 +- test/export/test_serialize.py | 6 +- test/functorch/test_control_flow.py | 4 +- test/test_fx.py | 4 +- torch/_dynamo/output_graph.py | 58 +++++-- torch/_dynamo/variables/higher_order_ops.py | 11 +- torch/_export/serde/serialize.py | 24 +-- torch/_inductor/utils.py | 9 +- torch/fx/passes/utils/source_matcher_utils.py | 3 +- torch/fx/proxy.py | 2 +- 13 files changed, 248 insertions(+), 45 deletions(-) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 0555eaba827e4..520e5243a0108 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -798,10 +798,11 @@ def _prepare_model_args(): continue if min_seq_nr < 0: min_seq_nr = seq_nr - mod_name = node.meta.get("source_fn", "") + source_fn_stack = node.meta.get("source_fn_stack", []) orig_aten = node.meta.get("original_aten", "") - if isinstance(mod_name, tuple): - mod_name = mod_name[0] + mod_name = "" + if len(source_fn_stack) > 0: + mod_name = source_fn_stack[-1][0] # Make all seq_nr relative so it starts at 0 seq_nr = seq_nr - min_seq_nr seq_table = seq_table + f"{seq_nr}|{orig_aten}|{mod_name}\n" diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 0828efd8563b9..6a3eb5926acc3 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -964,7 +964,7 @@ def forward(self, x): if node.op not in {"placeholder", "output"}: self.assertTrue(node.stack_trace is not None) self.assertTrue(node.meta["nn_module_stack"] is not None) - self.assertTrue(node.meta["source_fn"] is not None) + self.assertTrue(node.meta["source_fn_stack"] is not None) torch._dynamo.reset() @@ -974,7 +974,7 @@ def forward(self, x): if node.op == "call_function": self.assertTrue(node.stack_trace is not None) self.assertTrue(node.meta["nn_module_stack"] is not None) - self.assertTrue(node.meta["source_fn"] is not None) + self.assertTrue(node.meta["source_fn_stack"] is not None) self.assertTrue(node.meta["val"] is not None) self.assertTrue(node.meta["original_aten"] is not None) @@ -4014,7 +4014,6 @@ def fn(x): self.assertEqual( nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"] ) - self.assertEqual(nd1.meta["source_fn"], nd2.meta["source_fn"]) self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"]) def test_preserve_fx_node_metadata_recompile(self): diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 1a093459b543e..3c77a31f1235a 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import functools +import pprint import re import unittest @@ -1692,6 +1693,154 @@ def fn(x): self.assertTrue(activations.keys() == forward_handles.keys()) + def _get_source_fn_stack(self, gm, node_names): + ret = {} + for mod in gm.modules(): + for node in mod.graph.nodes: + if node.name in node_names: + actual_stack = [ + name for name, _ in node.meta.get("source_fn_stack", []) + ] + ret[node.name] = actual_stack + return ret + + def test_wrap_source_fn_stack(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x) + + mod = MockModule() + + def gn(x): + return torch.cos(x) + wrap(mod, x) + + def fn(x): + return wrap(gn, x) + + backend = EagerAndRecordGraphs() + inp = torch.randn((4, 4)) + torch.compile(fn, backend=backend, fullgraph=True)(inp) + + gm = backend.graphs[0] + actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"}) + self.assertExpectedInline( + pprint.pformat(actual_stack), + """\ +{'add': ['wrap', 'add'], + 'cos': ['wrap', 'cos'], + 'linear': ['wrap', 'wrap', 'linear']}""", + ) + + def test_cond_source_fn_stack(self): + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def cond_f(pred, pred2, x, y): + def true_fn(pred2, x, y): + return x + y + + def false_fn(pred2, x, y): + def true_fn2(x, y): + return x.sin() - y.cos() + + def false_fn2(x, y): + return x.cos() - y.sin() + + return control_flow.cond(pred2, true_fn2, false_fn2, [x, y]) + + return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y]) + + pred = torch.tensor(True) + pred2 = torch.tensor(False) + xs = torch.randn(2, 3, 3) + y = torch.randn(3, 3) + cond_f(pred, pred2, xs, y) + + gm = backend.graphs[0] + actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"}) + self.assertExpectedInline( + pprint.pformat(actual_stack), + """\ +{'add': ['cond', 'add'], + 'cos': ['cond', 'cond', 'cos'], + 'sin': ['cond', 'cond', 'sin'], + 'sub': ['cond', 'cond', 'sub']}""", + ) + + def test_map_source_fn_stack(self): + backend = EagerAndRecordGraphs() + + xs = torch.randn(2, 3, 3) + y = torch.randn(3) + + @torch.compile(backend=backend, fullgraph=True) + def map_f(xs, y): + def inner(x, y): + def inner2(x, y): + return x + y + + return control_flow.map(inner2, x, y) * y.cos() + + return control_flow.map(inner, xs, y).sin() + + result = map_f(xs, y) + + gm = backend.graphs[0] + actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"}) + self.assertExpectedInline( + pprint.pformat(actual_stack), + """{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""", + ) + + def test_grad_source_fn_stack(self): + backend = EagerAndRecordGraphs() + + def fn(x): + return x.sin().sum() + + @torch.compile(backend=backend, fullgraph=False) + def wrapper_fn(x): + return torch.func.grad(torch.func.grad(fn))(x) + + x = torch.randn(()) + + wrapper_fn(x) + gm = backend.graphs[0] + actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"}) + self.assertExpectedInline( + pprint.pformat(actual_stack), + """\ +{'sin': ['grad_impl', 'grad_impl', 'sin'], + 'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""", + ) + + def test_vmap_source_fn_stack(self): + backend = EagerAndRecordGraphs() + + def inner_fn(x): + return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x) + + @torch.compile(backend=backend, fullgraph=True) + def fn(x): + return torch.func.vmap(lambda x: inner_fn(x.cos()))(x) + + x = torch.randn(3, 3, 3, 3) + fn(x) + gm = backend.graphs[0] + actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sum_2", "cos", "add"}) + self.assertExpectedInline( + pprint.pformat(actual_stack), + """\ +{'add': ['vmap_impl', 'vmap_impl', 'add'], + 'cos': ['vmap_impl', 'cos'], + 'sum_1': ['vmap_impl', 'vmap_impl', 'sum_1'], + 'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""", + ) + class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): def run(self, result=None): diff --git a/test/export/test_export.py b/test/export/test_export.py index e817288608cf3..0e67a9dd04dca 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -346,7 +346,7 @@ def forward(self, x): node.name in ep.graph_signature.inputs_to_buffers or node.name in ep.graph_signature.inputs_to_parameters ): - self.assertTrue("source_fn" in node.meta) + self.assertTrue("source_fn_stack" in node.meta) self.assertTrue("nn_module_stack" in node.meta) @@ -1071,8 +1071,13 @@ def forward(self, x): for mod in gm.modules(): for node in mod.graph.nodes: if node.name in {"sin", "cos"}: - actual_source_fns.append(node.meta.get("source_fn", None)) - exp_source_fns = [("cos", "cos"), ("sin", "sin")] + source_fn_st = node.meta.get("source_fn_stack", None) + if source_fn_st is not None: + source_names = [] + for source_fn in source_fn_st: + source_names.append(source_fn[0]) + actual_source_fns.append(source_names) + exp_source_fns = [["cond", "cos"], ["cond", "sin"]] self.assertEqual(actual_source_fns, exp_source_fns) def test_lift_constants(self) -> None: diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 6151d4ed360fa..7bb79da01f7a3 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -273,10 +273,10 @@ def _check_graph_nodes(gm1, gm2, _check_meta=True): # node1.meta.get("nn_module_stack", None), # node2.meta.get("nn_module_stack", None), # ) - # Check "source_fn" metadata + # Check "source_fn_stack" metadata self.assertEqual( - node1.meta.get("source_fn", None), - node2.meta.get("source_fn", None), + node1.meta.get("source_fn_stack", None), + node2.meta.get("source_fn_stack", None), ) _check_graph_nodes(ep.graph_module, deserialized_ep.graph_module, _check_meta) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index fce2b9ce09e5b..e613f0ae06dea 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1501,7 +1501,7 @@ def false_fn(x): return x * x.sin() def foo(x): - return cond(x.shape[0] == 4, true_fn, false_fn, [x]) + return cond(x.shape[0] == 4, true_fn, false_fn, (x,)) inp = torch.randn([4, 3]) gm, _ = torch._dynamo.export(foo)(inp) @@ -1512,7 +1512,7 @@ def run_with_interpreter(*args): checked_ops = {"add", "mul", "sin", "cos"} - checked_meta = ["source_fn", "stack_trace"] + checked_meta = ["source_fn_stack", "stack_trace"] all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta) new_source_fns = collect_meta_for_filtered_nodes(new_gm, checked_ops, checked_meta) self.assertEqual(all_source_fns, new_source_fns) diff --git a/test/test_fx.py b/test/test_fx.py index 43f97f35dcd81..eafea7d1498ee 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1777,13 +1777,13 @@ def forward(self, x): if node.op == 'get_attr': node.meta["nn_module_stack"] = "self" node.meta["stack_trace"] = "stack_trace" - node.meta["source_fn"] = "source_fn" + node.meta["source_fn_stack"] = "source_fn_stack" new_gm = Transformer(gm).transform() for node in new_gm.graph.nodes: if node.op == 'get_attr': self.assertEqual(node.meta["nn_module_stack"], "self") self.assertEqual(node.meta["stack_trace"], "stack_trace") - self.assertEqual(node.meta["source_fn"], "source_fn") + self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack") def test_interpreter(self): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 32707cee559be..5e674726a88aa 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -423,11 +423,13 @@ def remove_node(self, *args, **kwargs): return self.current_tracer.remove_node(*args, **kwargs) @contextlib.contextmanager - def new_subtracer(self): + def new_subtracer(self, source_target): new_scope_ctx = enter_new_scope() try: new_scope_ctx.__enter__() - tracer = SubgraphTracer(self, parent=self.current_tracer) + tracer = SubgraphTracer( + self, parent=self.current_tracer, source_target=source_target + ) self.tracers.append(tracer) yield tracer finally: @@ -1181,7 +1183,9 @@ class SubgraphTracer(fx.Tracer): compiling and executing the graph. """ - def __init__(self, output_graph, parent=None, export_root=False): + def __init__( + self, output_graph, parent=None, export_root=False, source_target=None + ): super().__init__() self.output_graph = weakref.proxy(output_graph) self.graph = torch.fx.Graph() @@ -1220,6 +1224,16 @@ def __init__(self, output_graph, parent=None, export_root=False): self._orig_gm_meta = None self._orig_gm_lineno_map = None self._orig_gm_firstlineno = None + # Each SubgraphTracer is associated with a source target, which indicates + # which operator this subgraph is attached to. We compute a source_fn_stack + # based on the source tareget. For the root tracer, it's set to []. + # This is useful for debugging and transforming the exported graph. + if self.parent is None: + self.source_fn_stack = [] + else: + self.source_fn_stack = self.parent.source_fn_stack + [ + (self.graph._target_to_str(source_target), source_target) + ] def create_proxy( self, @@ -1315,6 +1329,24 @@ def get_trace_call_log_str(): self._orig_gm_meta = None self._orig_gm_lineno_map = None self._orig_gm_firstlineno = None + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + unimplemented("Invoking an nn.Module inside HigherOrderOperator") + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] # preserve original meta if it is available if ( @@ -1332,26 +1364,30 @@ def get_trace_call_log_str(): meta = self._orig_gm_meta[node_idx] if "stack_trace" in meta: rv.node.meta["stack_trace"] = meta["stack_trace"] - if "nn_module_stack" in meta and "source_fn" in meta: + if "nn_module_stack" in meta and "source_fn_stack" in meta: rv.node.meta["nn_module_stack"] = meta["nn_module_stack"] - rv.node.meta["source_fn"] = meta["source_fn"] + rv.node.meta["source_fn_stack"] = meta["source_fn_stack"] if "nn_module_stack" not in rv.node.meta: nn_module_stack = tx.nn_module_stack if nn_module_stack: rv.node.meta["nn_module_stack"] = nn_module_stack.copy() - if "source_fn" not in rv.node.meta: + if "source_fn_stack" not in rv.node.meta: if kind in {"call_function", "call_method"}: - rv.node.meta["source_fn"] = (rv.node.name, target) + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] elif kind == "call_module": if self.parent is not None: unimplemented("Invoking an nn.Module inside HigherOrderOperator") # For modules we store the class - rv.node.meta["source_fn"] = ( - rv.node.name, - rv.node.meta["nn_module_stack"][target][1], - ) + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] if "stack_trace" not in rv.node.meta: frame_summaries: List[traceback.FrameSummary] = [] diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 857ab879fbdb9..208bb69905ac7 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -153,6 +153,9 @@ def speculate_subgraph( checkpoint, description, *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target=None, always_restore=False, enable_grad=False, # NOTE [Temporary argument `manually_set_subgraph_inputs`] @@ -175,7 +178,7 @@ def speculate_subgraph( ) try: - with tx.output.new_subtracer() as tracer: + with tx.output.new_subtracer(source_target) as tracer: args = validate_args_and_maybe_create_graph_inputs( sub_args, tracer, tx, manually_set_subgraph_inputs ) @@ -431,6 +434,7 @@ def speculate_branch(branch): graph_checkpoint, checkpoint, "cond", + source_target=self.value, ) if not isinstance(ret_val, TensorVariable): @@ -583,6 +587,7 @@ def call_function( tx.output.graph, checkpoint, "torch.ops.higher_order.map", + source_target=self.value, ) body_nn_modules = tx.copy_graphstate().output.nn_modules @@ -718,6 +723,7 @@ def call_function( graph_checkpoint, checkpoint, "torch.func.grad", + source_target=self.value, # See NOTE [HACK: Enable autograd while tracing function] enable_grad=True, ) @@ -929,6 +935,7 @@ def call_function( graph_checkpoint, checkpoint, "torch.vmap", + source_target=self.value, ) body_name = add_subgraph( @@ -1039,6 +1046,7 @@ def call_function( graph_checkpoint, checkpoint, "the user-defined autograd.Function", + source_target=self.value, # Backwards should never, ever be stored! always_restore=always_restore, restore_side_effects=False, @@ -1096,6 +1104,7 @@ def create_wrapped_node(self, tx, args, kwargs, description): graph_checkpoint, checkpoint, description, + source_target=self.value, manually_set_subgraph_inputs=False, ) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 4f9a08ca9a22b..d5c1d85647623 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -79,6 +79,8 @@ def _reverse_map(d: Dict[Any, Enum]): MetaType = Union[FakeTensor, int, torch.SymInt, bool, torch.SymBool] +ST_DELIMITER = ";" + _TORCH_TO_SERIALIZE_DTYPE = { torch.uint8: ScalarType.BYTE, torch.int8: ScalarType.CHAR, @@ -465,12 +467,11 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: f"{k}:({v[0]},{self.serialize_operator(v[1])})" for k, v in nn_module_stack.items() ] - ret["nn_module_stack"] = ";".join(nn_module_list) + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) - if source_fn := node.meta.get("source_fn"): - # Serialize to "fx_node_name,op_str" - op = self.serialize_operator(source_fn[1]) - ret["source_fn"] = f"{source_fn[0]},{op}" + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" for source_fn in source_fn_st] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) return ret @@ -1286,7 +1287,7 @@ def deserialize_meta_func(serialized_target: str): if nn_module_stack_str := metadata.get("nn_module_stack"): # Originally serialized to "fx_node_name:(orig_ref,type_str)" - nn_module_stack_list = nn_module_stack_str.split(";") + nn_module_stack_list = nn_module_stack_str.split(ST_DELIMITER) nn_module_stack = {} for kv in nn_module_stack_list: key_idx = kv.find(":") @@ -1310,12 +1311,13 @@ def deserialize_meta_func(serialized_target: str): nn_module_stack[key] = (kv[key_idx + 2:comma_idx], module) ret["nn_module_stack"] = nn_module_stack - if source_fn_str := metadata.get("source_fn"): + if source_fn_st_str := metadata.get("source_fn_stack"): # Originally serializes to "fx_node_name,op_str" - source_fn = source_fn_str.split(",") - op = deserialize_meta_func(source_fn[1]) - ret["source_fn"] = (source_fn[0], op) - + source_fn_st = [] + for source_fn_str in source_fn_st_str.split(ST_DELIMITER): + name, target_str = source_fn_str.split(",") + source_fn_st.append((name, deserialize_meta_func(target_str))) + ret["source_fn_stack"] = source_fn_st return ret def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 22b84ff3b01e8..dd43e4c8aa1d2 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -383,11 +383,12 @@ def get_fused_kernel_name(node_schedule, descriptive_names): # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) sources = [] for origin in all_origins: - if origin.op == "call_function" and "source_fn" in origin.meta: - if isinstance(origin.meta["source_fn"][1], str): - sources.append(origin.meta["source_fn"][1]) + if origin.op == "call_function" and "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + if isinstance(source_fn[1], str): + sources.append(source_fn[1]) else: - sources.append(origin.meta["source_fn"][1].__name__) + sources.append(source_fn[1].__name__) sources = sorted(set(sources)) elif descriptive_names == "inductor_node": sources = [ diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index d5060589af5d2..2830f60d5eab1 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -73,9 +73,10 @@ def get_source_partitions( # function, or the type of module if the node is decomposed from a leaf # module - if (source_fn := node.meta.get("source_fn", None)) is None: + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None: continue + source_fn = source_fn_st[-1] if source_fn[1] not in wanted_sources: continue diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 152a906475859..e19b3c7a73f91 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -83,7 +83,7 @@ def __exit__(self, *args): return -_COPY_META_FIELDS = ["nn_module_stack", "source_fn", "original_aten", "recompute", "from_node"] +_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node"] @compatibility(is_backward_compatible=True)