Skip to content

Commit

Permalink
Add Post Freezing Optimizations, turn on by default in torch.jit.free…
Browse files Browse the repository at this point in the history
…ze (#50222)

Summary:
Pull Request resolved: #50222

This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal.

I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything.

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations.

Test Plan: Imported from OSS

Reviewed By: tugsbayasgalan

Differential Revision: D25856264

Pulled By: eellison

fbshipit-source-id: 56be1f12cfc459b4c4421d4dfdedff8b9ac77112
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Jan 12, 2021
1 parent 30aeed7 commit a389b30
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 8 deletions.
16 changes: 16 additions & 0 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,3 +1501,19 @@ def forward(self, x):
# add with different dtype
test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False,
add_tensor=torch.rand(1).to(torch.int), expect_success=False)

def test_optimize_freeze_module(self):
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
FileCheck().check("batch_norm").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
FileCheck().check_not("batch_norm").run(frozen_mod.graph)

# optimize_frozen_module should be run
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ core_sources_full_mobile = [
"torch/csrc/jit/passes/prepack_folding.cpp",
"torch/csrc/jit/passes/fold_conv_bn.cpp",
"torch/csrc/jit/passes/frozen_conv_folding.cpp",
"torch/csrc/jit/passes/frozen_graph_optimizations.cpp",
"torch/csrc/jit/passes/remove_expands.cpp",
"torch/csrc/jit/passes/remove_dropout.cpp",
"torch/csrc/jit/passes/requires_grad_analysis.cpp",
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _freeze_module(module: ScriptModule,
preserved_attrs: List[str] = [],
freeze_interfaces: _bool = True,
preserveParameters: _bool = True) -> ScriptModule: ...
def _jit_pass_optimize_frozen_graph(Graph) -> None: ...
def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
Expand Down
21 changes: 21 additions & 0 deletions torch/csrc/jit/passes/frozen_graph_optimizations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>

namespace torch {
namespace jit {

void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
// run a couple times to capture Conv -> Mul -> Add etc
for (size_t i = 0; i < 2; i++) {
FoldFrozenConvBatchnorm(graph);
FoldFrozenConvAddOrSub(graph);
FoldFrozenConvMulOrDiv(graph);
}
}

} // namespace jit
} // namespace torch
19 changes: 19 additions & 0 deletions torch/csrc/jit/passes/frozen_graph_optimizations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

/** \brief Runs a set of Optimizations that Optimize Frozen Graphs
*
* Currently this set of optimizations is:
* - FoldFrozenConvBatchnorm
* - FoldFrozenConvAddOrSub
* - FoldFrozenConvMulOrDiv
*/

namespace torch {
namespace jit {

TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch
2 changes: 2 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/fuse_relu.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
Expand Down Expand Up @@ -299,6 +300,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
.def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
.def("_jit_pass_fuse_linear", &FuseLinear)
.def(
"_jit_pass_fuse_add_relu",
Expand Down
14 changes: 7 additions & 7 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph

from torch.jit.cuda import stream
from torch.jit._freeze import freeze
from torch.jit._freeze import freeze, optimize_frozen_module

# For backwards compatibility
_fork = fork
Expand Down Expand Up @@ -93,20 +93,20 @@ def script_if_tracing(fn):
return _script_if_tracing(fn)


# for torch.jit.isinstance
# for torch.jit.isinstance
def isinstance(obj, target_type):
"""
This function provides for conatiner type refinement in TorchScript. It can refine
This function provides for conatiner type refinement in TorchScript. It can refine
parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
refine basic types such as bools and ints that are available in TorchScript.
Args:
obj: object to refine the type of
target_type: type to try to refine obj to
target_type: type to try to refine obj to
Returns:
``bool``: True if obj was successfully refined to the type of target_type,
False otherwise with no new type refinement
``bool``: True if obj was successfully refined to the type of target_type,
False otherwise with no new type refinement
Example (using ``torch.jit.isinstance`` for type refinement):
Expand Down
44 changes: 43 additions & 1 deletion torch/jit/_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.jit._script import RecursiveScriptModule, ScriptModule


def freeze(mod, preserved_attrs: Optional[List[str]] = None):
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
Expand All @@ -26,6 +26,11 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None):
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference,
in addition to the graph cleanup that already occurs. The details of the optimizations can be found in
`torch.jit.optimize_frozen_module.`
Returns:
Frozen :class:`ScriptModule`.
Expand Down Expand Up @@ -97,5 +102,42 @@ def forward(self, input):

out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
if optimize:
optimize_frozen_module(out)

return out


def optimize_frozen_module(mod):
r"""
Runs a series of optimizations looking for patterns that occur in frozen graphs.
The current set of optimizations is:
- Conv -> Batchnorm folding
- Conv -> Add/Sub folding
- Conv -> Mul/Div folding
Args:
mod (:class:`ScriptModule`): a frozen module to be optimized
Returns:
None
Note:
In rare occassions, this can result in slower execution.
Example (Freezing a module with Conv->Batchnorm)
.. code-block:: python
import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
assert "batch_norm" in str(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
assert "batch_norm" not in str(frozen_mod.graph)
"""
torch._C._jit_pass_optimize_frozen_graph(mod.graph)

0 comments on commit a389b30

Please sign in to comment.