Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 18, 2023
2 parents ec1e26a + 7b4fe21 commit 0acef59
Showing 1 changed file with 57 additions and 10 deletions.
67 changes: 57 additions & 10 deletions test/fx/test_lazy_recompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
import torch._export

class TestLazyRecompile(TestCase):
@staticmethod
def replace_sin_with_cos(gm):
for n in gm.graph.nodes:
if n.target == "sin":
n.target = "cos"

def test_replace_sin_with_cos(self):
def f(x):
return x.sin()

x = torch.randn(2, 3)
gm = fx.symbolic_trace(f)

for n in gm.graph.nodes:
if n.target == "sin":
n.target = "cos"
self.replace_sin_with_cos(gm)

gm.recompile()
expected = x.cos()
Expand All @@ -31,16 +34,11 @@ def f(x):

x = torch.randn(2, 3)
gm = fx.symbolic_trace(f)

for n in gm.graph.nodes:
if n.target == "sin":
n.target = "cos"

self.replace_sin_with_cos(gm)
gm.recompile()
expected = x.cos()
actual = gm.forward(x)

print(f"sin {x.sin()}, cos {x.cos()}, expected {expected}, actual {actual}")
self.assertTrue(torch.allclose(expected, actual))

def test_export(self):
Expand All @@ -53,5 +51,54 @@ def f(x):
gm = torch._export.export(f, (torch.randn(2, 3),))
self.assertTrue(isinstance(gm, torch._export.ExportedProgram))

def test_needs_recompile(self):
"""
Make sure needs_recompile() return the corrent state.
"""
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
gm(torch.randn(2, 3))
self.assertFalse(gm._needs_recompile())

def test_multi_recompile(self):
"""
Cover the case that multiple recompilation happens.
"""
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
x = torch.randn(2, 3)
# trigger the first recompilation
self.assertTrue(torch.allclose(x.sin(), gm(x)))
self.assertFalse(gm._needs_recompile())

self.replace_sin_with_cos(gm)
self.assertFalse(gm._needs_recompile())
gm.recompile()
self.assertTrue(gm._needs_recompile())
# trigger the second recompilation
self.assertTrue(torch.allclose(x.cos(), gm(x)))
self.assertFalse(gm._needs_recompile())


def test_accessing_code_cause_recompiling(self):
"""
Make sure we recompile if we have not done that yet when we access the code
property of a GraphModule.
"""
def f(x):
return x.sin()

gm = fx.symbolic_trace(f)
self.assertTrue(gm._needs_recompile())
# should trigger a recompilation
_ = gm.code
self.assertFalse(gm._needs_recompile())

if __name__ == "__main__":
run_tests()

0 comments on commit 0acef59

Please sign in to comment.