Skip to content

Commit

Permalink
Update on "recompile fx.GraphModule lazily"
Browse files Browse the repository at this point in the history
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jul 19, 2023
2 parents 7ad6c5a + 189f95d commit 2999a11
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
5 changes: 4 additions & 1 deletion test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def setUp(self):

if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
lib_file_path = find_library_location('libtorchbind_test.so')
torch.ops.load_library(str(lib_file_path))
# torch.ops.load_library(str(lib_file_path))

def tearDown(self):
super().tearDown()
Expand Down Expand Up @@ -250,6 +250,9 @@ def forward(self, a):
return a * 2

gm = symbolic_trace(MyModule())
# TODO: this need revise before landing. Adding this to see more
# test signals.
gm.real_recompile()
self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)

def test_custom_import(self):
Expand Down
8 changes: 5 additions & 3 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,21 +746,23 @@ def call_wrapped(self, *args, **kwargs):
# Passing Tracer as argument allows subclasses extending fx.GraphModule
# define their own Tracer (extending fx.Tracer).
def __reduce_deploy__(self, importer: Importer):
# call _real_recompile before accessing self.__dict__ since the former
# may add extra keys to self.__dict__.
python_code = self._real_recompile()
dict_without_graph = self.__dict__.copy()
dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
del dict_without_graph['_graph']

python_code = self._real_recompile()
import_block = _format_import_block(python_code.globals, importer)
return (reduce_deploy_graph_module, (dict_without_graph, import_block))

def __reduce_package__(self, exporter: PackageExporter):
python_code = self._real_recompile()
dict_without_graph = self.__dict__.copy()
dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
del dict_without_graph['_graph']

generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
python_code = self._real_recompile()
import_block = _format_import_block(python_code.globals, exporter.importer)
module_code = import_block + self.code
exporter.save_source_string(generated_module_name, module_code)
Expand All @@ -774,8 +776,8 @@ def __reduce__(self):
On the deserialization side, we symbolically trace through the generated
code to regenerate the underlying ``Graph``
"""
dict_without_graph = self.__dict__.copy()
python_code = self._real_recompile()
dict_without_graph = self.__dict__.copy()
import_block = _format_import_block(python_code.globals, sys_importer)
del dict_without_graph['_graph']
return (reduce_graph_module, (dict_without_graph, import_block))
Expand Down

0 comments on commit 2999a11

Please sign in to comment.