From 0923d19601e9c4853a90d0b6b2f73a29b0a28af6 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Sat, 5 Dec 2020 08:42:13 -0800 Subject: [PATCH] fx quant: add types to quantization_patterns (#48851) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48851 Adding typing to improve readability. Note: this uncovered a few missing return statements, we should fix that before landing. Test Plan: ``` mypy torch/quantization/ ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25338644 fbshipit-source-id: 0ac4405db05fdd2737bc3415217bc1937c2db684 --- .../quantization/fx/quantization_patterns.py | 95 +++++++++++++------ 1 file changed, 67 insertions(+), 28 deletions(-) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 72e165f8351e..176cd7603286 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -37,6 +37,13 @@ import operator import warnings +from typing import Any, Callable, Dict + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +QuantizerCls = Any + # ------------------------- # Pattern Registrations # ------------------------- @@ -47,7 +54,7 @@ class QuantizeHandler(ABC): """ Base handler class for the quantizer patterns """ - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): """ Records pattern information in __init__, which will be used in convert """ @@ -58,7 +65,9 @@ def __init__(self, quantizer, node): self.all_node_args = True @abstractmethod - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: """ Convert the given node to a quantized node and insert it to the quantized graph """ @@ -71,18 +80,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern((torch.nn.functional.relu, operator.add)) @register_quant_pattern((torch.nn.functional.relu, torch.add)) class Add(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] + node = node.args[0] # type: ignore assert node.op == 'call_function' and node.target in [operator.add, torch.add] self.add_node = node self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)]) - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: if self.num_node_args == 1: # add scalar if self.relu_node is not None: @@ -119,18 +130,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern((torch.nn.functional.relu, operator.mul)) @register_quant_pattern((torch.nn.functional.relu, torch.mul)) class Mul(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] + node = node.args[0] # type: ignore assert node.op == 'call_function' and node.target in [operator.mul, torch.mul] self.mul_node = node self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)]) - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: if self.num_node_args == 1: # mul scalar if self.relu_node is not None: @@ -159,7 +172,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.cat) class Cat(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: if not self.all_node_args: return NotImplemented activation_post_process = quantizer.activation_post_process_map[node.name] @@ -191,18 +206,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) class ConvRelu(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] + node = node.args[0] # type: ignore self.conv_node = node if node.op == 'call_module': self.conv = quantizer.modules[self.conv_node.target] - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: # TODO: debug option for conv module qconfig = quantizer.qconfig_map[node.name] activation_statically_quantized = activation_is_statically_quantized(qconfig) @@ -230,7 +247,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ self.conv_node.target, (load_arg(quantized=True)(self.conv_node.args[0]),), {}) - elif self.conv_node.op == 'call_function': + else: # call_function + assert self.conv_node.op == 'call_function' if self.relu_node is not None: raise Exception("functional conv + relu is not supported yet") if debug: @@ -273,18 +291,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern((torch.nn.ReLU, torch.nn.Linear)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear)) class LinearReLUQuantizeHandler(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node - node = node.args[0] + node = node.args[0] # type: ignore self.linear_node = node if node.op == 'call_module': self.linear = quantizer.modules[self.linear_node.target] - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: # Supported combinations are: # quant_type | activation (compute_type) | weight # static quint8 qint8 @@ -338,7 +358,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ 'call_module', self.linear_node.target, (load_arg(quantized=activation_statically_quantized)(self.linear_node.args[0]),), {}) - elif self.linear_node.op == 'call_function': + else: # call_function + assert self.linear_node.op == 'call_function' if debug: quantized_input_idxs = [] if activation_statically_quantized: @@ -405,13 +426,15 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.nn.intrinsic.BNReLU2d) @register_quant_pattern(torch.nn.intrinsic.BNReLU3d) class BatchNorm(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) assert node.op == 'call_module' self.bn_node = node self.bn = quantizer.modules[self.bn_node.target] - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: if convert_custom_config_dict is None: convert_custom_config_dict = {} additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) @@ -431,10 +454,12 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.nn.EmbeddingBag) @mark_input_output_not_observed() class Embedding(QuantizeHandler): - def __init__(self, quantizer, node): + def __init__(self, quantizer: QuantizerCls, node: Node): super().__init__(quantizer, node) - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: # Supported combinations are: # quant_type | activation | weight | activation_compute_type # weight_only | float32 | quint8 | None @@ -486,7 +511,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: if not self.all_node_args: return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ @@ -528,7 +555,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ # TODO: elu is using scale/zero_point instead of output_scale, output_zero_point @register_quant_pattern(torch.nn.functional.elu) class ELU(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) @@ -553,7 +582,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant) @register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant) class FixedQParamsOpQuantizeHandler(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) # these ops have quantized equivalents that do not need any extra information @@ -622,13 +653,17 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern('unsqueeze_') @register_quant_pattern('view') class CopyNode(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) # Default quantization handler, used for quantization of input and output # of quantizable objects (e.g. modules and functionals) class DefaultQuantizeHandler(QuantizeHandler): - def convert(self, quantizer, node): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: assert self.all_node_args root_module = quantizer.modules[''] return quantize_node( @@ -637,7 +672,9 @@ def convert(self, quantizer, node): node, quantizer.activation_post_process_map[node.name]) class CustomModuleQuantizeHandler(QuantizeHandler): - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: """ Convert a float custom module to quantized custom module """ assert node.op == 'call_module' @@ -666,7 +703,9 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler): """ Converts an observed standalone module to quantized standalone module by calling convert_fx on the observed standalone module. """ - def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, + debug: bool = False, + convert_custom_config_dict: Dict[str, Any] = None) -> Node: assert node.op == 'call_module' qconfig = quantizer.qconfig_map[node.name] convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore