Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] convert layout of conv weight ahead of time for inference #103642

Closed
wants to merge 18 commits into from

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented Jun 15, 2023

Stack from ghstack (oldest at bottom):

This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).

  • convmixer: 4.285x -> 4.309x
  • resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.

Commands

TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference  

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 15, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103642

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit b8c4556:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Jun 15, 2023
ghstack-source-id: 705643764afe28030fbc46bfe289acda5840cf37
Pull Request resolved: #103642
@shunting314 shunting314 changed the title [wip][inductor] convert layout of conv weight ahead of time [wip][inductor] convert layout of conv weight ahead of time for inference Jun 15, 2023
…e for inference"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jun 15, 2023
…ence

ghstack-source-id: 3288073a1050922f36e67ccaad694f58942b9f5f
Pull Request resolved: #103642
…e for inference"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jun 15, 2023
…ence

ghstack-source-id: 749f016d781016cec09406d271a4d0f531469ab7
Pull Request resolved: #103642
…e for inference"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jun 16, 2023
…ence

ghstack-source-id: 43e0ce6937b992ccb4e13099fd1b57251a70bc36
Pull Request resolved: #103642
…e for inference"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jun 16, 2023
…ence

ghstack-source-id: cecab3094abb8a908d8e36dd3075d3c29ba08699
Pull Request resolved: #103642
@shunting314 shunting314 changed the title [wip][inductor] convert layout of conv weight ahead of time for inference [inductor] convert layout of conv weight ahead of time for inference Jun 16, 2023
… inference"


This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.


Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference  
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jun 16, 2023
ghstack-source-id: 8026814f2d984c5454d8fc28b2bab7aa0fc833de
Pull Request resolved: #103642
… inference"


This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.


Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference  
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jun 16, 2023
ghstack-source-id: 310f78944dec0abf991d0ccaab408ccbcfcb16a5
Pull Request resolved: #103642
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool ! A few comments. let me know what you think

torch/_inductor/freezing.py Outdated Show resolved Hide resolved
torch/_inductor/freezing.py Outdated Show resolved Hide resolved
@jansel
Copy link
Contributor

jansel commented Jun 21, 2023

At a glance this looks good to me. I'll let @eellison do the final review.

@jansel jansel removed their request for review June 21, 2023 16:03
@shunting314
Copy link
Contributor Author

@eellison @jansel this is more complex then I initially thought. I just found one issue. We convert convolution weights to channels last (either change GraphModule attribute directly or add aten.clone nodes to the graph to do the layout convolution). Later on in compile_fx_inner when we do FakeTensorProp, the changed strides of the conv weight will be propagated thru the graph. After the propagation, a previously contiguous output tensor in eager may become a channels last tensor.

I think to avoid that and make sure we don't change output tensor's layout, we need

  1. collect all output tensors layout before the pass added by this PR. Those represent the eager layout we need for output tensors
  2. perform the pass added by this PR
  3. explicitly add aten.clone node in the end of the graph to convert each of the output tensor's stride to those collected in step 1.

Another variant of step3 is, we collect all the output tensor's stride after performing step2 (Need do an extra FakeTensorProp) and apply step3 only to those output tensors whose stride get changed. But this may not be necessary since I assume adding redundant no-op aten.clone node is fine. We can remove those no-op aten.clone nodes from the graph in a post_grad pass. (maybe we already did).

Wdyt?

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/shunting314/66/head branch July 2, 2023 14:16
shunting314 added a commit that referenced this pull request Jul 17, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule.recompile` to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 17, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule.recompile` to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 17, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 17, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 18, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 19, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 19, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 19, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 19, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 19, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Jul 19, 2023
Context: eellison 's review comment [here](#103642 (comment)) complains about my code calling `torch.fx.GraphModule.recompile` after I changed the graph.  We didn't simply remove the call to `recompile` at that time since that increases the risk that user see or run stale python code. In this PR, I recompile GraphModule lazily without increasing the risk that user see/run stale python code.

When training BertForMaskedLM, the `GraphModule.recompile` is called 707 times and takes 1.8s in total. The whole compilation takes around 60 seconds.

By spot checking, I found the main reason we call recompile so frequently is due to inductor pattern matcher. E.g., if we want to replace src_fn with dst_fn, we need trace both src_fn and dst_fn. After tracing is done, we create a GraphModule. The init method of GraphModule will call recompile.

By doing recompile lazily, we reduce the number of calls for `GraphModule._real_recompile` (in this PR, `recompile` just mark the class as needing recompilation and is very light weight. `_real_recompile` does the real recompilation) to 37 times and reduces its total execution time to 0.045s.



cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants