Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 17, 2023
2 parents 01a5e32 + 90c8cd5 commit 4694c6d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
10 changes: 10 additions & 0 deletions test/fx/test_lazy_recompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.testing._internal.common_utils import TestCase, run_tests
from torch import fx
import torch
import torch._export

class TestLazyRecompile(TestCase):
def test_replace_sin_with_cos(self):
Expand Down Expand Up @@ -42,6 +43,15 @@ def f(x):
print(f"sin {x.sin()}, cos {x.cos()}, expected {expected}, actual {actual}")
self.assertTrue(torch.allclose(expected, actual))

def test_export(self):
"""
torch.export will access GraphModule._out_spec. Make sure we generate them
if we have not done that yet.
"""
def f(x):
return x.sin()
gm = torch._export.export(f, (torch.randn(2, 3),))
self.assertTrue(isinstance(gm, torch._export.ExportedProgram))

if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def export(
# because aot_export expects a tuple as return type
return_val = f(*args)
flat_args, in_spec = pytree.tree_flatten(args)
out_spec = orig_out_spec = gm_torch_level._out_spec
out_spec = orig_out_spec = gm_torch_level.out_spec
# this means it is scalar return value, so will make it tuple
if not isinstance(return_val, (list, tuple)):
out_spec = pytree.tree_flatten((return_val,))[1]
Expand Down
5 changes: 5 additions & 0 deletions torch/_functorch/compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
@make_boxed_compiler
def simple_ts_compile(fx_g, _):
strip_overloads(fx_g)

# realize the lazy recompilication to make jit.script happy.
if fx_g._needs_recompile():
fx_g._real_recompile()

f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
return f
Expand Down
24 changes: 22 additions & 2 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,22 @@ def code(self) -> str:
Return the Python code generated from the ``Graph`` underlying this
``GraphModule``.
"""
if self._needs_recompile():
self._real_recompile()

self.real_recompile()
if not hasattr(self, '_code'):
raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
return self._code

@property
def in_spec(self):
self.real_recompile()
return getattr(self, "_in_spec", None)

@property
def out_spec(self):
self.real_recompile()
return getattr(self, "_out_spec", None)

@compatibility(is_backward_compatible=True)
@classmethod
def recompile(cls):

Check notice on line 663 in torch/fx/graph_module.py

View workflow job for this annotation

GitHub Actions / bc_linter

Function GraphModule.recompile: self was renamed to cls
Expand All @@ -665,6 +674,17 @@ def _lazy_forward(self, *args, **kwargs):

forward = _lazy_forward

def real_recompile(self):
"""
A torch script safe wrapper around _real_recompile.
Call _real_recompile only if we have not done that yet after the last
change to the fx.Graph
"""
# Jit scripting can not handle `_needs_recompile` or `_real_recompile`.
if not torch.jit.is_scripting():
if self._needs_recompile():
self._real_recompile()

def _real_recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
Expand Down

0 comments on commit 4694c6d

Please sign in to comment.