Skip to content

Commit

Permalink
[quant][graphmode][fx] Standalone module support {input/output}_quant…
Browse files Browse the repository at this point in the history
…ized_idxs (#49754)

Summary:
Pull Request resolved: #49754

This PR adds the support for {input/output}_quantized_idxs for standalone module.

if input_quantized_idxs = [] and output_quantized_idxs = [], the standalone module will be expecting float
input and produce float output, and will quantize the input and dequantize output internally

if input_quantized_idxs = [0] and otuput_qiuantized_idxs = [0], the standalone module will be expecting quantized
input and produce quantized output, the input will be quantized in the parent module, and output will be dequantized
in the parent module as well, this is similar to current quantized modules like nn.quantized.Conv2d

For more details, please see the test case

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_standalone_module

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D25684692

fbshipit-source-id: 900360e01c0e35b26fe85f4a887dc1fd6f7bfb66
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Dec 24, 2020
1 parent 69b1373 commit 89b4899
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 74 deletions.
126 changes: 96 additions & 30 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -570,7 +570,16 @@ def forward(self, x):
m = convert_fx(m)
m(tensor_input)

def test_standalone_module(self):
def _test_standalone_module(
self,
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check):
""" Test standalone module with different quantized input/quantized output
configurations
"""
class StandaloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -610,45 +619,32 @@ def forward(self, x):
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

qconfig_dict = {"": default_qconfig}
config_name = {"standalone_module_name": [("standalone", None, None)]}
config_class = {"standalone_module_class": [(StandaloneModule, None, None)]}
for prepare_config in [config_name, config_class]:
for is_name in [True, False]:
if is_name:
prepare_config = {
"standalone_module_name": [("standalone", None, interface_config)]
}
else:
prepare_config = {
"standalone_module_class": [(StandaloneModule, None, interface_config)]
}

original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)

qconfig_dict = {"": default_qconfig}
# check prepared model
m = prepare_fx(
original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config)
# calibration
m(data)
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
# for input and output of conv in the standalone module
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)

# check converted/quantized model
m = convert_fx(m)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
res = m(data)

# quantize the reference model
Expand All @@ -658,6 +654,76 @@ def forward(self, x):
ref_res = ref_m(data)
self.assertEqual(res, ref_res)

def test_standalone_module_float_interface(self):
float_interface_config = {
"input_quantized_idxs": [], # float input
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for input and output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method("dequantize") : 1,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)

def test_standalone_module_quantized_interface(self):
quantized_interface_config = {
"input_quantized_idxs": [0], # quantized input
"output_quantized_idxs": [0], # quantized output
}
interface_config = quantized_interface_config
# observer for input and output of first conv
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 1
}
convert_count_check = {
# quantizing input for conv
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
# dequantizing output of standalone module
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# quantization of input happens in parent module
# quantization of output happens in the quantized conv module
ns.call_function(torch.quantize_per_tensor) : 0,
ns.call_module(nnq.Conv2d): 1,
# dequantization for output happens in parent module
ns.call_method("dequantize") : 0,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)

@skipIfNoFBGEMM
def test_qconfig_none(self):
class M(torch.nn.Module):
Expand Down
10 changes: 8 additions & 2 deletions torch/quantization/fx/observed_module.py
Expand Up @@ -2,11 +2,11 @@
import copy
from torch.fx import GraphModule # type: ignore
from torch.fx.graph import Graph
from typing import Union, Dict, Any
from typing import Union, Dict, Any, List

class ObservedGraphModule(GraphModule):

def get_preserved_attr_names(self):
def get_preserved_attr_names(self) -> List[str]:
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map',
Expand Down Expand Up @@ -35,6 +35,12 @@ def is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)

class ObservedStandaloneGraphModule(ObservedGraphModule):
def get_preserved_attr_names(self) -> List[str] :
return super().get_preserved_attr_names() + [
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs"
]

def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
Expand Down
4 changes: 2 additions & 2 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -753,10 +753,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
qconfig = quantizer.qconfig_map[node.name]
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore
observed_standalone_module = quantizer.modules[node.target]
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs
quantized_standalone_module = convert(observed_standalone_module, debug=debug)
parent_name, name = _parent_name(node.target)
# update the modules dict
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
quantizer.modules[node.target] = quantized_standalone_module
# standalone module takes float input
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))

0 comments on commit 89b4899

Please sign in to comment.