Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 18, 2023
2 parents 963bed2 + d2227ce commit 0b7b0f5
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
7 changes: 7 additions & 0 deletions test/fx/test_lazy_recompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,12 @@ def f(x):
_ = gm.code
self.assertFalse(gm._needs_recompile())

def test_graph_module_str(self):
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue("sin" in str(gm))

if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def __call__(self, fn):
fn = innermost_fn(fn)
if isinstance(fn, torch.fx.GraphModule):
# do the real recompile so dynamo don't need to handle the lazy
# recompilation forward method
fn.real_recompile()
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
Expand Down
4 changes: 4 additions & 0 deletions torch/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,10 @@ def f(x):
Returns:
GraphModule: a Module created from the recorded operations from ``root``.
"""
if isinstance(root, GraphModule):
# If we retracing a GraphModule, make sure we realize the lazy
# recompilation.
root.real_recompile()
tracer = Tracer()
graph = tracer.trace(root, concrete_args)
name = (
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def print_readable(self, print_output=True):
def __str__(self) -> str:
orig_str = super().__str__()
print_readable_reminder = "# To see more debug info, please use `graph_module.print_readable()`"
return '\n'.join([orig_str, self._code, print_readable_reminder])
return '\n'.join([orig_str, self.code, print_readable_reminder])

def _replicate_for_data_parallel(self):
new_gm = self.__copy__()
Expand Down

0 comments on commit 0b7b0f5

Please sign in to comment.