Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx][fix] Fix error that DefaultQuantizer is not inserted after a module configured with None qconfig #47316

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 23 additions & 2 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -636,7 +636,6 @@ def forward(self, x):
self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig)
self.assertEqual(m.module_conv2.qconfig, module_name_qconfig)


def test_remove_qconfig(self):
class M(torch.nn.Module):
def __init__(self):
Expand All @@ -657,6 +656,29 @@ def forward(self, x):
self.assertFalse(hasattr(module, 'qconfig'),
'qconfig is not removed for ' + name)

def test_default_quant_after_none_qconfig(self):
""" Make sure default quant is inserted properly"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = self.conv1(x)
x = x.transpose(1, 2)
x = self.conv2(x)

m = M().eval()
qconfig_dict = {
"": default_qconfig,
"module_name": [
("conv1", None)
]
}
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m)

@skipIfNoFBGEMM
def test_qat_and_script(self):
model = LinearModelWithSubmodule().train()
Expand Down Expand Up @@ -961,7 +983,6 @@ def forward(self, x):
# quantize, should run with no errors
quantized = convert_fx(prepared_copy)


@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
Expand Down
167 changes: 82 additions & 85 deletions torch/quantization/fx/quantize.py
Expand Up @@ -388,8 +388,6 @@ def load_arg(a):
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
env[node.name] = observed_graph.node_copy(node, load_arg)
if qconfig is None:
continue

def insert_observer(node, observer, device):
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
Expand All @@ -401,93 +399,92 @@ def insert_observer(node, observer, device):
if device:
getattr(model, observer_name).to(device)

if isinstance(obj, CustomModuleQuantizeHandler):
custom_module = self.modules[node.target]
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
observed_custom_module_class = \
get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name, observed_custom_module)

# index for input of custom module that needs to be observed in parent
standalone_module_input_idxs = None
if isinstance(obj, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = self.modules[node.target]
prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
observed_standalone_module = prepare(standalone_module, {'': qconfig})
observed_standalone_module.qconfig = qconfig
standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
observed_standalone_module = mark_observed_standalone_module(observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name, observed_standalone_module)
self.modules[node.target] = observed_standalone_module


# don't need to insert observer for output if activation does not
# need to be statically quantized
if not activation_is_statically_quantized(qconfig):
continue

if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training:
# we only insert fake quantize module in qat
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(pattern, None)
assert activation_post_process_ctr is not None, \
'activation_post_process constructor not provided for ' + \
'pattern:' + str(pattern)
device = assert_and_get_unique_device(model)
insert_observer(node, activation_post_process_ctr(), device)
elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and
not model.training) or isinstance(obj, CopyNode):
# inserting observers for output of observed module, or mark the output
# as observed
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'

def is_observed(input_arg):
if isinstance(input_arg, Node):
return input_arg.name in observed_node_names_set
elif isinstance(input_arg, list):
return all(map(is_observed, input_arg))
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1:
input_node = matched_nodes[-1] # first node in the sequence

def input_is_observed(arg):
return isinstance(arg, Node) and arg.name in observed_node_names_set
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
observed_node_names_set.add(node.name)
elif isinstance(obj, StandaloneModuleQuantizeHandler):
assert node.op == 'call_module'
output_is_observed = self.modules[node.target]._output_is_observed
if output_is_observed:
observed_node_names_set.add(node.name)
elif qconfig is not None and obj.all_node_args:
# observer for outputs
new_observer = qconfig.activation()
# respect device affinity when adding observers
device = assert_and_get_unique_device(model)
insert_observer(node, new_observer, device)

# insert observer for input of standalone module
if standalone_module_input_idxs is not None:
for idx in standalone_module_input_idxs:
if node.args[idx].name not in observed_node_names_set:
if qconfig is not None:
if isinstance(obj, CustomModuleQuantizeHandler):
custom_module = self.modules[node.target]
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
observed_custom_module_class = \
get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name, observed_custom_module)

elif isinstance(obj, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = self.modules[node.target]
prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
observed_standalone_module = prepare(standalone_module, {'': qconfig})
observed_standalone_module.qconfig = qconfig
standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
observed_standalone_module = mark_observed_standalone_module(observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name, observed_standalone_module)
self.modules[node.target] = observed_standalone_module


# don't need to insert observer for output if activation does not
# need to be statically quantized
if activation_is_statically_quantized(qconfig):
if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training:
# we only insert fake quantize module in qat
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(pattern, None)
assert activation_post_process_ctr is not None, \
"activation_post_process constructor not provided for " + \
"pattern:" + str(pattern)
device = assert_and_get_unique_device(model)
insert_observer(node, activation_post_process_ctr(), device)
elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and
not model.training) or isinstance(obj, CopyNode):
# inserting observers for output of observed module, or mark the output
# as observed
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'

def is_observed(input_arg):
if isinstance(input_arg, Node):
return input_arg.name in observed_node_names_set
elif isinstance(input_arg, list):
return all(map(is_observed, input_arg))
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1:
input_node = matched_nodes[-1] # first node in the sequence

def input_is_observed(arg):
return isinstance(arg, Node) and arg.name in observed_node_names_set
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
observed_node_names_set.add(node.name)
elif isinstance(obj, StandaloneModuleQuantizeHandler):
assert node.op == 'call_module'
output_is_observed = self.modules[node.target]._output_is_observed
if output_is_observed:
observed_node_names_set.add(node.name)
elif obj.all_node_args:
# observer for outputs
new_observer = qconfig.activation()
# respect device affinity when adding observers
device = assert_and_get_unique_device(model)
insert_observer(node.args[idx], new_observer, device)
insert_observer(node, new_observer, device)

# insert observer for input of standalone module
if standalone_module_input_idxs is not None:
for idx in standalone_module_input_idxs:
if node.args[idx].name not in observed_node_names_set:
new_observer = qconfig.activation()
device = assert_and_get_unique_device(model)
insert_observer(node.args[idx], new_observer, device)
else:
env[node.name] = observed_graph.node_copy(node, load_arg)

Expand Down