Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Update freezing api #52337

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/cpp/jit/test_module_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
#include <test/cpp/jit/test_utils.h>

#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>

namespace torch {
Expand Down Expand Up @@ -341,6 +343,20 @@ TEST(ModuleAPITest, Define) {
AT_ASSERT(result.toTensor().item<float>() == 6);
}

TEST(ModuleAPITest, Freezing) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def forward(self, x, b : int = 4):
return self.foo + x + b
)");
m.eval();
auto frozen_mod = torch::jit::freeze(m);
auto forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
;
}

TEST(ModuleAPITest, To_CUDA) {
Module m("test");
{
Expand Down
32 changes: 12 additions & 20 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,7 @@ def test_optimize_freeze_module(self):
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
# inspect frozen mod
FileCheck().check("batch_norm").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
Expand All @@ -1528,18 +1528,14 @@ def forward(self, x):
return self.dropout(x)

mod = torch.jit.script(Net())
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
FileCheck().check("aten::dropout").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
# inspect mod
torch._C._jit_pass_inline(mod.graph)
FileCheck().check("aten::dropout").run(mod.graph)
frozen_mod = torch.jit.freeze(mod.eval())
FileCheck().check_not("aten::dropout").run(frozen_mod.graph)

script_mod = torch.jit.script(mod)
script_mod.eval()

input = torch.randn(2)
output_s = script_mod.forward(input)
output_s = mod.forward(input)
output_f = frozen_mod.forward(input)
self.assertEqual(output_s, output_f)

Expand All @@ -1552,18 +1548,14 @@ def __init__(self):
def forward(self, x):
return self.dropout(x)

mod = torch.jit.script(Net())
# set optimize to False here, by default freezing runs optimize_frozen_module
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
FileCheck().check("aten::feature_dropout").run(frozen_mod.graph)
torch.jit.optimize_frozen_module(frozen_mod)
mod = torch.jit.script(Net().eval())
# inspect mod
torch._C._jit_pass_inline(mod.graph)
FileCheck().check("aten::feature_dropout").run(mod.graph)
frozen_mod = torch.jit.freeze(mod)
FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph)

script_mod = torch.jit.script(mod)
script_mod.eval()

input = torch.randn(2, 2)
output_s = script_mod.forward(input)
output_s = mod.forward(input)
output_f = frozen_mod.forward(input)
self.assertEqual(output_s, output_f)
5 changes: 5 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def _freeze_module(module: ScriptModule,
freeze_interfaces: _bool = True,
preserveParameters: _bool = True) -> ScriptModule: ...
def _jit_pass_optimize_frozen_graph(Graph) -> None: ...
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...

def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>

Expand Down Expand Up @@ -336,6 +338,21 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
return obj;
}

Module freeze(
const Module& module,
c10::optional<std::vector<std::string>> preserved_attrs,
bool optimize_numerics) {
TORCH_CHECK(
module.is_training(),
"Freezing is currently only implemented for modules in eval mode. Please call .eval() before freezing");

Module out_mod = freeze_module(
module, preserved_attrs.value_or(std::vector<std::string>({})));
auto graph = module.get_method("forward").graph();
OptimizeFrozenGraph(graph, optimize_numerics);
return out_mod;
}

buffer_list Module::buffers(bool recurse) const {
return buffer_list(*this, recurse, /*return_module=*/false);
}
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ struct TORCH_API Module : public Object {
bool non_blocking);
};

// C++ equivalent api of `torch.jit.freeze`. See documentation there for
// details.
TORCH_API Module freeze(
const Module& module,
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
bool optimize_numerics = true);

namespace detail {

struct TORCH_API SlotCursor {
Expand Down
14 changes: 9 additions & 5 deletions torch/csrc/jit/passes/frozen_graph_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
namespace torch {
namespace jit {

void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
void OptimizeFrozenGraph(
std::shared_ptr<Graph>& graph,
bool optimize_numerics) {
removeDropout(graph);
// run a couple times to capture Conv -> Mul -> Add etc
for (size_t i = 0; i < 2; i++) {
FoldFrozenConvBatchnorm(graph);
FoldFrozenConvAddOrSub(graph);
FoldFrozenConvMulOrDiv(graph);
if (optimize_numerics) {
for (size_t i = 0; i < 2; i++) {
FoldFrozenConvBatchnorm(graph);
FoldFrozenConvAddOrSub(graph);
FoldFrozenConvMulOrDiv(graph);
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/passes/frozen_graph_optimizations.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
namespace torch {
namespace jit {

TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph);
TORCH_API void OptimizeFrozenGraph(
std::shared_ptr<Graph>& graph,
bool optimize_numerics = true);

} // namespace jit
} // namespace torch
30 changes: 21 additions & 9 deletions torch/jit/_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.jit._script import RecursiveScriptModule, ScriptModule


def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):
def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
Expand All @@ -26,10 +26,8 @@ def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = Tr
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.

optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference,
in addition to the graph cleanup that already occurs. The details of the optimizations can be found in
`torch.jit.optimize_frozen_module.`

optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.

Returns:
Frozen :class:`ScriptModule`.
Expand Down Expand Up @@ -102,23 +100,29 @@ def forward(self, input):

out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
if optimize:
optimize_frozen_module(out)
optimize_frozen_module(out, optimize_numerics)

return out


def optimize_frozen_module(mod):
def optimize_frozen_module(mod, optimize_numerics: bool = True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you get non-portable optimizations in the future - would it be a separate bool flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, that was what i was envisioning.

def optimize_frozen_module(mod, optimize_numerics: bool = True, non_portable = False):

r"""
Runs a series of optimizations looking for patterns that occur in frozen graphs.
The current set of optimizations is:
- Dropout Removal
- Conv -> Batchnorm folding
- Conv -> Add/Sub folding
- Conv -> Mul/Div folding

Args:
mod (:class:`ScriptModule`): a frozen module to be optimized

optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_allclose`
when applied on a single transformation, however in a module where many transformations are applied
the rtol or atol may no longer fall within the default `assert_allclose` tolerance. Conv -> Batchnorm folding,
Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.

Returns:
None

Expand All @@ -140,4 +144,12 @@ def optimize_frozen_module(mod):
assert "batch_norm" not in str(frozen_mod.graph)

"""
torch._C._jit_pass_optimize_frozen_graph(mod.graph)
# xxx: keep in sync with frozen_graph_optimization.cpp
# intentionally duplicated to make to make it easier to create custom optimization sequence
torch._C._jit_pass_remove_dropout(mod._c)
if optimize_numerics:
# run a couple times to capture Conv -> Mul -> Add etc
for _ in range(2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not very scientific, you probably want to iterate till the fixed point :) but it's unrelated to this diff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, that's kind of a todo, i'm not convinced it really matters but it would be a good follow up

torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)