Skip to content

Commit

Permalink
[JIT] Update freezing api (#52337)
Browse files Browse the repository at this point in the history
Summary:
Update freezing api  for 1.8,  and add a corresponding C++ API. The `optimize` flag hasn't been publicly released yet, so we are able to change it without breaking BC. I will submit a PR to branch release as well, there are a few more days to do that

Pull Request resolved: #52337

Reviewed By: ejguan

Differential Revision: D26491833

Pulled By: eellison

fbshipit-source-id: 6dcd74eb8f76db64ac53183d03dabdd0f101f4b5
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Feb 18, 2021
1 parent ac12116 commit e1d927e
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 35 deletions.
16 changes: 16 additions & 0 deletions test/cpp/jit/test_module_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
#include <test/cpp/jit/test_utils.h>

#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>

namespace torch {
Expand Down Expand Up @@ -341,6 +343,20 @@ TEST(ModuleAPITest, Define) {
AT_ASSERT(result.toTensor().item<float>() == 6);
}

TEST(ModuleAPITest, Freezing) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def forward(self, x, b : int = 4):
return self.foo + x + b
)");
m.eval();
auto frozen_mod = torch::jit::freeze(m);
auto forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
;
}

TEST(ModuleAPITest, To_CUDA) {
Module m("test");
{
Expand Down
32 changes: 12 additions & 20 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,7 @@ def test_optimize_freeze_module(self):
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
# inspect frozen mod
FileCheck().check("batch_norm").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
Expand All @@ -1528,18 +1528,14 @@ def forward(self, x):
return self.dropout(x)

mod = torch.jit.script(Net())
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
FileCheck().check("aten::dropout").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
# inspect mod
torch._C._jit_pass_inline(mod.graph)
FileCheck().check("aten::dropout").run(mod.graph)
frozen_mod = torch.jit.freeze(mod.eval())
FileCheck().check_not("aten::dropout").run(frozen_mod.graph)

script_mod = torch.jit.script(mod)
script_mod.eval()

input = torch.randn(2)
output_s = script_mod.forward(input)
output_s = mod.forward(input)
output_f = frozen_mod.forward(input)
self.assertEqual(output_s, output_f)

Expand All @@ -1552,18 +1548,14 @@ def __init__(self):
def forward(self, x):
return self.dropout(x)

mod = torch.jit.script(Net())
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
FileCheck().check("aten::feature_dropout").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
mod = torch.jit.script(Net().eval())
# inspect mod
torch._C._jit_pass_inline(mod.graph)
FileCheck().check("aten::feature_dropout").run(mod.graph)
frozen_mod = torch.jit.freeze(mod)
FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph)

script_mod = torch.jit.script(mod)
script_mod.eval()

input = torch.randn(2, 2)
output_s = script_mod.forward(input)
output_s = mod.forward(input)
output_f = frozen_mod.forward(input)
self.assertEqual(output_s, output_f)
5 changes: 5 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def _freeze_module(module: ScriptModule,
freeze_interfaces: _bool = True,
preserveParameters: _bool = True) -> ScriptModule: ...
def _jit_pass_optimize_frozen_graph(Graph) -> None: ...
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...

def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>

Expand Down Expand Up @@ -336,6 +338,21 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
return obj;
}

Module freeze(
const Module& module,
c10::optional<std::vector<std::string>> preserved_attrs,
bool optimize_numerics) {
TORCH_CHECK(
module.is_training(),
"Freezing is currently only implemented for modules in eval mode. Please call .eval() before freezing");

Module out_mod = freeze_module(
module, preserved_attrs.value_or(std::vector<std::string>({})));
auto graph = module.get_method("forward").graph();
OptimizeFrozenGraph(graph, optimize_numerics);
return out_mod;
}

buffer_list Module::buffers(bool recurse) const {
return buffer_list(*this, recurse, /*return_module=*/false);
}
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ struct TORCH_API Module : public Object {
bool non_blocking);
};

// C++ equivalent api of `torch.jit.freeze`. See documentation there for
// details.
TORCH_API Module freeze(
const Module& module,
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
bool optimize_numerics = true);

namespace detail {

struct TORCH_API SlotCursor {
Expand Down
14 changes: 9 additions & 5 deletions torch/csrc/jit/passes/frozen_graph_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
namespace torch {
namespace jit {

void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
void OptimizeFrozenGraph(
std::shared_ptr<Graph>& graph,
bool optimize_numerics) {
removeDropout(graph);
// run a couple times to capture Conv -> Mul -> Add etc
for (size_t i = 0; i < 2; i++) {
FoldFrozenConvBatchnorm(graph);
FoldFrozenConvAddOrSub(graph);
FoldFrozenConvMulOrDiv(graph);
if (optimize_numerics) {
for (size_t i = 0; i < 2; i++) {
FoldFrozenConvBatchnorm(graph);
FoldFrozenConvAddOrSub(graph);
FoldFrozenConvMulOrDiv(graph);
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/passes/frozen_graph_optimizations.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
namespace torch {
namespace jit {

TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph);
TORCH_API void OptimizeFrozenGraph(
std::shared_ptr<Graph>& graph,
bool optimize_numerics = true);

} // namespace jit
} // namespace torch
30 changes: 21 additions & 9 deletions torch/jit/_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.jit._script import RecursiveScriptModule, ScriptModule


def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
Expand All @@ -26,10 +26,8 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = Tr
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference,
in addition to the graph cleanup that already occurs. The details of the optimizations can be found in
`torch.jit.optimize_frozen_module.`
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
Returns:
Frozen :class:`ScriptModule`.
Expand Down Expand Up @@ -102,23 +100,29 @@ def forward(self, input):

out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
if optimize:
optimize_frozen_module(out)
optimize_frozen_module(out, optimize_numerics)

return out


def optimize_frozen_module(mod):
def optimize_frozen_module(mod, optimize_numerics: bool = True):
r"""
Runs a series of optimizations looking for patterns that occur in frozen graphs.
The current set of optimizations is:
- Dropout Removal
- Conv -> Batchnorm folding
- Conv -> Add/Sub folding
- Conv -> Mul/Div folding
Args:
mod (:class:`ScriptModule`): a frozen module to be optimized
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_allclose`
when applied on a single transformation, however in a module where many transformations are applied
the rtol or atol may no longer fall within the default `assert_allclose` tolerance. Conv -> Batchnorm folding,
Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.
Returns:
None
Expand All @@ -140,4 +144,12 @@ def optimize_frozen_module(mod):
assert "batch_norm" not in str(frozen_mod.graph)
"""
torch._C._jit_pass_optimize_frozen_graph(mod.graph)
# xxx: keep in sync with frozen_graph_optimization.cpp
# intentionally duplicated to make to make it easier to create custom optimization sequence
torch._C._jit_pass_remove_dropout(mod._c)
if optimize_numerics:
# run a couple times to capture Conv -> Mul -> Add etc
for _ in range(2):
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)

0 comments on commit e1d927e

Please sign in to comment.