-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Post Freezing Optimizations, turn on by default in torch.jit.freeze #50222
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 8896c3e (more details on the Dr. CI page):
Extra GitHub checks: 1 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. This comment has been revised 72 times. |
…ch.jit.freeze" [ghstack-poisoned]
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. [ghstack-poisoned]
ghstack-source-id: 5acbb012b21c13cc7b6817f8ca6e0d20a616d8ab Pull Request resolved: #50222
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. [ghstack-poisoned]
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. [ghstack-poisoned]
ghstack-source-id: 61c598b050a62772117920d85c54ddbc06edd914 Pull Request resolved: #50222
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. [ghstack-poisoned]
ghstack-source-id: 75fece97a678d3ab95c4d4a0c251ffdc499d1f1f Pull Request resolved: #50222
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does it interact with graph mode quantization? And also for mobile? I suggest to have couple a test?
It doesn't, because they all use _freeze_module, not the torch.jit.freeze API. If they did tests would break. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
namespace jit { | ||
|
||
void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) { | ||
// run a couple times to capture Conv -> Mul -> Add etc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth running that while something is changed (with some threshold to not run for too long).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be nice but i'm not sure how relevant it is, might do as a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned this simply because it's a common pattern in compilers. I agree with you that it's not clear whether it matters here.
@@ -10,7 +10,7 @@ | |||
from torch.jit._script import RecursiveScriptModule, ScriptModule | |||
|
|||
|
|||
def freeze(mod, preserved_attrs: Optional[List[str]] = None): | |||
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be easier to land this in two steps: 1) add a new flag with the new functionality, but disabled by default, 2) flip the default. This way if something breaks, only the flag switch would be reverted.
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. [ghstack-poisoned]
…ch.jit.freeze" This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. Differential Revision: [D25856264](https://our.internmc.facebook.com/intern/diff/D25856264) [ghstack-poisoned]
@eellison as a suggestion for another post freezing optimisation: you could convert the dtype of the weights in instances such as AMP. Currently you end up with a torchscript file containing 32-bit weights and script that converts specific weights to 16-bit wherever they are used. |
@xsacha thanks for the suggestion! Do you have a repro by any chance ? |
Stack from ghstack:
This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal.
I would like some feedback on the API. torch.jit.freeze is technically in ~prototype~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use
freeze_module
, not the python API, so this shouldn't break anything.I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a
torch::jit::freeze
which follows the same api astorch.jit.freeze
intended for C++ use, and runs the optimizations.Differential Revision: D25856264