-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Update freezing api #52337
[JIT] Update freezing api #52337
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from torch.jit._script import RecursiveScriptModule, ScriptModule | ||
|
||
|
||
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True): | ||
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True): | ||
r""" | ||
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned | ||
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. | ||
|
@@ -26,10 +26,8 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = Tr | |
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. | ||
Attributes modified in preserved methods will also be preserved. | ||
|
||
optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference, | ||
in addition to the graph cleanup that already occurs. The details of the optimizations can be found in | ||
`torch.jit.optimize_frozen_module.` | ||
|
||
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly | ||
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`. | ||
|
||
Returns: | ||
Frozen :class:`ScriptModule`. | ||
|
@@ -102,23 +100,29 @@ def forward(self, input): | |
|
||
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) | ||
RecursiveScriptModule._finalize_scriptmodule(out) | ||
if optimize: | ||
optimize_frozen_module(out) | ||
optimize_frozen_module(out, optimize_numerics) | ||
|
||
return out | ||
|
||
|
||
def optimize_frozen_module(mod): | ||
def optimize_frozen_module(mod, optimize_numerics: bool = True): | ||
r""" | ||
Runs a series of optimizations looking for patterns that occur in frozen graphs. | ||
The current set of optimizations is: | ||
- Dropout Removal | ||
- Conv -> Batchnorm folding | ||
- Conv -> Add/Sub folding | ||
- Conv -> Mul/Div folding | ||
|
||
Args: | ||
mod (:class:`ScriptModule`): a frozen module to be optimized | ||
|
||
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly | ||
preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_allclose` | ||
when applied on a single transformation, however in a module where many transformations are applied | ||
the rtol or atol may no longer fall within the default `assert_allclose` tolerance. Conv -> Batchnorm folding, | ||
Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics. | ||
|
||
Returns: | ||
None | ||
|
||
|
@@ -140,4 +144,12 @@ def optimize_frozen_module(mod): | |
assert "batch_norm" not in str(frozen_mod.graph) | ||
|
||
""" | ||
torch._C._jit_pass_optimize_frozen_graph(mod.graph) | ||
# xxx: keep in sync with frozen_graph_optimization.cpp | ||
# intentionally duplicated to make to make it easier to create custom optimization sequence | ||
torch._C._jit_pass_remove_dropout(mod._c) | ||
if optimize_numerics: | ||
# run a couple times to capture Conv -> Mul -> Add etc | ||
for _ in range(2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not very scientific, you probably want to iterate till the fixed point :) but it's unrelated to this diff There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, that's kind of a todo, i'm not convinced it really matters but it would be a good follow up |
||
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph) | ||
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph) | ||
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you get non-portable optimizations in the future - would it be a separate
bool
flag?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, that was what i was envisioning.
def optimize_frozen_module(mod, optimize_numerics: bool = True, non_portable = False):