Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 18, 2023
2 parents 6c3ed1d + 05f4671 commit 411e03d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 4 deletions.
59 changes: 59 additions & 0 deletions test/fx/test_lazy_recompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torch.testing._internal.common_utils import TestCase, run_tests
from torch import fx
from torch.fx.experimental.proxy_tensor import make_fx
import torch
import torch._export

Expand Down Expand Up @@ -107,5 +108,63 @@ def f(x):
gm = fx.symbolic_trace(f)
self.assertTrue("sin" in str(gm))

def test_recapture_with_make_fx(self):
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
gm2 = make_fx(gm, (torch.randn(2, 3),))

# gm still has pending recompilation make_fx can smoothly handle
# lazye recompilation since its implemented thru the dispatcher.
self.assertTrue(gm._needs_recompile())

def test_recapture_with_symbolic_trace(self):
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
gm2 = fx.symbolic_trace(gm)

# the lazy recompilcation is already realized. We realize the
# recompilation in the beginning of symbolic_trace since symbolic_trace can not
# handle the tracing of lazy recompilation.
self.assertFalse(gm._needs_recompile())

def test_recapture_with_dynamo(self):
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
gm2 = torch.compile(gm)(torch.rand(2, 3))

# the lazy recompilcation is already realized. We realize the
# recompilation in the beginning of dynamo since dynamo can not
# handle the tracing of lazy recompilation.
self.assertFalse(gm._needs_recompile())


def test_recapture_with_torchscript(self):
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
gm2 = torch.jit.script(gm)

# the lazy recompilcation is already realized. We realize the
# recompilation in the beginning of torchscript since torchscript can not
# handle the tracing of lazy recompilation.
#
# The real recompilation is triggered for torchscript automatically
# when the get_overload_annotations API in torch/jit/_recursive.py is called.
# This API will access the perperties like graph_in_spec etc which force
# recompilation.
self.assertFalse(gm._needs_recompile())


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def forward(self, x):
m, (torch.randn(1, 1, 2, 2),),
results_len=2)

# @skipIfNoFBGEMM
@skipIfNoFBGEMM
def test_add_mul_inputs_activations(self):
m = AddMulFunctional().eval()
res = self._test_match_activations(
Expand Down
3 changes: 0 additions & 3 deletions torch/_functorch/compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,6 @@ def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
def simple_ts_compile(fx_g, _):
strip_overloads(fx_g)

# realize the lazy recompilication to make jit.script happy.
fx_g.real_recompile()

f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
return f
Expand Down

0 comments on commit 411e03d

Please sign in to comment.