Skip to content

Commit

Permalink
Revert D25684692: [quant][graphmode][fx] Standalone module support {i…
Browse files Browse the repository at this point in the history
…nput/output}_quantized_idxs

Test Plan: revert-hammer

Differential Revision:
D25684692 (89b4899)

Original commit changeset: 900360e01c0e

fbshipit-source-id: 8b65fa8fbc7b364fbddb5f23cc696cd9b7db98cd
  • Loading branch information
Mike Ruberry authored and facebook-github-bot committed Dec 24, 2020
1 parent ec6de6a commit 46cf6d3
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 221 deletions.
126 changes: 30 additions & 96 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -570,16 +570,7 @@ def forward(self, x):
m = convert_fx(m)
m(tensor_input)

def _test_standalone_module(
self,
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check):
""" Test standalone module with different quantized input/quantized output
configurations
"""
def test_standalone_module(self):
class StandaloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -619,32 +610,45 @@ def forward(self, x):
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

for is_name in [True, False]:
if is_name:
prepare_config = {
"standalone_module_name": [("standalone", None, interface_config)]
}
else:
prepare_config = {
"standalone_module_class": [(StandaloneModule, None, interface_config)]
}

qconfig_dict = {"": default_qconfig}
config_name = {"standalone_module_name": [("standalone", None, None)]}
config_class = {"standalone_module_class": [(StandaloneModule, None, None)]}
for prepare_config in [config_name, config_class]:
original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)

qconfig_dict = {"": default_qconfig}
# check prepared model
m = prepare_fx(
original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config)
# calibration
m(data)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
# for input and output of conv in the standalone module
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)

# check converted/quantized model
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
res = m(data)

# quantize the reference model
Expand All @@ -654,76 +658,6 @@ def forward(self, x):
ref_res = ref_m(data)
self.assertEqual(res, ref_res)

def test_standalone_module_float_interface(self):
float_interface_config = {
"input_quantized_idxs": [], # float input
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for input and output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method("dequantize") : 1,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)

def test_standalone_module_quantized_interface(self):
quantized_interface_config = {
"input_quantized_idxs": [0], # quantized input
"output_quantized_idxs": [0], # quantized output
}
interface_config = quantized_interface_config
# observer for input and output of first conv
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 1
}
convert_count_check = {
# quantizing input for conv
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
# dequantizing output of standalone module
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# quantization of input happens in parent module
# quantization of output happens in the quantized conv module
ns.call_function(torch.quantize_per_tensor) : 0,
ns.call_module(nnq.Conv2d): 1,
# dequantization for output happens in parent module
ns.call_method("dequantize") : 0,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)

@skipIfNoFBGEMM
def test_qconfig_none(self):
class M(torch.nn.Module):
Expand Down
10 changes: 2 additions & 8 deletions torch/quantization/fx/observed_module.py
Expand Up @@ -2,11 +2,11 @@
import copy
from torch.fx import GraphModule # type: ignore
from torch.fx.graph import Graph
from typing import Union, Dict, Any, List
from typing import Union, Dict, Any

class ObservedGraphModule(GraphModule):

def get_preserved_attr_names(self) -> List[str]:
def get_preserved_attr_names(self):
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map',
Expand Down Expand Up @@ -35,12 +35,6 @@ def is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)

class ObservedStandaloneGraphModule(ObservedGraphModule):
def get_preserved_attr_names(self) -> List[str] :
return super().get_preserved_attr_names() + [
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs"
]

def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
Expand Down
4 changes: 2 additions & 2 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -753,10 +753,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
qconfig = quantizer.qconfig_map[node.name]
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore
observed_standalone_module = quantizer.modules[node.target]
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs
quantized_standalone_module = convert(observed_standalone_module, debug=debug)
parent_name, name = _parent_name(node.target)
# update the modules dict
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
quantizer.modules[node.target] = quantized_standalone_module
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
# standalone module takes float input
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))

0 comments on commit 46cf6d3

Please sign in to comment.