-
Notifications
You must be signed in to change notification settings - Fork 25.3k
recompile fx.GraphModule lazily #105257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
recompile fx.GraphModule lazily #105257
Conversation
[ghstack-poisoned]
I'll review after tests are passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests seem pretty broken.
Rather than hooking into __call__
(which won't work if one calls forward
directly), I'd suggest:
def _lazy_forward(self, *args, **kwargs):
self._real_recompile()
assert not self._needs_recompile()
return self.forward(*args, **kwargs)
forward = _lazy_forward
@classmethod
def _needs_recompile(cls):
return cls.forward is cls._lazy_forward
@classmethod
def recompile(cls):
cls.forward = cls._lazy_forward
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph. We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code. When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds. By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile. By doing recompile lazily, we reduce the number of calls for `GraphModule.recompile` to 37 times and reduces its total execution time to 0.045s. [ghstack-poisoned]
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph. We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code. When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds. By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile. By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s. [ghstack-poisoned]
Actually, you might not want the |
The I think this 'a separate subclass per GraphModule instance' trick is not 100% necessary. An alternative is as you said, we change |
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph. We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code. When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds. By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile. By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s. [ghstack-poisoned]
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph. We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code. When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds. By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile. By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s. [ghstack-poisoned]
|
||
def real_recompile(self): | ||
""" | ||
A torch script safe wrapper around _real_recompile. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's incredibly strange to me that recompile is being called from TorchScriptable code!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is how it happens from looking at some failure stack trace. When torchscript gets a GraphModule, it will try to inspect the GraphModule's attributes. When it looks at the 'code' property, it will try to 'script' the method implementation. Due to the lazy recompilation, we may need to recompile the graph module when 'code' property is being accessed. That's how what you described happens.
I think a not too bad fix is to always force a real recompilation when TorchScript backend gets a graph module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like TorchScript should manage that, not the code inside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally yes. But this hits some limitation of torchscript - some of the lazy recompile related methods can not be handled by torchscript. Since we are deprecating torchscript, probably we should workaround rather than spending efforts to improve torchscript?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang so we encounter similar issues when passing a lazy recompiled GraphModule to dynamo. Similar to the torchscript case, there is an option to enhance the graph capturing layer (dynamo or torchscript) to handle the GraphModule lazy recompilation. But that may take quite some effort and I'm not sure if it's necessary. it would be handy to just force the recompile at the beginning of these graph capturing layers. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this gets hard to land you might have better luck with adding a lazy-recompile method which we call inside torch/_inductor / PT2 stack, instead of trying to always make recompile() always lazy. As you said, it's mostly within the pattern matcher that the invocations come from.
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
I feel it's easier to open a new PR rather than iterating on the previous PR (#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. Pull Request resolved: #117911 Approved by: https://github.com/jansel
I feel it's easier to open a new PR rather than iterating on the previous PR (pytorch#105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. Pull Request resolved: pytorch#117911 Approved by: https://github.com/jansel
Stack from ghstack (oldest at bottom):
Context: @eellison 's review comment here complains about my code calling
torch.fx.GraphModule.recompile
after I changed the graph. We didn't simply remove the call torecompile
at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.When training BertForMaskedLM, the
GraphModule.recompile
is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.
By doing recompile lazily, we reduce the number of calls for
GraphModule._real_recompile
(in this PR,recompile
just mark the class as needing recompilation and is very light weight._real_recompile
does the real recompilation) to 37 times and reduces its total execution time to 0.045s.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @Xia-Weiwen @anijain2305 @ipiszy