Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][pt2] Fix and rename move_model_to_eval (#108891) #109027

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def test_functional_constant_folding_after_dynamo_export(self):
prepare_model(*example_inputs)

convert_model = convert_pt2e(prepare_model)
convert_model.eval()
torch.ao.quantization.move_exported_model_to_eval(convert_model)
compiler_model = compile_fx(convert_model, example_inputs)

# First Run
Expand Down
3 changes: 2 additions & 1 deletion test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def _test_common(
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model).eval()
convert_model = convert_pt2e(prepare_model)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
_ = torch.compile(convert_model)(*inputs)
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
Expand Down
38 changes: 33 additions & 5 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _verify_symmetric_qnnpack_qat_numerics(
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)

if verify_convert:
torch.ao.quantization.move_model_to_eval(model_pt2e)
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
Expand Down Expand Up @@ -2392,7 +2392,7 @@ def forward(self, x, y):
non_ref_node_occurrence
)

def test_move_model_to_eval(self):
def test_move_exported_model_to_eval(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -2404,8 +2404,6 @@ def forward(self, x):
example_inputs = (torch.randn(1),)
m = M().train()
m = capture_pre_autograd_graph(m, example_inputs)
m.graph.eliminate_dead_code()
m.recompile()

# Assert that dropout op exists and is in train mode
dropout_node = None
Expand All @@ -2417,13 +2415,43 @@ def forward(self, x):
self.assertTrue(dropout_node.args[2])

# Do the subgraph rewriting
torch.ao.quantization.move_model_to_eval(m)
torch.ao.quantization.move_exported_model_to_eval(m)

# Assert that dropout op is now replaced with a clone op
targets = [n.target for n in m.graph.nodes]
self.assertTrue(torch.ops.aten.clone.default in targets)
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)

def test_disallow_eval_train(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
example_inputs = (torch.rand(3, 3, 5, 5),)

# Before export: this is OK
m.eval()
m.train()

# After export: this is not OK
m = capture_pre_autograd_graph(m, example_inputs)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()

# After prepare: still not OK
quantizer = XNNPACKQuantizer()
m = prepare_qat_pt2e(m, quantizer)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()

# After convert: still not OK
m = convert_pt2e(m)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()


@skipIfNoQNNPACK
class TestQuantizePT2EOps(QuantizationTestCase):
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
from .pt2e.utils import move_model_to_eval
from .pt2e.utils import move_exported_model_to_eval
from typing import Union, List, Callable, Tuple, Optional
from torch import Tensor
import torch
Expand Down Expand Up @@ -120,7 +120,7 @@
"get_quantized_operator",
"get_static_quant_module_class",
"load_observer_state_dict",
"move_model_to_eval",
"move_exported_model_to_eval",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
"prepare",
Expand Down
31 changes: 27 additions & 4 deletions torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import operator
import types

import torch
from torch.fx import (
GraphModule,
Expand All @@ -6,14 +9,13 @@
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.nn.utils.fusion import fuse_conv_bn_weights
import operator
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec

__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"move_model_to_eval",
"move_exported_model_to_eval",
"remove_tensor_overload_for_qdq_ops",
]

Expand Down Expand Up @@ -200,6 +202,10 @@ def _replace_dropout_for_eval(m: GraphModule):

See https://github.com/pytorch/pytorch/issues/103681.
"""
# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
m.recompile()

def dropout_train(x):
return F.dropout(x, p=0.5, training=True)

Expand Down Expand Up @@ -390,9 +396,9 @@ def replacement(x_i8, scale, zero_point, quant_min, quant_max):
node.args = new_args
return gm

# TODO: also support move_model_to_train
# TODO: also support move_exported_model_to_train
# TODO: also support standalone batchnorm
def move_model_to_eval(m: GraphModule):
def move_exported_model_to_eval(m: GraphModule):
"""
Move an exported GraphModule to eval mode.

Expand All @@ -401,3 +407,20 @@ def move_model_to_eval(m: GraphModule):
"""
_replace_dropout_for_eval(m)
return m

# TODO: Handle this in export itself and don't wrap the model in another GraphModule
# in prepare and convert
def _disallow_eval_train(model: GraphModule):
"""
Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
This is useful for exported models, where these methods don't actually behave as expected.
"""
def _train(self, mode: bool = True):
raise NotImplementedError("Calling train() is not supported yet.")

def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")

model.train = types.MethodType(_train, model) # type: ignore[method-assign]
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
return model
4 changes: 4 additions & 0 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
_disallow_eval_train,
)
from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
Expand Down Expand Up @@ -70,6 +71,7 @@ def prepare_pt2e(
propagate_annotation(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

def prepare_qat_pt2e(
Expand All @@ -87,6 +89,7 @@ def prepare_qat_pt2e(
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

def convert_pt2e(
Expand All @@ -100,4 +103,5 @@ def convert_pt2e(
model = reference_representation_rewrite(model)

model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model