Skip to content

Commit

Permalink
Replace node.meta source_fn with source_fn_stack (#108595)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
  • Loading branch information
ydwu4 authored and facebook-github-bot committed Sep 18, 2023
1 parent 8a567bb commit 418a30e
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 45 deletions.
7 changes: 4 additions & 3 deletions test/dynamo/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,11 @@ def _prepare_model_args():
continue
if min_seq_nr < 0:
min_seq_nr = seq_nr
mod_name = node.meta.get("source_fn", "")
source_fn_stack = node.meta.get("source_fn_stack", [])
orig_aten = node.meta.get("original_aten", "")
if isinstance(mod_name, tuple):
mod_name = mod_name[0]
mod_name = ""
if len(source_fn_stack) > 0:
mod_name = source_fn_stack[-1][0]
# Make all seq_nr relative so it starts at 0
seq_nr = seq_nr - min_seq_nr
seq_table = seq_table + f"{seq_nr}|{orig_aten}|{mod_name}\n"
Expand Down
5 changes: 2 additions & 3 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def forward(self, x):
if node.op not in {"placeholder", "output"}:
self.assertTrue(node.stack_trace is not None)
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertTrue(node.meta["source_fn"] is not None)
self.assertTrue(node.meta["source_fn_stack"] is not None)

torch._dynamo.reset()

Expand All @@ -974,7 +974,7 @@ def forward(self, x):
if node.op == "call_function":
self.assertTrue(node.stack_trace is not None)
self.assertTrue(node.meta["nn_module_stack"] is not None)
self.assertTrue(node.meta["source_fn"] is not None)
self.assertTrue(node.meta["source_fn_stack"] is not None)
self.assertTrue(node.meta["val"] is not None)
self.assertTrue(node.meta["original_aten"] is not None)

Expand Down Expand Up @@ -4013,7 +4013,6 @@ def fn(x):
self.assertEqual(
nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
)
self.assertEqual(nd1.meta["source_fn"], nd2.meta["source_fn"])
self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])

def test_preserve_fx_node_metadata_recompile(self):
Expand Down
149 changes: 149 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import functools
import pprint
import re
import unittest

Expand Down Expand Up @@ -1692,6 +1693,154 @@ def fn(x):

self.assertTrue(activations.keys() == forward_handles.keys())

def _get_source_fn_stack(self, gm, node_names):
ret = {}
for mod in gm.modules():
for node in mod.graph.nodes:
if node.name in node_names:
actual_stack = [
name for name, _ in node.meta.get("source_fn_stack", [])
]
ret[node.name] = actual_stack
return ret

def test_wrap_source_fn_stack(self):
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)

mod = MockModule()

def gn(x):
return torch.cos(x) + wrap(mod, x)

def fn(x):
return wrap(gn, x)

backend = EagerAndRecordGraphs()
inp = torch.randn((4, 4))
torch.compile(fn, backend=backend, fullgraph=True)(inp)

gm = backend.graphs[0]
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"})
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""\
{'add': ['wrap', 'add'],
'cos': ['wrap', 'cos'],
'linear': ['wrap', 'wrap', 'linear']}""",
)

def test_cond_source_fn_stack(self):
backend = EagerAndRecordGraphs()

@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
def true_fn(pred2, x, y):
return x + y

def false_fn(pred2, x, y):
def true_fn2(x, y):
return x.sin() - y.cos()

def false_fn2(x, y):
return x.cos() - y.sin()

return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])

return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])

pred = torch.tensor(True)
pred2 = torch.tensor(False)
xs = torch.randn(2, 3, 3)
y = torch.randn(3, 3)
cond_f(pred, pred2, xs, y)

gm = backend.graphs[0]
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"})
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""\
{'add': ['cond', 'add'],
'cos': ['cond', 'cond', 'cos'],
'sin': ['cond', 'cond', 'sin'],
'sub': ['cond', 'cond', 'sub']}""",
)

def test_map_source_fn_stack(self):
backend = EagerAndRecordGraphs()

xs = torch.randn(2, 3, 3)
y = torch.randn(3)

@torch.compile(backend=backend, fullgraph=True)
def map_f(xs, y):
def inner(x, y):
def inner2(x, y):
return x + y

return control_flow.map(inner2, x, y) * y.cos()

return control_flow.map(inner, xs, y).sin()

result = map_f(xs, y)

gm = backend.graphs[0]
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"})
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""",
)

def test_grad_source_fn_stack(self):
backend = EagerAndRecordGraphs()

def fn(x):
return x.sin().sum()

@torch.compile(backend=backend, fullgraph=False)
def wrapper_fn(x):
return torch.func.grad(torch.func.grad(fn))(x)

x = torch.randn(())

wrapper_fn(x)
gm = backend.graphs[0]
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""\
{'sin': ['grad_impl', 'grad_impl', 'sin'],
'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""",
)

def test_vmap_source_fn_stack(self):
backend = EagerAndRecordGraphs()

def inner_fn(x):
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)

@torch.compile(backend=backend, fullgraph=True)
def fn(x):
return torch.func.vmap(lambda x: inner_fn(x.cos()))(x)

x = torch.randn(3, 3, 3, 3)
fn(x)
gm = backend.graphs[0]
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sum_2", "cos", "add"})
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""\
{'add': ['vmap_impl', 'vmap_impl', 'add'],
'cos': ['vmap_impl', 'cos'],
'sum_1': ['vmap_impl', 'vmap_impl', 'sum_1'],
'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""",
)


class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
def run(self, result=None):
Expand Down
11 changes: 8 additions & 3 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def forward(self, x):
node.name in ep.graph_signature.inputs_to_buffers or
node.name in ep.graph_signature.inputs_to_parameters
):
self.assertTrue("source_fn" in node.meta)
self.assertTrue("source_fn_stack" in node.meta)
self.assertTrue("nn_module_stack" in node.meta)


Expand Down Expand Up @@ -1071,8 +1071,13 @@ def forward(self, x):
for mod in gm.modules():
for node in mod.graph.nodes:
if node.name in {"sin", "cos"}:
actual_source_fns.append(node.meta.get("source_fn", None))
exp_source_fns = [("cos", "cos"), ("sin", "sin")]
source_fn_st = node.meta.get("source_fn_stack", None)
if source_fn_st is not None:
source_names = []
for source_fn in source_fn_st:
source_names.append(source_fn[0])
actual_source_fns.append(source_names)
exp_source_fns = [["cond", "cos"], ["cond", "sin"]]
self.assertEqual(actual_source_fns, exp_source_fns)

def test_lift_constants(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ def _check_graph_nodes(gm1, gm2, _check_meta=True):
# node1.meta.get("nn_module_stack", None),
# node2.meta.get("nn_module_stack", None),
# )
# Check "source_fn" metadata
# Check "source_fn_stack" metadata
self.assertEqual(
node1.meta.get("source_fn", None),
node2.meta.get("source_fn", None),
node1.meta.get("source_fn_stack", None),
node2.meta.get("source_fn_stack", None),
)

_check_graph_nodes(ep.graph_module, deserialized_ep.graph_module, _check_meta)
Expand Down
4 changes: 2 additions & 2 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ def false_fn(x):
return x * x.sin()

def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
inp = torch.randn([4, 3])
gm, _ = torch._dynamo.export(foo)(inp)

Expand All @@ -1476,7 +1476,7 @@ def run_with_interpreter(*args):


checked_ops = {"add", "mul", "sin", "cos"}
checked_meta = ["source_fn", "stack_trace"]
checked_meta = ["source_fn_stack", "stack_trace"]
all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
new_source_fns = collect_meta_for_filtered_nodes(new_gm, checked_ops, checked_meta)
self.assertEqual(all_source_fns, new_source_fns)
Expand Down
4 changes: 2 additions & 2 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,13 +1777,13 @@ def forward(self, x):
if node.op == 'get_attr':
node.meta["nn_module_stack"] = "self"
node.meta["stack_trace"] = "stack_trace"
node.meta["source_fn"] = "source_fn"
node.meta["source_fn_stack"] = "source_fn_stack"
new_gm = Transformer(gm).transform()
for node in new_gm.graph.nodes:
if node.op == 'get_attr':
self.assertEqual(node.meta["nn_module_stack"], "self")
self.assertEqual(node.meta["stack_trace"], "stack_trace")
self.assertEqual(node.meta["source_fn"], "source_fn")
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")


def test_interpreter(self):
Expand Down
58 changes: 47 additions & 11 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,13 @@ def remove_node(self, *args, **kwargs):
return self.current_tracer.remove_node(*args, **kwargs)

@contextlib.contextmanager
def new_subtracer(self):
def new_subtracer(self, source_target):
new_scope_ctx = enter_new_scope()
try:
new_scope_ctx.__enter__()
tracer = SubgraphTracer(self, parent=self.current_tracer)
tracer = SubgraphTracer(
self, parent=self.current_tracer, source_target=source_target
)
self.tracers.append(tracer)
yield tracer
finally:
Expand Down Expand Up @@ -1181,7 +1183,9 @@ class SubgraphTracer(fx.Tracer):
compiling and executing the graph.
"""

def __init__(self, output_graph, parent=None, export_root=False):
def __init__(
self, output_graph, parent=None, export_root=False, source_target=None
):
super().__init__()
self.output_graph = weakref.proxy(output_graph)
self.graph = torch.fx.Graph()
Expand Down Expand Up @@ -1220,6 +1224,16 @@ def __init__(self, output_graph, parent=None, export_root=False):
self._orig_gm_meta = None
self._orig_gm_lineno_map = None
self._orig_gm_firstlineno = None
# Each SubgraphTracer is associated with a source target, which indicates
# which operator this subgraph is attached to. We compute a source_fn_stack
# based on the source tareget. For the root tracer, it's set to [].
# This is useful for debugging and transforming the exported graph.
if self.parent is None:
self.source_fn_stack = []
else:
self.source_fn_stack = self.parent.source_fn_stack + [
(self.graph._target_to_str(source_target), source_target)
]

def create_proxy(
self,
Expand Down Expand Up @@ -1315,6 +1329,24 @@ def get_trace_call_log_str():
self._orig_gm_meta = None
self._orig_gm_lineno_map = None
self._orig_gm_firstlineno = None
nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

if kind in {"call_function", "call_method"}:
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
(rv.node.name, target)
]
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# For modules we store the class
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
(
rv.node.name,
rv.node.meta["nn_module_stack"][target][1],
)
]

# preserve original meta if it is available
if (
Expand All @@ -1332,26 +1364,30 @@ def get_trace_call_log_str():
meta = self._orig_gm_meta[node_idx]
if "stack_trace" in meta:
rv.node.meta["stack_trace"] = meta["stack_trace"]
if "nn_module_stack" in meta and "source_fn" in meta:
if "nn_module_stack" in meta and "source_fn_stack" in meta:
rv.node.meta["nn_module_stack"] = meta["nn_module_stack"]
rv.node.meta["source_fn"] = meta["source_fn"]
rv.node.meta["source_fn_stack"] = meta["source_fn_stack"]

if "nn_module_stack" not in rv.node.meta:
nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

if "source_fn" not in rv.node.meta:
if "source_fn_stack" not in rv.node.meta:
if kind in {"call_function", "call_method"}:
rv.node.meta["source_fn"] = (rv.node.name, target)
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
(rv.node.name, target)
]
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# For modules we store the class
rv.node.meta["source_fn"] = (
rv.node.name,
rv.node.meta["nn_module_stack"][target][1],
)
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
(
rv.node.name,
rv.node.meta["nn_module_stack"][target][1],
)
]

if "stack_trace" not in rv.node.meta:
frame_summaries: List[traceback.FrameSummary] = []
Expand Down

0 comments on commit 418a30e

Please sign in to comment.