diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 7965b3cc88a4..29649e428e04 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -564,7 +564,6 @@ def forward(self, x): "": None, "object_type": [ (nn.Conv2d, default_qconfig), - ("chunk", None) ] } # make sure it runs @@ -915,6 +914,49 @@ def forward(self, x): m = prepare_fx(m, qconfig_dict) m = convert_fx(m) + def test_qconfig_for_call_method(self): + class Sub(torch.nn.Module): + def forward(self, x): + return x.transpose(2, 3) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sub = Sub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) # default_qconfig + x = self.sub(x) # None + x = self.conv2(x) # default_qconfig + return x.transpose(2, 3) # default_qconfig + + + m = M().eval() + # since we don't quantize sub, we should have dequantize after the self.conv1 + # and quantize before self.conv2 + # however, the dequantize after conv2 should happen after x.transpose since + # it is configured with default_qconfig + qconfig_dict = {"": default_qconfig, "module_name": [("sub", None)]} + m = prepare_fx(m, qconfig_dict) + m(torch.randn(2, 1, 3, 3)) + m = convert_fx(m) + node_list = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("dequantize"), + ns.call_method("transpose"), + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("transpose"), + ns.call_method("dequantize") + ] + self.checkGraphModuleNodes(m, expected_node_list=node_list) + + # make sure it runs + m(torch.randn(2, 1, 3, 3)) + @skipIfNoFBGEMM def test_qat_and_script(self): model = LinearModelWithSubmodule().train() diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py index 3db370c4422d..42d4c1a3f3b7 100644 --- a/torch/quantization/fx/qconfig_utils.py +++ b/torch/quantization/fx/qconfig_utils.py @@ -88,10 +88,9 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): # get qconfig for module_name, # fallback to module_name_regex_qconfig, module_type_qconfig, # global_qconfig if necessary -def get_qconfig(modules, qconfig_dict, module_name, global_qconfig): - assert modules is not None +def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig): module_type_qconfig = get_object_type_qconfig( - qconfig_dict, type(modules[module_name]), global_qconfig) + qconfig_dict, module_type, global_qconfig) module_name_regex_qconfig = get_module_name_regex_qconfig( qconfig_dict, module_name, module_type_qconfig) module_name_qconfig = get_module_name_qconfig( diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 318295270b61..7a2e46a04e03 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -61,8 +61,6 @@ from .qconfig_utils import * -import warnings - from typing import Optional, Dict, Any, List, Tuple, Set, Callable # Define helper types @@ -346,50 +344,47 @@ def _generate_qconfig_map( self, root: torch.nn.Module, input_graph: Graph, - qconfig_dict: Any) -> None: - global_qconfig = qconfig_dict.get('', None) + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]]) -> None: + global_qconfig = qconfig_dict.get("", None) self.qconfig_map = dict() for node in input_graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": module_name, _ = _parent_name(node.target) + assert self.modules is not None self.qconfig_map[node.name] = get_qconfig( - self.modules, qconfig_dict, module_name, global_qconfig) - elif node.op == 'call_function': + qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig) + elif node.op == "call_function": # precedence: [TODO] module_name_qconfig (need scope support # from fx) # > function_qconfig > global_qconfig function_qconfig = get_object_type_qconfig( qconfig_dict, node.target, global_qconfig) self.qconfig_map[node.name] = function_qconfig - elif node.op == 'call_method': - self_obj = node.args[0] - # qconfig for call_method should be the same as the `self` - # object for the call - if self_obj.name in self.qconfig_map: - qconfig = self.qconfig_map[self_obj.name] - else: - # need scope info for each node to support this - warnings.warn( - "Scope info is not yet supported, taking default " + - "qconfig for value {}".format(node.name)) - qconfig = get_qconfig( - self.modules, qconfig_dict, '', global_qconfig) - qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig) + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # use the qconfig of the module that the node belongs to + qconfig = get_qconfig( + qconfig_dict, module_type, module_path, global_qconfig) self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': + assert self.modules is not None module_qconfig = get_qconfig( - self.modules, qconfig_dict, node.target, global_qconfig) + qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex # is used - assert self.modules is not None self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig - def _prepare(self, model: GraphModule, qconfig_dict: Any, - prepare_custom_config_dict: Optional[Dict[str, Any]], - is_standalone_module: bool) -> GraphModule: + def _prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]], + prepare_custom_config_dict: Optional[Dict[str, Any]], + is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -428,7 +423,7 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches - self._generate_qconfig_map(model, model.graph, qconfig_dict) + self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized standalone_module_name_configs = prepare_custom_config_dict.get( @@ -579,11 +574,15 @@ def restore_state(self, observed: GraphModule) -> None: self.prepare_custom_config_dict = \ observed._prepare_custom_config_dict # type: ignore - def prepare(self, model: GraphModule, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None, - is_standalone_module: bool = False) -> GraphModule: + def prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]], + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: return self._prepare( - model, qconfig_dict, prepare_custom_config_dict, + model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict, is_standalone_module) def _run_weight_observers(self, observed: GraphModule) -> None: diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 89ba877ffe78..9d51570dafeb 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -1,12 +1,13 @@ import torch from torch.fx import GraphModule # type: ignore from torch.fx.symbolic_trace import Tracer # type: ignore +from torch.fx.node import Target, Node, Argument # type: ignore from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from torch.nn.intrinsic import _FusedModule -from typing import Dict, Any, List, Callable +from typing import Dict, Any, List, Callable, Tuple, Optional def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): @@ -41,20 +42,99 @@ def _fuse_fx( fuser = Fuser() return fuser.fuse(graph_module, fuse_custom_config_dict) -class CustomTracer(Tracer): +class Scope(object): + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example: + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self): + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + +class ScopeContextManager(object): + """ A context manager to track the Scope of Node during symbolic + tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + def __init__( + self, + scope: Scope, + current_module: torch.nn.Module, + current_module_path: str): + super().__init__() + self.prev_module_type = scope.module_type + self.prev_module_path = scope.module_path + self.scope = scope + self.scope.module_path = current_module_path + self.scope.module_type = type(current_module) + + def __enter__(self): + return + + def __exit__(self, *args): + self.scope.module_path = self.prev_module_path + self.scope.module_type = self.prev_module_type + return + + +class QuantizationTracer(Tracer): def __init__(self, skipped_module_names: List[str], skipped_module_classes: List[Callable]): super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes - - def is_leaf_module(self, m, module_qualified_name): - return (m.__module__.startswith('torch.nn') and + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.node_name_to_scope = {} + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + return (m.__module__.startswith("torch.nn") and not isinstance(m, torch.nn.Sequential)) or \ module_qualified_name in self.skipped_module_names or \ type(m) in self.skipped_module_classes or \ isinstance(m, _FusedModule) + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + module_qualified_name = self.path_of_module(m) + if not self.is_leaf_module(m, module_qualified_name): + def scoped_forward(_args, _kwargs): + # Creating scope with information of current module + # scope will be restored automatically upon exit + with ScopeContextManager(self.scope, m, module_qualified_name): + return forward(*_args, **_kwargs) + return scoped_forward(args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) + + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) + if kind == "call_method": + self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type) + return node def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None, @@ -89,18 +169,20 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, float_custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") skipped_module_classes += float_custom_module_classes - tracer = CustomTracer(skipped_module_names, skipped_module_classes) + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) graph_module = _fuse_fx(graph_module, prepare_custom_config_dict) quantizer = Quantizer() return quantizer.prepare( graph_module, qconfig_dict, + tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, is_standalone_module=is_standalone_module) def _prepare_standalone_module_fx( - model: torch.nn.Module, qconfig_dict: Any, + model: torch.nn.Module, + qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module.