Skip to content

Commit

Permalink
fx quant: add types to quantization_patterns (#48851)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48851

Adding typing to improve readability.

Note: this uncovered a few missing return statements, we should
fix that before landing.

Test Plan:
```
mypy torch/quantization/
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25338644

fbshipit-source-id: 0ac4405db05fdd2737bc3415217bc1937c2db684
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 5, 2020
1 parent fa5f7d8 commit 0923d19
Showing 1 changed file with 67 additions and 28 deletions.
95 changes: 67 additions & 28 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -37,6 +37,13 @@
import operator
import warnings

from typing import Any, Callable, Dict

# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
# Define separately to prevent circular imports.
# TODO(future PR): improve this.
QuantizerCls = Any

# -------------------------
# Pattern Registrations
# -------------------------
Expand All @@ -47,7 +54,7 @@
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
""" Records pattern information in __init__, which will be used
in convert
"""
Expand All @@ -58,7 +65,9 @@ def __init__(self, quantizer, node):
self.all_node_args = True

@abstractmethod
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
""" Convert the given node to a quantized node and insert
it to the quantized graph
"""
Expand All @@ -71,18 +80,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern((torch.nn.functional.relu, operator.add))
@register_quant_pattern((torch.nn.functional.relu, torch.add))
class Add(QuantizeHandler):
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
node = node.args[0] # type: ignore
assert node.op == 'call_function' and node.target in [operator.add, torch.add]
self.add_node = node
self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)])

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if self.num_node_args == 1:
# add scalar
if self.relu_node is not None:
Expand Down Expand Up @@ -119,18 +130,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
class Mul(QuantizeHandler):
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
node = node.args[0] # type: ignore
assert node.op == 'call_function' and node.target in [operator.mul, torch.mul]
self.mul_node = node
self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)])

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if self.num_node_args == 1:
# mul scalar
if self.relu_node is not None:
Expand Down Expand Up @@ -159,7 +172,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_

@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if not self.all_node_args:
return NotImplemented
activation_post_process = quantizer.activation_post_process_map[node.name]
Expand Down Expand Up @@ -191,18 +206,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
class ConvRelu(QuantizeHandler):
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
node = node.args[0] # type: ignore
self.conv_node = node
if node.op == 'call_module':
self.conv = quantizer.modules[self.conv_node.target]

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# TODO: debug option for conv module
qconfig = quantizer.qconfig_map[node.name]
activation_statically_quantized = activation_is_statically_quantized(qconfig)
Expand Down Expand Up @@ -230,7 +247,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
self.conv_node.target,
(load_arg(quantized=True)(self.conv_node.args[0]),),
{})
elif self.conv_node.op == 'call_function':
else: # call_function
assert self.conv_node.op == 'call_function'
if self.relu_node is not None:
raise Exception("functional conv + relu is not supported yet")
if debug:
Expand Down Expand Up @@ -273,18 +291,20 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLUQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
node = node.args[0] # type: ignore
self.linear_node = node
if node.op == 'call_module':
self.linear = quantizer.modules[self.linear_node.target]

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation (compute_type) | weight
# static quint8 qint8
Expand Down Expand Up @@ -338,7 +358,8 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
'call_module',
self.linear_node.target,
(load_arg(quantized=activation_statically_quantized)(self.linear_node.args[0]),), {})
elif self.linear_node.op == 'call_function':
else: # call_function
assert self.linear_node.op == 'call_function'
if debug:
quantized_input_idxs = []
if activation_statically_quantized:
Expand Down Expand Up @@ -405,13 +426,15 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
class BatchNorm(QuantizeHandler):
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
assert node.op == 'call_module'
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
Expand All @@ -431,10 +454,12 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.nn.EmbeddingBag)
@mark_input_output_not_observed()
class Embedding(QuantizeHandler):
def __init__(self, quantizer, node):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
# Supported combinations are:
# quant_type | activation | weight | activation_compute_type
# weight_only | float32 | quint8 | None
Expand Down Expand Up @@ -486,7 +511,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
class DefaultNode(QuantizeHandler):
''' Common quantized op, first input and first output will be quantized
'''
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
if not self.all_node_args:
return NotImplemented
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
Expand Down Expand Up @@ -528,7 +555,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
@register_quant_pattern(torch.nn.functional.elu)
class ELU(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
Expand All @@ -553,7 +582,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant)
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant)
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

# these ops have quantized equivalents that do not need any extra information
Expand Down Expand Up @@ -622,13 +653,17 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern('unsqueeze_')
@register_quant_pattern('view')
class CopyNode(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

# Default quantization handler, used for quantization of input and output
# of quantizable objects (e.g. modules and functionals)
class DefaultQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
assert self.all_node_args
root_module = quantizer.modules['']
return quantize_node(
Expand All @@ -637,7 +672,9 @@ def convert(self, quantizer, node):
node, quantizer.activation_post_process_map[node.name])

class CustomModuleQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
""" Convert a float custom module to quantized custom module
"""
assert node.op == 'call_module'
Expand Down Expand Up @@ -666,7 +703,9 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler):
""" Converts an observed standalone module to quantized standalone module
by calling convert_fx on the observed standalone module.
"""
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
assert node.op == 'call_module'
qconfig = quantizer.qconfig_map[node.name]
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore
Expand Down

0 comments on commit 0923d19

Please sign in to comment.