From c35b966d3710b8e774ef7cbf8da11b56123cbf0f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 6 Jan 2021 16:15:44 -0800 Subject: [PATCH 1/5] [quant][graphmode][fx] Scope support for call_method in QuantizationTracer Summary: Previously we did not set the qconfig for call_method node correctly since it requires us to know the scope (module path of the module whose forward graph contains the node) of the node. This PR modifies the QuantizationTracer to record the scope information and build a map from call_method Node to module path, which will be used when we construct qconfig_map Test Plan: python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 44 ++++++++++++++++++- torch/quantization/fx/qconfig_utils.py | 5 +-- torch/quantization/fx/quantize.py | 56 ++++++++++++------------ torch/quantization/quantize_fx.py | 60 +++++++++++++++++++++++--- 4 files changed, 126 insertions(+), 39 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 66324f928f04..fabc3002f403 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -561,7 +561,6 @@ def forward(self, x): "": None, "object_type": [ (nn.Conv2d, default_qconfig), - ("chunk", None) ] } # make sure it runs @@ -846,6 +845,49 @@ def forward(self, x): m = prepare_fx(m, qconfig_dict) m = convert_fx(m) + def test_qconfig_for_call_method(self): + class Sub(torch.nn.Module): + def forward(self, x): + return x.transpose(2, 3) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sub = Sub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) # default_qconfig + x = self.sub(x) # None + x = self.conv2(x) # default_qconfig + return x.transpose(2, 3) # default_qconfig + + + m = M().eval() + # since we don't quantize sub, we should have dequantize after the self.conv1 + # and quantize before self.conv2 + # however, the dequantize after conv2 should happen after x.transpose since + # it is configured with default_qconfig + qconfig_dict = {"": default_qconfig, "module_name": [("sub", None)]} + m = prepare_fx(m, qconfig_dict) + m(torch.randn(2, 1, 3, 3)) + m = convert_fx(m) + node_list = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("dequantize"), + ns.call_method("transpose"), + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_method("transpose"), + ns.call_method("dequantize") + ] + self.checkGraphModuleNodes(m, expected_node_list=node_list) + + # make sure it runs + m(torch.randn(2, 1, 3, 3)) + @skipIfNoFBGEMM def test_qat_and_script(self): model = LinearModelWithSubmodule().train() diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py index 3db370c4422d..42d4c1a3f3b7 100644 --- a/torch/quantization/fx/qconfig_utils.py +++ b/torch/quantization/fx/qconfig_utils.py @@ -88,10 +88,9 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): # get qconfig for module_name, # fallback to module_name_regex_qconfig, module_type_qconfig, # global_qconfig if necessary -def get_qconfig(modules, qconfig_dict, module_name, global_qconfig): - assert modules is not None +def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig): module_type_qconfig = get_object_type_qconfig( - qconfig_dict, type(modules[module_name]), global_qconfig) + qconfig_dict, module_type, global_qconfig) module_name_regex_qconfig = get_module_name_regex_qconfig( qconfig_dict, module_name, module_type_qconfig) module_name_qconfig = get_module_name_qconfig( diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index af9496a66a63..c83a663c98fd 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -326,40 +326,32 @@ def _generate_qconfig_map( self, root: torch.nn.Module, input_graph: Graph, - qconfig_dict: Any) -> None: - global_qconfig = qconfig_dict.get('', None) + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]]) -> None: + global_qconfig = qconfig_dict.get("", None) self.qconfig_map = dict() for node in input_graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": module_name, _ = _parent_name(node.target) self.qconfig_map[node.name] = get_qconfig( - self.modules, qconfig_dict, module_name, global_qconfig) - elif node.op == 'call_function': + qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig) + elif node.op == "call_function": # precedence: [TODO] module_name_qconfig (need scope support # from fx) # > function_qconfig > global_qconfig function_qconfig = get_object_type_qconfig( qconfig_dict, node.target, global_qconfig) self.qconfig_map[node.name] = function_qconfig - elif node.op == 'call_method': - self_obj = node.args[0] - # qconfig for call_method should be the same as the `self` - # object for the call - if self_obj.name in self.qconfig_map: - qconfig = self.qconfig_map[self_obj.name] - else: - # need scope info for each node to support this - warnings.warn( - "Scope info is not yet supported, taking default " + - "qconfig for value {}".format(node.name)) - qconfig = get_qconfig( - self.modules, qconfig_dict, '', global_qconfig) - qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig) + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # use the qconfig of the module that the node belongs to + qconfig = get_qconfig( + qconfig_dict, module_type, module_path, global_qconfig) self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': module_qconfig = get_qconfig( - self.modules, qconfig_dict, node.target, global_qconfig) + qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex # is used @@ -367,9 +359,13 @@ def _generate_qconfig_map( self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig - def _prepare(self, model: GraphModule, qconfig_dict: Any, - prepare_custom_config_dict: Optional[Dict[str, Any]], - is_standalone_module: bool) -> GraphModule: + def _prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]], + prepare_custom_config_dict: Optional[Dict[str, Any]], + is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -399,7 +395,7 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches - self._generate_qconfig_map(model, model.graph, qconfig_dict) + self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope) # match the patterns that will get quantized standalone_module_name_configs = prepare_custom_config_dict.get( @@ -536,11 +532,15 @@ def restore_state(self, observed: GraphModule) -> None: self.prepare_custom_config_dict = \ observed._prepare_custom_config_dict # type: ignore - def prepare(self, model: GraphModule, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None, - is_standalone_module: bool = False) -> GraphModule: + def prepare( + self, + model: GraphModule, + qconfig_dict: Any, + node_name_to_scope: Dict[str, str], + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: return self._prepare( - model, qconfig_dict, prepare_custom_config_dict, + model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict, is_standalone_module) def _run_weight_observers(self, observed: GraphModule) -> None: diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index cba104b8f783..c449ad2a9061 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -1,12 +1,13 @@ import torch from torch.fx import GraphModule # type: ignore from torch.fx.symbolic_trace import Tracer # type: ignore +from torch.fx.node import Target, Node, Argument # type: ignore from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from torch.nn.intrinsic import _FusedModule -from typing import Dict, Any, List, Callable +from typing import Dict, Any, List, Callable, Tuple, Optional def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): @@ -41,20 +42,62 @@ def _fuse_fx( fuser = Fuser() return fuser.fuse(graph_module, fuse_custom_config_dict) -class CustomTracer(Tracer): +class Scope(object): + def __init__(self, module_path, module_type): + super().__init__() + self.module_path = module_path + self.module_type = module_type + +class ScopeContextManager(object): + def __init__(self, scope, current_module, current_module_path): + super().__init__() + self.prev_module_type = scope.module_type + self.prev_module_path = scope.module_path + self.scope = scope + self.scope.module_path = current_module_path + self.scope.module_type = type(current_module) + + def __enter__(self): + return + + def __exit__(self, *args): + self.scope.module_path = self.prev_module_path + self.scope.module_type = self.prev_module_type + return + + +class QuantizationTracer(Tracer): def __init__(self, skipped_module_names: List[str], skipped_module_classes: List[Callable]): super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes + self.scope = Scope("", None) + self.node_name_to_scope = {} - def is_leaf_module(self, m, module_qualified_name): - return (m.__module__.startswith('torch.nn') and + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + return (m.__module__.startswith("torch.nn") and not isinstance(m, torch.nn.Sequential)) or \ module_qualified_name in self.skipped_module_names or \ type(m) in self.skipped_module_classes or \ isinstance(m, _FusedModule) + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + module_qualified_name = self.path_of_module(m) + if not self.is_leaf_module(m, module_qualified_name): + def scoped_forward(_args, _kwargs): + with ScopeContextManager(self.scope, m, module_qualified_name): + return forward(*_args, **_kwargs) + return scoped_forward(args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) + + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) + if kind == "call_method": + self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type) + return node def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, prepare_custom_config_dict: Dict[str, Any] = None, @@ -89,18 +132,21 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, float_custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") skipped_module_classes += float_custom_module_classes - tracer = CustomTracer(skipped_module_names, skipped_module_classes) + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) graph_module = _fuse_fx(graph_module, prepare_custom_config_dict) quantizer = Quantizer() return quantizer.prepare( graph_module, qconfig_dict, + tracer.node_name_to_scope, prepare_custom_config_dict=prepare_custom_config_dict, is_standalone_module=is_standalone_module) def _prepare_standalone_module_fx( - model: torch.nn.Module, qconfig_dict: Any, + model: torch.nn.Module, + qconfig_dict: Any, + node_name_to_scope: Dict[str, Tuple[str, Any]], prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. @@ -110,7 +156,7 @@ def _prepare_standalone_module_fx( Both input and output of the module are observed in the standalone module. """ - return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) + return _prepare_fx(model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict, is_standalone_module=True) def fuse_fx(model: torch.nn.Module, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: From bf520b5f9a573082ab3d6ce043ec699104547ef8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 6 Jan 2021 17:07:12 -0800 Subject: [PATCH 2/5] Update on "[quant][graphmode][fx] Scope support for call_method in QuantizationTracer" Summary: Previously we did not set the qconfig for call_method node correctly since it requires us to know the scope (module path of the module whose forward graph contains the node) of the node. This PR modifies the QuantizationTracer to record the scope information and build a map from call_method Node to module path, which will be used when we construct qconfig_map Test Plan: python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25818132](https://our.internmc.facebook.com/intern/diff/D25818132) [ghstack-poisoned] --- torch/quantization/fx/quantize.py | 2 -- torch/quantization/quantize_fx.py | 8 ++++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index c83a663c98fd..43141cda885c 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -61,8 +61,6 @@ from .qconfig_utils import * -import warnings - from typing import Optional, Dict, Any, List, Tuple, Set, Callable # Define helper types diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index c449ad2a9061..3fd614b36cbd 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -43,13 +43,17 @@ def _fuse_fx( return fuser.fuse(graph_module, fuse_custom_config_dict) class Scope(object): - def __init__(self, module_path, module_type): + def __init__(self, module_path: str, module_type: Any): super().__init__() self.module_path = module_path self.module_type = module_type class ScopeContextManager(object): - def __init__(self, scope, current_module, current_module_path): + def __init__( + self, + scope: Scope, + current_module: torch.nn.Module, + current_module_path: str): super().__init__() self.prev_module_type = scope.module_type self.prev_module_path = scope.module_path From c16325cd0bdd7e3585c461bdf267cff5535e7ef2 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 6 Jan 2021 17:29:29 -0800 Subject: [PATCH 3/5] Update on "[quant][graphmode][fx] Scope support for call_method in QuantizationTracer" Summary: Previously we did not set the qconfig for call_method node correctly since it requires us to know the scope (module path of the module whose forward graph contains the node) of the node. This PR modifies the QuantizationTracer to record the scope information and build a map from call_method Node to module path, which will be used when we construct qconfig_map Test Plan: python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25818132](https://our.internmc.facebook.com/intern/diff/D25818132) [ghstack-poisoned] --- torch/quantization/quantize_fx.py | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 204abc91f9a1..65b9b8c18f55 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -43,12 +43,38 @@ def _fuse_fx( return fuser.fuse(graph_module, fuse_custom_config_dict) class Scope(object): + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example: + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self): + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ def __init__(self, module_path: str, module_type: Any): super().__init__() self.module_path = module_path self.module_type = module_type class ScopeContextManager(object): + """ A context manager to track the Scope of Node during symbolic + tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ def __init__( self, scope: Scope, @@ -76,6 +102,11 @@ def __init__(self, skipped_module_names: List[str], super().__init__() self.skipped_module_names = skipped_module_names self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type self.scope = Scope("", None) self.node_name_to_scope = {} @@ -90,6 +121,8 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu module_qualified_name = self.path_of_module(m) if not self.is_leaf_module(m, module_qualified_name): def scoped_forward(_args, _kwargs): + # Creating scope with information of current module + # scope will be restored automatically upon exit with ScopeContextManager(self.scope, m, module_qualified_name): return forward(*_args, **_kwargs) return scoped_forward(args, kwargs) From 835ae2199e6b9d327a83cb880c9fba7835e74b8e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 7 Jan 2021 11:28:46 -0800 Subject: [PATCH 4/5] Update on "[quant][graphmode][fx] Scope support for call_method in QuantizationTracer" Summary: Previously we did not set the qconfig for call_method node correctly since it requires us to know the scope (module path of the module whose forward graph contains the node) of the node. This PR modifies the QuantizationTracer to record the scope information and build a map from call_method Node to module path, which will be used when we construct qconfig_map Test Plan: python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25818132](https://our.internmc.facebook.com/intern/diff/D25818132) [ghstack-poisoned] --- torch/quantization/fx/quantize.py | 5 +++-- torch/quantization/quantize_fx.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index e0b05cf349c8..7a2e46a04e03 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -352,6 +352,7 @@ def _generate_qconfig_map( for node in input_graph.nodes: if node.op == "get_attr": module_name, _ = _parent_name(node.target) + assert self.modules is not None self.qconfig_map[node.name] = get_qconfig( qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig) elif node.op == "call_function": @@ -368,12 +369,12 @@ def _generate_qconfig_map( qconfig_dict, module_type, module_path, global_qconfig) self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': + assert self.modules is not None module_qconfig = get_qconfig( qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex # is used - assert self.modules is not None self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig @@ -577,7 +578,7 @@ def prepare( self, model: GraphModule, qconfig_dict: Any, - node_name_to_scope: Dict[str, str], + node_name_to_scope: Dict[str, Tuple[str, Any]], prepare_custom_config_dict: Dict[str, Any] = None, is_standalone_module: bool = False) -> GraphModule: return self._prepare( diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 65b9b8c18f55..9d51570dafeb 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -183,7 +183,6 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, def _prepare_standalone_module_fx( model: torch.nn.Module, qconfig_dict: Any, - node_name_to_scope: Dict[str, Tuple[str, Any]], prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. @@ -205,7 +204,7 @@ def _prepare_standalone_module_fx( same as input_quantized_idxs configuration provided for the standalone module """ - return _prepare_fx(model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict, is_standalone_module=True) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) def fuse_fx(model: torch.nn.Module, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: From 46d4585b01266318645fd6afae589c6710721711 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 7 Jan 2021 13:55:05 -0800 Subject: [PATCH 5/5] Update on "[quant][graphmode][fx] Scope support for call_method in QuantizationTracer" Summary: Previously we did not set the qconfig for call_method node correctly since it requires us to know the scope (module path and type of the module whose forward graph contains the node) of the node. This PR modifies the QuantizationTracer to record the scope information and build a map from call_method Node to (module_path, module_type), which will be used when we construct qconfig_map Test Plan: python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25818132](https://our.internmc.facebook.com/intern/diff/D25818132) [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 8 ++++---- torch/quantization/quantize_fx.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 29649e428e04..3801012c35a6 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -934,14 +934,14 @@ def forward(self, x): m = M().eval() - # since we don't quantize sub, we should have dequantize after the self.conv1 - # and quantize before self.conv2 - # however, the dequantize after conv2 should happen after x.transpose since - # it is configured with default_qconfig qconfig_dict = {"": default_qconfig, "module_name": [("sub", None)]} m = prepare_fx(m, qconfig_dict) m(torch.randn(2, 1, 3, 3)) m = convert_fx(m) + # since sub is configured to have qconfig None, we should dequantize the output + # of self.conv1 and quantize the input of self.conv2 + # dequantize after conv2 should happen after transpose since + # it is configured with default_qconfig node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 9d51570dafeb..641ab4fddfde 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -108,7 +108,7 @@ def __init__(self, skipped_module_names: List[str], # We can change this if there is a use case that configures # qconfig using top level module type self.scope = Scope("", None) - self.node_name_to_scope = {} + self.node_name_to_scope : Dict[str, Tuple[str, Any]] = {} def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: return (m.__module__.startswith("torch.nn") and