Skip to content

Commit 4331a46

Browse files
committed
[export] Update dynamo_graph_capture_for_export to return GraphModule.
1 parent 757975a commit 4331a46

File tree

9 files changed

+186
-77
lines changed

9 files changed

+186
-77
lines changed

test/distributed/tensor/test_dtensor_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def test_annotate_aot_export_joint_with_descriptors_alone(self):
323323
[
324324
(
325325
graph_capture_and_aot_export_joint_with_descriptors_v2,
326-
"[[4, 10], [4], [10, 4], [10], [4, 10], [4], [10, 4], [10], [s64, 10], [s64, 10]]",
326+
"[[4, 10], [4], [10, 4], [10], [s49, 10], [s49, 10]]",
327327
),
328328
(
329329
graph_capture_and_aot_export_joint_with_descriptors,

test/export/test_experimental.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import torch._dynamo
10+
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
1011
from torch._dynamo.test_case import run_tests, TestCase
1112
from torch._functorch.aot_autograd import aot_export_module
1213
from torch.export import export
@@ -403,8 +404,6 @@ def forward(self, x):
403404
self.assertEqual(res_export, res_eager)
404405

405406
def test_dynamo_graph_capture(self):
406-
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
407-
408407
class Foo(torch.nn.Module):
409408
def forward(self, dct, lst, bleh):
410409
x = dct["a"] * lst[1][0]
@@ -439,6 +438,22 @@ def make_inputs():
439438
test_inputs = make_inputs()
440439
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
441440

441+
def test_dynamo_graph_capture_closure(self):
442+
from torch.export import Dim
443+
444+
N = 3
445+
446+
class MyModel(torch.nn.Module):
447+
def forward(self, x):
448+
y = x[:-1, :] # [s0 - 1, 32]
449+
stacked = torch.stack([y] * N, dim=0) # [N * (s0 - 1), 32]
450+
reshaped = stacked.reshape(-1, N, 32) # [(s0 - 1), N, 32]
451+
return reshaped
452+
453+
inps = (torch.randn(10, 32),)
454+
ep = dynamo_graph_capture_for_export(MyModel())(*inps)
455+
self.assertEqual(ep(*inps), MyModel()(*inps))
456+
442457

443458
if __name__ == "__main__":
444459
run_tests()

torch/_dynamo/convert_frame.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,7 @@ def graph_capture_output(self) -> GraphCaptureOutput:
894894
output_graph.import_sources,
895895
output_graph.traced_code,
896896
self.bytecode,
897+
self.tracer_output.closure,
897898
)
898899

899900

@@ -925,6 +926,7 @@ class GraphCaptureOutput:
925926
import_sources: dict[str, str]
926927
traced_code: list[CodeType]
927928
bytecode: CodeType
929+
closure: Optional[tuple[Any, ...]]
928930

929931
def build_guards(
930932
self,
@@ -979,7 +981,7 @@ def forward_callable(self) -> Callable[..., Any]:
979981
return types.FunctionType(
980982
self.graph_capture_output.bytecode,
981983
f_globals,
982-
closure=(),
984+
closure=self.graph_capture_output.closure,
983985
)
984986

985987

torch/_dynamo/eval_frame.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,9 @@ def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
12111211

12121212
# Make dynamo graph to have same input/output spec as user code
12131213
def argument_names(
1214-
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
1214+
f_sig: inspect.Signature,
1215+
args: Union[list[Any], tuple[Any, ...]],
1216+
kwargs: dict[str, Any],
12151217
) -> list[str]:
12161218
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
12171219
# Get a list of Parameter objects from the Signature object

torch/_dynamo/functional_export.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import copy
21
import inspect
32
import logging
43
import traceback
5-
import types
64
from collections import namedtuple
5+
from dataclasses import dataclass
76
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
87

98
import sympy
@@ -23,12 +22,12 @@
2322
DimDynamic,
2423
StatelessSymbolicContext,
2524
)
26-
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
25+
from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo
26+
from torch.utils._pytree import TreeSpec
2727

2828

2929
if TYPE_CHECKING:
3030
from torch._subclasses.fake_tensor import FakeTensorMode
31-
from torch.utils._pytree import TreeSpec
3231

3332

3433
log = logging.getLogger(__name__)
@@ -449,9 +448,20 @@ def _suggest_or_raise_constraint_violation(
449448
raise constraint_violation_error
450449

451450

451+
@dataclass(frozen=True)
452+
class PyTreeifyOutput:
453+
graph_module: torch.fx.GraphModule
454+
in_spec: TreeSpec
455+
in_shuffle_graph: torch.fx.GraphModule
456+
num_flat_args: int
457+
out_spec: TreeSpec
458+
out_shuffle_graph: torch.fx.GraphModule
459+
root: Optional[torch.nn.Module] = None
460+
461+
452462
def pytreeify(
453463
out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any]
454-
) -> Any:
464+
) -> PyTreeifyOutput:
455465
"""
456466
Given a dynamo capture output, return a callable graph module that
457467
contain the following information:
@@ -469,10 +479,13 @@ def pytreeify(
469479
backend_input = out.backend_input
470480
backend = out.backend_input.graph_module
471481

482+
root = None
472483
if isinstance(mod, torch.nn.Module):
473484
args = (mod,) + args
485+
root = mod
474486
elif inspect.ismethod(mod):
475487
args = (mod.__self__,) + args
488+
root = mod.__self__
476489

477490
flat_real_args, in_spec = pytree.tree_flatten((args, kwargs))
478491

@@ -538,47 +551,73 @@ def backend_dummy(*example_inputs):
538551
out_shuffle = OutShuffle()
539552
out_shuffle_graph = torch.fx.symbolic_trace(out_shuffle)
540553

541-
def pytree_call(*args, **kwargs):
542-
import torch.export._unlift
554+
assert out_shuffle.out_spec is not None
555+
return PyTreeifyOutput(
556+
backend_input.graph_module,
557+
in_spec,
558+
in_shuffle_graph,
559+
len(flat_real_args),
560+
out_shuffle.out_spec,
561+
out_shuffle_graph,
562+
root=root, # type: ignore[arg-type]
563+
)
543564

544-
flat_args, in_spec_runtime = pytree.tree_flatten((args, kwargs))
545-
if not torch.export._unlift.eq_spec(in_spec_runtime, in_spec):
546-
raise RuntimeError(
547-
f"Model input mismatch. Expected input spec: {in_spec}. Actual input spec: {in_spec_runtime}"
548-
)
549-
flat_outs = backend_input.graph_module(*in_shuffle_graph(*flat_args))
550-
assert out_shuffle.out_spec is not None
551-
return pytree.tree_unflatten(
552-
out_shuffle_graph(*flat_args, *flat_outs), out_shuffle.out_spec
553-
)
554565

555-
if isinstance(mod, torch.nn.Module):
556-
compiled_mod = copy.copy(mod)
557-
compiled_mod.forward = types.MethodType(pytree_call, compiled_mod)
558-
if not hasattr(compiled_mod, "meta"):
559-
compiled_mod.meta = {} # type: ignore[attr-defined]
560-
if isinstance(compiled_mod.meta, dict) and "fake_mode" not in compiled_mod.meta:
561-
compiled_mod.meta["fake_mode"] = out.backend_input.fake_mode
562-
return compiled_mod
563-
elif inspect.ismethod(mod):
564-
return types.MethodType(pytree_call, mod.__self__)
565-
else:
566-
return pytree_call
566+
def normalize_graph_module(gm):
567+
for node in gm.graph.nodes:
568+
if node.op == "placeholder":
569+
node.meta["val"] = node.meta["example_value"]
567570

568571

569572
def dynamo_graph_capture_for_export(
570573
mod: Callable[..., Any],
574+
constraints: Optional[list[Constraint]] = None,
571575
) -> Callable[..., Any]:
572576
def inner(*args: Any, **kwargs: Any) -> Any:
573577
with (
574578
get_metrics_context(),
575579
dynamo_timed("fullgraph_capture"),
576580
):
577-
out = fullgraph_capture(mod, args, kwargs)
581+
out = fullgraph_capture(
582+
mod,
583+
args,
584+
kwargs,
585+
constraints=constraints,
586+
_is_export_deprecated_do_not_use=True,
587+
)
578588

579589
# TODO filter out side effects.
580-
581-
return pytreeify(out, mod, args, kwargs)
590+
pyt = pytreeify(out, mod, args, kwargs)
591+
592+
graph_module = pyt.graph_module
593+
tree_leaf_names = [
594+
graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None)
595+
for i in range(pyt.num_flat_args)
596+
]
597+
graph_module.graph._codegen = _ExportCodeGen(
598+
_PyTreeInfo(
599+
argument_names(inspect.signature(mod), args, kwargs),
600+
pyt.in_spec,
601+
pyt.out_spec,
602+
),
603+
pyt.in_shuffle_graph,
604+
pyt.out_shuffle_graph,
605+
tree_leaf_names,
606+
pyt.root,
607+
) # type: ignore[attr-defined]
608+
normalize_graph_module(graph_module)
609+
graph_module._in_spec = pyt.in_spec
610+
graph_module._out_spec = pyt.out_spec
611+
graph_module._in_shuffle_graph = pyt.in_shuffle_graph
612+
graph_module._out_shuffle_graph = pyt.out_shuffle_graph
613+
object.__setattr__(graph_module, "_root", pyt.root) # type: ignore[attr-defined]
614+
graph_module.recompile()
615+
graph_module.meta["module_call_specs"] = (
616+
out.graph_capture_output.output_graph.export_metadata.module_call_spec
617+
)
618+
assert out.backend_input is not None
619+
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
620+
return graph_module
582621

583622
return inner
584623

torch/_dynamo/output_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,12 +2741,14 @@ class DynamoTracerOutput:
27412741
error_on_graph_break: bool
27422742
is_tracing_resume_prologue: bool
27432743
output_graph: Optional[OutputGraph]
2744+
closure: Optional[tuple[Any, ...]]
27442745

27452746
def __init__(
27462747
self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
27472748
) -> None:
27482749
self.error_on_graph_break = tracer.error_on_graph_break
27492750
self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
2751+
self.closure = tracer.closure
27502752
if error:
27512753
self.output_graph = None
27522754
else:

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4119,6 +4119,7 @@ def __init__(
41194119
self.f_builtins: dict[str, Any] = f_builtins
41204120
self.code_options: dict[str, Any] = code_options
41214121
self.f_code: types.CodeType = f_code
4122+
self.closure = closure
41224123

41234124
# Execution record for replaying errors
41244125
if closure is not None and config.replay_record_enabled:

torch/export/_trace.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
GuardOnDataDependentSymNode,
9898
ShapeEnv,
9999
)
100-
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
100+
from torch.fx.graph import _PyTreeInfo
101101
from torch.utils._pytree import TreeSpec
102102
from torch.utils._sympy.value_ranges import ValueRangeError
103103

@@ -1486,12 +1486,10 @@ def _strict_export(
14861486

14871487
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined]
14881488

1489-
gm_torch_level.graph._codegen = _PyTreeCodeGen(
1490-
_PyTreeInfo(
1491-
orig_arg_names,
1492-
gm_torch_level._in_spec,
1493-
out_spec,
1494-
)
1489+
gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo(
1490+
orig_arg_names,
1491+
gm_torch_level._in_spec,
1492+
out_spec,
14951493
)
14961494
gm_torch_level.recompile()
14971495

0 commit comments

Comments
 (0)