diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 8c776f515a3a..fe7dc53a8019 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -447,15 +447,18 @@ def __init__(self): self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None - def _qat_swap_modules(self, root, additional_qat_module_mapping): + def _qat_swap_modules( + self, root: torch.nn.Module, + additional_qat_module_mapping: Dict[Callable, Callable]) -> None: all_mappings = get_combined_dict( get_default_qat_module_mappings(), additional_qat_module_mapping) convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False) - def _generate_qconfig_map(self, - root, - input_graph, - qconfig_dict): + def _generate_qconfig_map( + self, + root: torch.nn.Module, + input_graph: Graph, + qconfig_dict: Any) -> None: global_qconfig = qconfig_dict.get('', None) self.qconfig_map = dict() @@ -495,8 +498,9 @@ def _generate_qconfig_map(self, self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig - def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, - is_standalone_module): + def _prepare(self, model: GraphModule, qconfig_dict: Any, + prepare_custom_config_dict: Optional[Dict[str, Any]], + is_standalone_module: bool) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -534,6 +538,7 @@ def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, "standalone_module_class", None) custom_module_classes = get_custom_module_class_keys( prepare_custom_config_dict, "float_to_observed_custom_module_class") + assert self.patterns is not None matches = self._find_matches( model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes) @@ -552,7 +557,7 @@ def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed - standalone_module_observed_input_idxs = [] + standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': @@ -602,25 +607,28 @@ def load_arg(a): model = mark_observed_module(model) return model - def save_state(self, observed): - observed._activation_post_process_map = self.activation_post_process_map - observed._patterns = self.patterns - observed._qconfig_map = self.qconfig_map + def save_state(self, observed: GraphModule) -> None: + observed._activation_post_process_map = \ + self.activation_post_process_map # type: ignore + observed._patterns = self.patterns # type: ignore + observed._qconfig_map = self.qconfig_map # type: ignore - def restore_state(self, observed): + def restore_state(self, observed: GraphModule) -> None: assert is_observed_module(observed), \ 'incoming model must be produced by prepare_fx' - self.activation_post_process_map = observed._activation_post_process_map - self.patterns = observed._patterns - self.qconfig_map = observed._qconfig_map - - def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None, - is_standalone_module=False): + self.activation_post_process_map = \ + observed._activation_post_process_map # type: ignore + self.patterns = observed._patterns # type: ignore + self.qconfig_map = observed._qconfig_map # type: ignore + + def prepare(self, model: GraphModule, qconfig_dict: Any, + prepare_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: return self._prepare( model, qconfig_dict, prepare_custom_config_dict, is_standalone_module) - def _run_weight_observers(self, observed): + def _run_weight_observers(self, observed: GraphModule) -> None: r''' Extract the subgraph that produces the weight for dynamic quant or weight only quant node and run the subgraph to observe the weight. Note that the observers of dynamic quant or weight only quant ops are @@ -640,8 +648,9 @@ def _run_weight_observers(self, observed): weight_observer_module() return - def _convert(self, model, debug=False, convert_custom_config_dict=None, - is_standalone_module=False): + def _convert(self, model: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -662,6 +671,7 @@ def _convert(self, model, debug=False, convert_custom_config_dict=None, custom_module_classes = get_custom_module_class_keys( convert_custom_config_dict, "observed_to_quantized_custom_module_class") + assert self.patterns is not None matches = self._find_matches( model.graph, self.modules, self.patterns, custom_module_classes=custom_module_classes) @@ -905,7 +915,7 @@ def load_arg(a): # type: ignore # Trace back from the weight node util we hit getattr, reconstruct the # graph module with the traced nodes and run the graph module to pack the # weight. then replace the original chain of ops with the packed weight. - def _fold_weight(self, quantized): + def _fold_weight(self, quantized: GraphModule) -> GraphModule: packed_weights = dict() # map from folded node name to the prepacked weight name folded_nodes = dict() @@ -951,8 +961,9 @@ def load_arg(a): quantized = GraphModule(quantized_root, folded_graph) return quantized - def convert(self, model, debug=False, convert_custom_config_dict=None, - is_standalone_module=False): + def convert(self, model: GraphModule, debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None, + is_standalone_module: bool = False) -> GraphModule: quantized = self._convert( model, debug, convert_custom_config_dict, is_standalone_module) if not debug: @@ -960,10 +971,11 @@ def convert(self, model, debug=False, convert_custom_config_dict=None, return quantized def _find_matches( - self, graph, modules, patterns, - standalone_module_names=None, - standalone_module_classes=None, - custom_module_classes=None) -> Dict[str, MatchResult]: + self, graph: Graph, modules: Dict[str, torch.nn.Module], + patterns: Dict[Pattern, QuantizeHandler], + standalone_module_names: List[str] = None, + standalone_module_classes: List[Callable] = None, + custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]: """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. @@ -1017,7 +1029,7 @@ def record_match(pattern, node, matched): record_match(pattern, node, matched) for n in matched: match_map[n.name] = ( - node, matched, pattern, value(self, node), + node, matched, pattern, value(self, node), # type: ignore self.qconfig_map[n.name]) all_matched.add(n.name) # break after finding the first match @@ -1035,8 +1047,10 @@ def record_match(pattern, node, matched): def is_standalone_module(node_target): assert self.modules is not None - return node_target in standalone_module_names or \ - type(self.modules[node_target]) in standalone_module_classes + return ( + node_target in standalone_module_names or # type: ignore + type(self.modules[node_target]) in standalone_module_classes # type: ignore + ) # add standalone modules to the match for node in graph.nodes: