Skip to content

Commit

Permalink
[quant][pt2] Fix and rename move_model_to_eval (#108891) (#109027)
Browse files Browse the repository at this point in the history
Summary:
This commit fixes two silent correctness problems with
the current implementation of `move_model_to_eval`:

(1) Previously the user had to manually call `eliminate_dead_code`
before calling `move_model_to_eval`, otherwise the dropout pattern
won't actually get eliminated. This is because subgraph rewriter
complains the match is not self-contained, and so silently does
not do the replacement.

(2) We wish to error when the user calls `model.train()` or
`model.eval()` on an exported model. This error is raised
correctly immediately after export today, but no longer raised
after the user calls prepare or convert.

We fix (1) by moving the `eliminate_dead_code` call into
`move_model_to_eval`, and fix (2) by ensuring the respective
errors are thrown after prepare and convert as well.

Additionally, this commit renames `move_model_to_eval` to
`move_exported_model_to_eval` to be more explicit.

bypass-github-export-checks

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_disallow_eval_train
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_to_eval

Imported from OSS

Differential Revision: D49097293

Pull Request resolved: #108891
Approved by: https://github.com/jerryzh168
  • Loading branch information
andrewor14 committed Sep 11, 2023
1 parent 71c9d5c commit 7e23b49
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 13 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def test_functional_constant_folding_after_dynamo_export(self):
prepare_model(*example_inputs)

convert_model = convert_pt2e(prepare_model)
convert_model.eval()
torch.ao.quantization.move_exported_model_to_eval(convert_model)
compiler_model = compile_fx(convert_model, example_inputs)

# First Run
Expand Down
3 changes: 2 additions & 1 deletion test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def _test_common(
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model).eval()
convert_model = convert_pt2e(prepare_model)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
_ = torch.compile(convert_model)(*inputs)
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
Expand Down
38 changes: 33 additions & 5 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)

if verify_convert:
torch.ao.quantization.move_model_to_eval(model_pt2e)
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
Expand Down Expand Up @@ -2392,7 +2392,7 @@ def forward(self, x, y):
non_ref_node_occurrence
)

def test_move_model_to_eval(self):
def test_move_exported_model_to_eval(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -2404,8 +2404,6 @@ def forward(self, x):
example_inputs = (torch.randn(1),)
m = M().train()
m = capture_pre_autograd_graph(m, example_inputs)
m.graph.eliminate_dead_code()
m.recompile()

# Assert that dropout op exists and is in train mode
dropout_node = None
Expand All @@ -2417,13 +2415,43 @@ def forward(self, x):
self.assertTrue(dropout_node.args[2])

# Do the subgraph rewriting
torch.ao.quantization.move_model_to_eval(m)
torch.ao.quantization.move_exported_model_to_eval(m)

# Assert that dropout op is now replaced with a clone op
targets = [n.target for n in m.graph.nodes]
self.assertTrue(torch.ops.aten.clone.default in targets)
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)

def test_disallow_eval_train(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
example_inputs = (torch.rand(3, 3, 5, 5),)

# Before export: this is OK
m.eval()
m.train()

# After export: this is not OK
m = capture_pre_autograd_graph(m, example_inputs)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()

# After prepare: still not OK
quantizer = XNNPACKQuantizer()
m = prepare_qat_pt2e(m, quantizer)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()

# After convert: still not OK
m = convert_pt2e(m)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()


@skipIfNoQNNPACK
class TestQuantizePT2EOps(QuantizationTestCase):
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
from .pt2e.utils import move_model_to_eval
from .pt2e.utils import move_exported_model_to_eval
from typing import Union, List, Callable, Tuple, Optional
from torch import Tensor
import torch
Expand Down Expand Up @@ -120,7 +120,7 @@
"get_quantized_operator",
"get_static_quant_module_class",
"load_observer_state_dict",
"move_model_to_eval",
"move_exported_model_to_eval",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
"prepare",
Expand Down
31 changes: 27 additions & 4 deletions torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import operator
import types

import torch
from torch.fx import (
GraphModule,
Expand All @@ -6,14 +9,13 @@
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.nn.utils.fusion import fuse_conv_bn_weights
import operator
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec

__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"move_model_to_eval",
"move_exported_model_to_eval",
"remove_tensor_overload_for_qdq_ops",
]

Expand Down Expand Up @@ -200,6 +202,10 @@ def _replace_dropout_for_eval(m: GraphModule):
See https://github.com/pytorch/pytorch/issues/103681.
"""
# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
m.recompile()

def dropout_train(x):
return F.dropout(x, p=0.5, training=True)

Expand Down Expand Up @@ -390,9 +396,9 @@ def replacement(x_i8, scale, zero_point, quant_min, quant_max):
node.args = new_args
return gm

# TODO: also support move_model_to_train
# TODO: also support move_exported_model_to_train
# TODO: also support standalone batchnorm
def move_model_to_eval(m: GraphModule):
def move_exported_model_to_eval(m: GraphModule):
"""
Move an exported GraphModule to eval mode.
Expand All @@ -401,3 +407,20 @@ def move_model_to_eval(m: GraphModule):
"""
_replace_dropout_for_eval(m)
return m

# TODO: Handle this in export itself and don't wrap the model in another GraphModule
# in prepare and convert
def _disallow_eval_train(model: GraphModule):
"""
Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
This is useful for exported models, where these methods don't actually behave as expected.
"""
def _train(self, mode: bool = True):
raise NotImplementedError("Calling train() is not supported yet.")

def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")

model.train = types.MethodType(_train, model) # type: ignore[method-assign]
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
return model
4 changes: 4 additions & 0 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
_disallow_eval_train,
)
from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
Expand Down Expand Up @@ -70,6 +71,7 @@ def prepare_pt2e(
propagate_annotation(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

def prepare_qat_pt2e(
Expand All @@ -87,6 +89,7 @@ def prepare_qat_pt2e(
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

def convert_pt2e(
Expand All @@ -100,4 +103,5 @@ def convert_pt2e(
model = reference_representation_rewrite(model)

model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

0 comments on commit 7e23b49

Please sign in to comment.