From a1ec09bd3ce4342a9d480232da719f8c26a2eae9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 7 Jan 2021 11:28:46 -0800 Subject: [PATCH] [quant][graphmode][fx] Scope support for call_method in QuantizationTracer Summary: Previously we did not set the qconfig for call_method node correctly since it requires us to know the scope (module path of the module whose forward graph contains the node) of the node. This PR modifies the QuantizationTracer to record the scope information and build a map from call_method Node to module path, which will be used when we construct qconfig_map Test Plan: python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f30902dc69be9c263d31a851b07f757efe5ac2e1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50173 --- test/quantization/test_quantize_fx.py | 44 +++++++++++- torch/quantization/fx/qconfig_utils.py | 5 +- torch/quantization/fx/quantize.py | 61 ++++++++-------- torch/quantization/quantize_fx.py | 96 ++++++++++++++++++++++++-- 4 files changed, 164 insertions(+), 42 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 7965b3cc88a4..29649e428e04 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -564,7 +564,6 @@ def forward(self, x): "": None, "object_type": [ (nn.Conv2d, default_qconfig), - ("chunk", None) ] } # make sure it runs @@ -915,6 +914,49 @@ def forward(self, x): m = prepare_fx(m, qconfig_dict) m = convert_fx(m) + def test_qconfig_for_call_method(self): + class Sub(torch.nn.Module): + def forward(self, x): + return x.transpose(2, 3) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sub = Sub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) # default_qconfig + x = self.sub(x) # None + x = self.conv2(x) # default_qconfig + return x.transpose(2, 3) # default_qconfig + + + m = M().eval() + # since we don't quantize sub, we should have dequantize after the self.conv1 + # and quantize before self.conv2 + # however, the dequantize after conv2 should happen after x.transpose since + # it is configured with default_qconfig + qconfig_dict = {"": default_qconfig, "module_name": [("sub", None)]} + m = prepare_fx(m, qconfig_dict) + m(torch.randn(2, 1, 3, 3)) + m = convert_fx(m) + node_list = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("dequantize"), + ns.call_method("transpose"), + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("transpose"), + ns.call_method("dequantize") + ] + self.checkGraphModuleNodes(m, expected_node_list=node_list) + + # make sure it runs + m(torch.randn(2, 1, 3, 3)) + @skipIfNoFBGEMM def test_qat_and_script(self): model = LinearModelWithSubmodule().train() diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py index 3db370c4422d..42d4c1a3f3b7 100644 --- a/torch/quantization/fx/qconfig_utils.py +++ b/torch/quantization/fx/qconfig_utils.py @@ -88,10 +88,9 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): # get qconfig for module_name, # fallback to module_name_regex_qconfig, module_type_qconfig, # global_qconfig if necessary -def get_qconfig(modules, qconfig_dict, module_name, global_qconfig): - assert modules is not None +def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig): module_type_qconfig = get_object_type_qconfig( - qconfig_dict, type(modules[module_name]), global_qconfig) + qconfig_dict, module_type, global_qconfig) module_name_regex_qconfig = get_module_name_regex_qconfig( qconfig_dict, module_name, module_type_qconfig) module_name_qconfig = get_module_name_qconfig( diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 318295270b61..7a2e46a04e03 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -61,8 +61,6 @@ from .qconfig_utils import * -import warnings - from typing import Optional, Dict, Any, List, Tuple, Set, Callable # Define helper types @@ -346,50 +344,47 @@ def _generate_qconfig_map( self, root: torch.nn.Module, input_graph: Graph, - qconfig_dict: Any) -> None: - global_qconfig = qconfig_dict.get('', None) + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]]) -> None: + global_qconfig = qconfig_dict.get("", None) self.qconfig_map = dict() for node in input_graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": module_name, _ = _parent_name(node.target) + assert self.modules is not None self.qconfig_map[node.name] = get_qconfig( - self.modules, qconfig_dict, module_name, global_qconfig) - elif node.op == 'call_function': + qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig) + elif node.op == "call_function": # precedence: [TODO] module_name_qconfig (need scope support # from fx) # > function_qconfig > global_qconfig function_qconfig = get_object_type_qconfig( qconfig_dict, node.target, global_qconfig) self.qconfig_map[node.name] = function_qconfig - elif node.op == 'call_method': - self_obj = node.args[0] - # qconfig for call_method should be the same as the `self` - # object for the call - if self_obj.name in self.qconfig_map: - qconfig = self.qconfig_map[self_obj.name] - else: - # need scope info for each node to support this - warnings.warn( - "Scope info is not yet supported, taking default " + - "qconfig for value {}".format(node.name)) - qconfig = get_qconfig( - self.modules, qconfig_dict, '', global_qconfig) - qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig) + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # use the qconfig of the module that the node belongs to + qconfig = get_qconfig( + qconfig_dict, module_type, module_path, global_qconfig) self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': + assert self.modules is not None module_qconfig = get_qconfig( - self.modules, qconfig_dict, node.target, global_qconfig) + qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex # is used - assert self.modules is not None self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig - def _prepare(self, model: GraphModule, qconfig_dict: Any, - prepare_custom_config_dict: Optional[Dict[str, Any]], - is_standalone_module: bool) -> GraphModule: + def _prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]], + prepare_custom_config_dict: Optional[Dict[str, Any]], + is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -428,7 +423,7 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches - self._generate_qconfig_map(model, model.graph, qconfig_dict) + self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized standalone_module_name_configs = prepare_custom_config_dict.get( @@ -579,11 +574,15 @@ def restore_state(self, observed: GraphModule) -> None: self.prepare_custom_config_dict = \ observed._prepare_custom_config_dict # type: ignore - def prepare(self, model: GraphModule, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None, - is_standalone_module: bool = False) -> GraphModule: + def prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]], + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: return self._prepare( - model, qconfig_dict, prepare_custom_config_dict, + model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict, is_standalone_module) def _run_weight_observers(self, observed: GraphModule) -> None: diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 89ba877ffe78..9d51570dafeb 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -1,12 +1,13 @@ import torch from torch.fx import GraphModule # type: ignore from torch.fx.symbolic_trace import Tracer # type: ignore +from torch.fx.node import Target, Node, Argument # type: ignore from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from torch.nn.intrinsic import _FusedModule -from typing import Dict, Any, List, Callable +from typing import Dict, Any, List, Callable, Tuple, Optional def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): @@ -41,20 +42,99 @@ def _fuse_fx( fuser = Fuser() return fuser.fuse(graph_module, fuse_custom_config_dict) -class CustomTracer(Tracer): +class Scope(object): + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example: + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self): + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + +class ScopeContextManager(object): + """ A context manager to track the Scope of Node during symbolic + tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + def __init__( + self, + scope: Scope, + current_module: torch.nn.Module, + current_module_path: str): + super().__init__() + self.prev_module_type = scope.module_type + self.prev_module_path = scope.module_path + self.scope = scope + self.scope.module_path = current_module_path + self.scope.module_type = type(current_module) + + def __enter__(self): + return + + def __exit__(self, *args): + self.scope.module_path = self.prev_module_path + self.scope.module_type = self.prev_module_type + return + + +class QuantizationTracer(Tracer): def __init__(self, skipped_module_names: List[str], skipped_module_classes: List[Callable]): super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes - - def is_leaf_module(self, m, module_qualified_name): - return (m.__module__.startswith('torch.nn') and + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.node_name_to_scope = {} + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + return (m.__module__.startswith("torch.nn") and not isinstance(m, torch.nn.Sequential)) or \ module_qualified_name in self.skipped_module_names or \ type(m) in self.skipped_module_classes or \ isinstance(m, _FusedModule) + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + module_qualified_name = self.path_of_module(m) + if not self.is_leaf_module(m, module_qualified_name): + def scoped_forward(_args, _kwargs): + # Creating scope with information of current module + # scope will be restored automatically upon exit + with ScopeContextManager(self.scope, m, module_qualified_name): + return forward(*_args, **_kwargs) + return scoped_forward(args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) + + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) + if kind == "call_method": + self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type) + return node def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None, @@ -89,18 +169,20 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, float_custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") skipped_module_classes += float_custom_module_classes - tracer = CustomTracer(skipped_module_names, skipped_module_classes) + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) graph_module = _fuse_fx(graph_module, prepare_custom_config_dict) quantizer = Quantizer() return quantizer.prepare( graph_module, qconfig_dict, + tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, is_standalone_module=is_standalone_module) def _prepare_standalone_module_fx( - model: torch.nn.Module, qconfig_dict: Any, + model: torch.nn.Module, + qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module.