1- import copy
21import inspect
32import logging
43import traceback
5- import types
64from collections import namedtuple
5+ from dataclasses import dataclass
76from typing import Any , Callable , Optional , TYPE_CHECKING , Union
87
98import sympy
2322 DimDynamic ,
2423 StatelessSymbolicContext ,
2524)
26- from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
25+ from torch .fx .graph import _ExportCodeGen , _PyTreeCodeGen , _PyTreeInfo
26+ from torch .utils ._pytree import TreeSpec
2727
2828
2929if TYPE_CHECKING :
3030 from torch ._subclasses .fake_tensor import FakeTensorMode
31- from torch .utils ._pytree import TreeSpec
3231
3332
3433log = logging .getLogger (__name__ )
@@ -449,9 +448,20 @@ def _suggest_or_raise_constraint_violation(
449448 raise constraint_violation_error
450449
451450
451+ @dataclass (frozen = True )
452+ class PyTreeifyOutput :
453+ graph_module : torch .fx .GraphModule
454+ in_spec : TreeSpec
455+ in_shuffle_graph : torch .fx .GraphModule
456+ num_flat_args : int
457+ out_spec : TreeSpec
458+ out_shuffle_graph : torch .fx .GraphModule
459+ root : Optional [torch .nn .Module ] = None
460+
461+
452462def pytreeify (
453463 out : CaptureOutput , mod : Any , args : tuple [Any , ...], kwargs : dict [str , Any ]
454- ) -> Any :
464+ ) -> PyTreeifyOutput :
455465 """
456466 Given a dynamo capture output, return a callable graph module that
457467 contain the following information:
@@ -469,10 +479,13 @@ def pytreeify(
469479 backend_input = out .backend_input
470480 backend = out .backend_input .graph_module
471481
482+ root = None
472483 if isinstance (mod , torch .nn .Module ):
473484 args = (mod ,) + args
485+ root = mod
474486 elif inspect .ismethod (mod ):
475487 args = (mod .__self__ ,) + args
488+ root = mod .__self__
476489
477490 flat_real_args , in_spec = pytree .tree_flatten ((args , kwargs ))
478491
@@ -538,47 +551,73 @@ def backend_dummy(*example_inputs):
538551 out_shuffle = OutShuffle ()
539552 out_shuffle_graph = torch .fx .symbolic_trace (out_shuffle )
540553
541- def pytree_call (* args , ** kwargs ):
542- import torch .export ._unlift
554+ assert out_shuffle .out_spec is not None
555+ return PyTreeifyOutput (
556+ backend_input .graph_module ,
557+ in_spec ,
558+ in_shuffle_graph ,
559+ len (flat_real_args ),
560+ out_shuffle .out_spec ,
561+ out_shuffle_graph ,
562+ root = root , # type: ignore[arg-type]
563+ )
543564
544- flat_args , in_spec_runtime = pytree .tree_flatten ((args , kwargs ))
545- if not torch .export ._unlift .eq_spec (in_spec_runtime , in_spec ):
546- raise RuntimeError (
547- f"Model input mismatch. Expected input spec: { in_spec } . Actual input spec: { in_spec_runtime } "
548- )
549- flat_outs = backend_input .graph_module (* in_shuffle_graph (* flat_args ))
550- assert out_shuffle .out_spec is not None
551- return pytree .tree_unflatten (
552- out_shuffle_graph (* flat_args , * flat_outs ), out_shuffle .out_spec
553- )
554565
555- if isinstance (mod , torch .nn .Module ):
556- compiled_mod = copy .copy (mod )
557- compiled_mod .forward = types .MethodType (pytree_call , compiled_mod )
558- if not hasattr (compiled_mod , "meta" ):
559- compiled_mod .meta = {} # type: ignore[attr-defined]
560- if isinstance (compiled_mod .meta , dict ) and "fake_mode" not in compiled_mod .meta :
561- compiled_mod .meta ["fake_mode" ] = out .backend_input .fake_mode
562- return compiled_mod
563- elif inspect .ismethod (mod ):
564- return types .MethodType (pytree_call , mod .__self__ )
565- else :
566- return pytree_call
566+ def normalize_graph_module (gm ):
567+ for node in gm .graph .nodes :
568+ if node .op == "placeholder" :
569+ node .meta ["val" ] = node .meta ["example_value" ]
567570
568571
569572def dynamo_graph_capture_for_export (
570573 mod : Callable [..., Any ],
574+ constraints : Optional [list [Constraint ]] = None ,
571575) -> Callable [..., Any ]:
572576 def inner (* args : Any , ** kwargs : Any ) -> Any :
573577 with (
574578 get_metrics_context (),
575579 dynamo_timed ("fullgraph_capture" ),
576580 ):
577- out = fullgraph_capture (mod , args , kwargs )
581+ out = fullgraph_capture (
582+ mod ,
583+ args ,
584+ kwargs ,
585+ constraints = constraints ,
586+ _is_export_deprecated_do_not_use = True ,
587+ )
578588
579589 # TODO filter out side effects.
580-
581- return pytreeify (out , mod , args , kwargs )
590+ pyt = pytreeify (out , mod , args , kwargs )
591+
592+ graph_module = pyt .graph_module
593+ tree_leaf_names = [
594+ graph_module .graph ._graph_namespace .create_name (f"_tree_leaf_{ i } " , None )
595+ for i in range (pyt .num_flat_args )
596+ ]
597+ graph_module .graph ._codegen = _ExportCodeGen (
598+ _PyTreeInfo (
599+ argument_names (inspect .signature (mod ), args , kwargs ),
600+ pyt .in_spec ,
601+ pyt .out_spec ,
602+ ),
603+ pyt .in_shuffle_graph ,
604+ pyt .out_shuffle_graph ,
605+ tree_leaf_names ,
606+ pyt .root ,
607+ ) # type: ignore[attr-defined]
608+ normalize_graph_module (graph_module )
609+ graph_module ._in_spec = pyt .in_spec
610+ graph_module ._out_spec = pyt .out_spec
611+ graph_module ._in_shuffle_graph = pyt .in_shuffle_graph
612+ graph_module ._out_shuffle_graph = pyt .out_shuffle_graph
613+ object .__setattr__ (graph_module , "_root" , pyt .root ) # type: ignore[attr-defined]
614+ graph_module .recompile ()
615+ graph_module .meta ["module_call_specs" ] = (
616+ out .graph_capture_output .output_graph .export_metadata .module_call_spec
617+ )
618+ assert out .backend_input is not None
619+ graph_module .meta ["fake_mode" ] = out .backend_input .fake_mode # type: ignore[attr-defined]
620+ return graph_module
582621
583622 return inner
584623
0 commit comments