Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 18, 2023
2 parents 0b7b0f5 + 0697ed5 commit 6c3ed1d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
2 changes: 1 addition & 1 deletion test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def forward(self, x):
m, (torch.randn(1, 1, 2, 2),),
results_len=2)

@skipIfNoFBGEMM
# @skipIfNoFBGEMM
def test_add_mul_inputs_activations(self):
m = AddMulFunctional().eval()
res = self._test_match_activations(
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def export(
# because aot_export expects a tuple as return type
return_val = f(*args)
flat_args, in_spec = pytree.tree_flatten(args)
out_spec = orig_out_spec = gm_torch_level.out_spec
out_spec = orig_out_spec = gm_torch_level.graph_out_spec
# this means it is scalar return value, so will make it tuple
if not isinstance(return_val, (list, tuple)):
out_spec = pytree.tree_flatten((return_val,))[1]
Expand All @@ -222,7 +222,7 @@ def export(
gm_torch_level.graph._codegen = _PyTreeCodeGen(
_PyTreeInfo(
orig_args,
gm_torch_level._in_spec,
gm_torch_level.graph_in_spec,
out_spec,
)
)
Expand Down
9 changes: 5 additions & 4 deletions torch/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,11 @@ def trace(
A ``Graph`` representing the semantics of the passed-in ``root``.
"""
if isinstance(root, GraphModule):
# If we retracing a GraphModule, make sure we realize the lazy
# recompilation.
root.real_recompile()

global _is_fx_tracing_flag
old_is_fx_tracing_flag = _is_fx_tracing_flag
_is_fx_tracing_flag = True
Expand Down Expand Up @@ -1146,10 +1151,6 @@ def f(x):
Returns:
GraphModule: a Module created from the recorded operations from ``root``.
"""
if isinstance(root, GraphModule):
# If we retracing a GraphModule, make sure we realize the lazy
# recompilation.
root.real_recompile()
tracer = Tracer()
graph = tracer.trace(root, concrete_args)
name = (
Expand Down
36 changes: 28 additions & 8 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,14 +649,34 @@ def code(self) -> str:
return self._code

@property
def in_spec(self):
def graph_in_spec(self):
self.real_recompile()
return getattr(self, "_in_spec", None)
# even after recompiliation, _graph_in_spec may still be undefined
# if self._graph._codegen is not a _PyTreeCodeGen. So we need provide
# a default value for getattr.
return getattr(self, "_graph_in_spec", None)

@property
def out_spec(self):
def _in_spec(self):
"""
Deprecated. Use graph_in_spec instead.
"""
return self.graph_in_spec

@property
def graph_out_spec(self):
self.real_recompile()
return getattr(self, "_out_spec", None)
# even after recompiliation, _graph_out_spec may still be undefined
# if self._graph._codegen is not a _PyTreeCodeGen. So we need provide
# a default value for getattr.
return getattr(self, "_graph_out_spec", None)

@property
def _out_spec(self):
"""
Deprecated. Use graph_out_spec instead.
"""
return self.graph_out_spec

@compatibility(is_backward_compatible=True)
@classmethod
Expand Down Expand Up @@ -692,8 +712,8 @@ def _real_recompile(self) -> PythonCode:
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
self._graph_in_spec = self._graph._codegen.pytree_info.in_spec
self._graph_out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
self._code = python_code.src

Expand Down Expand Up @@ -726,7 +746,7 @@ def __reduce_deploy__(self, importer: Importer):
dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
del dict_without_graph['_graph']

python_code = self.recompile()
python_code = self._real_recompile()
import_block = _format_import_block(python_code.globals, importer)
return (reduce_deploy_graph_module, (dict_without_graph, import_block))

Expand All @@ -736,7 +756,7 @@ def __reduce_package__(self, exporter: PackageExporter):
del dict_without_graph['_graph']

generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
python_code = self.recompile()
python_code = self._real_recompile()
import_block = _format_import_block(python_code.globals, exporter.importer)
module_code = import_block + self.code
exporter.save_source_string(generated_module_name, module_code)
Expand Down

0 comments on commit 6c3ed1d

Please sign in to comment.