Skip to content

Commit

Permalink
[quant][graphmode][fx] Change standalone module api (#49719)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49719

We find there are multiple use cases for standalone module, one use case requires standalone module
to produce a module that takes float Tensor as input and outputs a float Tensor, the other needs to
produce a modulee that takes quantized Tensor as input and outputs a quantized Tensor.

This is similar to `quantized_input_idxs` and `quantized_output_idxs` so we want to nest
prepare_custom_config_dict in the standalone module configuration, for maximum flxibility we also
include qconfig_dict for stand alone module as well in case user needs to have special qconfig_dict for
the standalone module in the future.

Changed from
```python
prepare_custom_config_dict =
{
  "standalone_module_name": ["standalone_module"],
   "standalone_module_class": [StandaloneModule]
 }
```
to
```python
prepare_custom_config_dict =
{
  "standalone_module_name": [("standalone_module", qconfig_dict1, prepare_custom_config_dict1)],
  "standalone_module_class": [(StandaloneModule, qconfig_dict2, prepare_custom_config_dict2)]
 }
```
The entries in the config are:
1. name/module_class
2. optional qconfig_dict, when it is None, we'll use {"": qconfig} where qconfig is the one from parent qconfig_dict
3. optional prepare_custom_config_dict, when it is None, we'll use default value of prepare_custom_config_dict for prepare API (None)

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_standalone_module

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D25675704

fbshipit-source-id: 0889f519a3e55a7a677f0e2db4db9a18d87a93d4
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Dec 23, 2020
1 parent af1b636 commit f474ffa
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 21 deletions.
4 changes: 2 additions & 2 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -611,8 +611,8 @@ def forward(self, x):
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

qconfig_dict = {"": default_qconfig}
config_name = {"standalone_module_name": ["standalone"]}
config_class = {"standalone_module_class": [StandaloneModule]}
config_name = {"standalone_module_name": [("standalone", None, None)]}
config_class = {"standalone_module_class": [(StandaloneModule, None, None)]}
for prepare_config in [config_name, config_class]:
original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)
Expand Down
36 changes: 23 additions & 13 deletions torch/quantization/fx/quantize.py
Expand Up @@ -124,11 +124,18 @@ def insert_observer_for_special_module(
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = modules[node.target] # type: ignore
standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", [])
standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", [])
class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs}
name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs}
config = class_config_map.get(type(standalone_module), (None, None))
config = name_config_map.get(node.target, (None, None))
standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
standalone_prepare_config_dict = {} if config[1] is None else config[1]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
prepare(standalone_module, {"": qconfig})
observed_standalone_module.qconfig = qconfig
prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict)
observed_standalone_module = mark_observed_standalone_module(
observed_standalone_module)
parent_name, name = _parent_name(node.target)
Expand Down Expand Up @@ -395,10 +402,13 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any,
self._generate_qconfig_map(model, model.graph, qconfig_dict)

# match the patterns that will get quantized
standalone_module_names = prepare_custom_config_dict.get(
"standalone_module_name", None)
standalone_module_classes = prepare_custom_config_dict.get(
"standalone_module_class", None)
standalone_module_name_configs = prepare_custom_config_dict.get(
"standalone_module_name", [])
standalone_module_class_configs = prepare_custom_config_dict.get(
"standalone_module_class", [])

standalone_module_names = [config[0] for config in standalone_module_name_configs]
standalone_module_classes = [config[0] for config in standalone_module_class_configs]
custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
assert self.patterns is not None
Expand Down Expand Up @@ -754,21 +764,21 @@ def insert_quantize_node(node: Node) -> None:
root_node, matched, matched_pattern, obj, qconfig = \
matches.get(node.name, (None, None, None, None, None))
if root_node is node:
if qconfig is None:
is_observed_standalone_module_node = (
node.op == 'call_module' and
is_observed_standalone_module(
self.modules[node.target]) # type: ignore
)
if qconfig is None and not is_observed_standalone_module_node:
result = self.quantized_graph.node_copy(
node, load_non_quantized)
quantized = False
else:
assert obj is not None
is_standalone_module_node = (
node.op == 'call_module' and
is_observed_standalone_module(
self.modules[node.target]) # type: ignore
)
result = obj.convert(
self, node, load_arg, debug=debug,
convert_custom_config_dict=convert_custom_config_dict)
if is_standalone_module_node:
if is_observed_standalone_module_node:
quantized = False
else:
quantized = is_output_quantized(node, obj)
Expand Down
20 changes: 14 additions & 6 deletions torch/quantization/quantize_fx.py
Expand Up @@ -81,11 +81,11 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
# symbolically trace the model
if not is_standalone_module:
# standalone module and custom module config are applied in top level module
standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', [])
skipped_module_names += standalone_module_names
standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", [])
skipped_module_names += [config[0] for config in standalone_module_name_configs]

standalone_module_classes = prepare_custom_config_dict.get('standalone_module_class', [])
skipped_module_classes += standalone_module_classes
standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", [])
skipped_module_classes += [config[0] for config in standalone_module_class_configs]
float_custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
skipped_module_classes += float_custom_module_classes
Expand Down Expand Up @@ -178,11 +178,19 @@ def prepare_fx(
# optional: specify the path for standalone modules
# These modules are symbolically traced and quantized as one unit
"standalone_module_name": [
"submodule.standalone"
# module_name, qconfig_dict, prepare_custom_config_dict
("submodule.standalone",
None, # qconfig_dict for the prepare function called in the submodule,
# None means use qconfig from parent qconfig_dict
{"input_quantized_idxs": [], "output_quantized_idxs": []}) # prepare_custom_config_dict
],
"standalone_module_class": [
StandaloneModule
# module_class, qconfig_dict, prepare_custom_config_dict
(StandaloneModule,
None, # qconfig_dict for the prepare function called in the submodule,
# None means use qconfig from parent qconfig_dict
{"input_quantized_idxs": [0], "output_quantized_idxs": [0]}) # prepare_custom_config_dict
],
# user will manually define the corresponding observed
Expand Down

0 comments on commit f474ffa

Please sign in to comment.