Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Post Freezing Optimizations, turn on by default in torch.jit.freeze #50222

Closed
wants to merge 9 commits into from

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Jan 7, 2021

Stack from ghstack:

This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal.

I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use freeze_module, not the python API, so this shouldn't break anything.

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a torch::jit::freeze which follows the same api as torch.jit.freeze intended for C++ use, and runs the optimizations.

Differential Revision: D25856264

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 7, 2021

💊 CI failures summary and remediations

As of commit 8896c3e (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 2/2 non-CircleCI failure(s)

Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 72 times.

…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. 









[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Jan 7, 2021
ghstack-source-id: 5acbb012b21c13cc7b6817f8ca6e0d20a616d8ab
Pull Request resolved: #50222
…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. 









[ghstack-poisoned]
…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. 









[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Jan 7, 2021
ghstack-source-id: 61c598b050a62772117920d85c54ddbc06edd914
Pull Request resolved: #50222
…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. 









[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Jan 7, 2021
ghstack-source-id: 75fece97a678d3ab95c4d4a0c251ffdc499d1f1f
Pull Request resolved: #50222
…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. 









[ghstack-poisoned]
Copy link
Contributor

@bzinodev bzinodev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it interact with graph mode quantization? And also for mobile? I suggest to have couple a test?

@eellison
Copy link
Contributor Author

eellison commented Jan 8, 2021

It doesn't, because they all use _freeze_module, not the torch.jit.freeze API. If they did tests would break.

@eellison eellison requested a review from bzinodev January 8, 2021 19:35
Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

namespace jit {

void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
// run a couple times to capture Conv -> Mul -> Add etc

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth running that while something is changed (with some threshold to not run for too long).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be nice but i'm not sure how relevant it is, might do as a follow up.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this simply because it's a common pattern in compilers. I agree with you that it's not clear whether it matters here.

@@ -10,7 +10,7 @@
from torch.jit._script import RecursiveScriptModule, ScriptModule


def freeze(mod, preserved_attrs: Optional[List[str]] = None):
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easier to land this in two steps: 1) add a new flag with the new functionality, but disabled by default, 2) flip the default. This way if something breaks, only the flag switch would be reverted.

eellison added 2 commits January 8, 2021 12:59
…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. 









[ghstack-poisoned]
…ch.jit.freeze"


This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. 

I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. 

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations.

Differential Revision: [D25856264](https://our.internmc.facebook.com/intern/diff/D25856264)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in a389b30.

@facebook-github-bot facebook-github-bot deleted the gh/eellison/153/head branch January 16, 2021 15:18
@xsacha
Copy link
Contributor

xsacha commented Feb 8, 2021

@eellison as a suggestion for another post freezing optimisation: you could convert the dtype of the weights in instances such as AMP. Currently you end up with a torchscript file containing 32-bit weights and script that converts specific weights to 16-bit wherever they are used.

@eellison
Copy link
Contributor Author

eellison commented Feb 8, 2021

@xsacha thanks for the suggestion! Do you have a repro by any chance ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants