Skip to content

Commit

Permalink
[qunat][graphmode][fx] Embedding/EmbeddingBag works in static quant q…
Browse files Browse the repository at this point in the history
…config (#48062)

Summary:
Pull Request resolved: #48062

When Embedding/EmbeddingBag are configured with static quant we'll skip inserting observer for
them in the graph and keep the op unchanged and print a warning.
This also aligns with eager mode behavior as well.

We'll enforce this behavior for other ops that only supports dynamic/weight_only quant but not static quant as well.

We used a global variable `DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER`, this is not exposed to user right now,
we can add that later if needed.

Test Plan: Imported from OSS

Reviewed By: supriyar

Differential Revision: D25007537

fbshipit-source-id: 6ab9e025269b44bbfd0d6dd5bb9f95fe3ca9dead
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Nov 17, 2020
1 parent 3846e35 commit d7e8384
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 17 deletions.
26 changes: 14 additions & 12 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1948,7 +1948,8 @@ def forward(self, indices):
quantized_node = ns.call_module(nnq.Embedding)
configs = [
(float_qparams_dynamic_qconfig, ns.call_module(nnq.Embedding)),
(None, ns.call_module(nn.Embedding))
(None, ns.call_module(nn.Embedding)),
(default_qconfig, ns.call_module(nn.Embedding)),
]

for qconfig, node in configs:
Expand Down Expand Up @@ -1991,17 +1992,18 @@ def forward(self, indices, offsets):
custom_qconfig=float_qparams_qconfig
)

# check it works in None qconfig
qconfig_dict = {"": None}
m = M().eval()
m = prepare_fx(model, qconfig_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence={
ns.call_module(torch.quantization.MinMaxObserver): 0
})
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag))
# make sure it runs
m(*inputs)
# check it works in None and static qconfig
for qconfig in [None, default_qconfig]:
qconfig_dict = {"": default_qconfig}
m = M().eval()
m = prepare_fx(model, qconfig_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence={
ns.call_module(torch.quantization.MinMaxObserver): 0
})
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag))
# make sure it runs
m(*inputs)

class TestQuantizeFxModels(QuantizationTestCase):
def _test_model_impl(
Expand Down
12 changes: 12 additions & 0 deletions torch/quantization/fx/pattern_utils.py
Expand Up @@ -37,6 +37,18 @@ def get_default_quant_patterns():
def get_default_output_activation_post_process_map():
return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP

# a set of QuantizeHandler classes that are not observed
# we'll skip inserting observers for input and output for these QuantizeHandlers
# used for ops that only supports dynamic/weight only quantization
DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER = set()
def mark_input_output_not_observed():
def insert(fn):
DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER.add(fn)
return fn
return insert

def input_output_observed(qh):
return type(qh) not in DEFAULT_NOT_OBSERVED_QUANTIZE_HANDLER

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
Expand Down
12 changes: 10 additions & 2 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -15,6 +15,7 @@
)
from .pattern_utils import (
register_quant_pattern,
mark_input_output_not_observed,
)
from .utils import (
_parent_name,
Expand All @@ -30,6 +31,7 @@

from abc import ABC, abstractmethod
import operator
import warnings

# -------------------------
# Pattern Registrations
Expand Down Expand Up @@ -418,6 +420,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_

@register_quant_pattern(torch.nn.Embedding)
@register_quant_pattern(torch.nn.EmbeddingBag)
@mark_input_output_not_observed()
class Embedding(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
Expand All @@ -437,8 +440,13 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
emb = quantizer.modules[emb_node.target]
qconfig = quantizer.qconfig_map[node.name]
dtypes = get_qconfig_dtypes(qconfig)
assert dtypes in supported_dtypes, "qconfig dtype pair not supported:" \
" {}, supported dtypes are: {}".format(dtypes, supported_dtypes)
if dtypes not in supported_dtypes:
warnings.warn(
"dtype combination: {} is not "
"supported by Embedding/EmbeddingBag, "
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

qemb = get_static_quant_module_class(type(emb))
quantized = qemb.from_float(emb)
parent_name, name = _parent_name(emb_node.target)
Expand Down
8 changes: 5 additions & 3 deletions torch/quantization/fx/quantize.py
Expand Up @@ -32,6 +32,7 @@
is_match,
get_default_quant_patterns,
get_default_output_activation_post_process_map,
input_output_observed,
)

from .observed_module import (
Expand Down Expand Up @@ -479,7 +480,7 @@ def input_is_observed(arg):
output_is_observed = self.modules[node.target]._output_is_observed
if output_is_observed:
observed_node_names_set.add(node.name)
elif quantize_handler.all_node_args:
elif quantize_handler.all_node_args and input_output_observed(quantize_handler):
# observer for outputs
new_observer = qconfig.activation()
insert_observer(node, new_observer)
Expand Down Expand Up @@ -710,7 +711,8 @@ def is_output_quantized(node):
'CopyNode of type ' + node.op + ' is not handled'
quantized = is_quantized(node.args[0])

if not activation_is_statically_quantized(qconfig):
if not activation_is_statically_quantized(qconfig) or \
not input_output_observed(obj):
quantized = False

return quantized
Expand Down Expand Up @@ -975,7 +977,7 @@ def visit_arg(arg):
# don't attach observer/fake_quant for CopyNode
if isinstance(quantize_handler, CopyNode):
qconfig = None
if root_node is node:
if root_node is node and input_output_observed(quantize_handler):
# matched_nodes[-1] is the first op in the sequence and
# matched_nodes[0] is the last op in the sequence
# inputs
Expand Down

0 comments on commit d7e8384

Please sign in to comment.