Skip to content

Commit

Permalink
[quant][graphmode][fx] Merge all quantization mode
Browse files Browse the repository at this point in the history
Summary:
This PR merges all quantization mode and will only expose the following top level functions:
```
prepare_fx
prepare_qat_fx
convert_fx
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 264ca59e1e9255e9e5ca08194101e6f2ccc10af6
Pull Request resolved: #45292
  • Loading branch information
jerryzh168 committed Sep 25, 2020
1 parent 8d68115 commit 3a7cbbb
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 55 deletions.
9 changes: 3 additions & 6 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ def test_functional_debug(self):
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 1
node_occurrence[weight_prepack_node] = 0
node_occurrence[quantized_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node=quantized_node,
expected_node_occurrence=node_occurrence,
debug=True)

Expand Down Expand Up @@ -226,10 +226,7 @@ def forward(self, x):
for debug in [True, False]:
node_occurrence = dict()
if weight_prepack_node:
if debug:
node_occurrence[weight_prepack_node] = 1
else:
node_occurrence[weight_prepack_node] = 0
node_occurrence[weight_prepack_node] = 0
m = ModuleClass(*module_constructor_inputs).eval()
m = symbolic_trace(m)
qconfig_dict = {"": float16_dynamic_qconfig}
Expand Down
99 changes: 70 additions & 29 deletions torch/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from torch.fx.graph import (
Node,
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd

from ..quantization_mappings import (
get_static_quant_module_class,
get_quantized_operator,
Expand All @@ -17,6 +20,11 @@
_parent_name,
quantize_node,
get_per_tensor_qparams,
activation_is_dynamically_quantized,
activation_is_statically_quantized,
weight_is_quantized,
weight_dtype,
get_linear_prepack_op_for_dtype,
)

from abc import ABC, abstractmethod
Expand Down Expand Up @@ -235,7 +243,7 @@ def convert(self, quantizer, node, load_arg, debug=False):
# for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLU(QuantizeHandler):
class LinearReLUQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
Expand All @@ -248,50 +256,75 @@ def __init__(self, quantizer, node):
self.linear = quantizer.modules[self.linear_node.target]

def convert(self, quantizer, node, load_arg, debug=False):
qconfig = quantizer.qconfig_map[node.name]
activation_statically_quantized = activation_is_statically_quantized(qconfig)
# TODO: debug option for linear module
if self.linear_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name]
# 1. attach output activation post process to linear module
if node.name in quantizer.activation_post_process_map:
# this is the static quantization case
output_activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.linear.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
output_activation_post_process = None

if output_activation_post_process:
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
float_linear_module = self.linear[1].activation_post_process
else:
float_linear_module = self.linear
float_linear_module.activation_post_process = output_activation_post_process

# 2. select corresponding quantized linear class for the float linear class
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
qlinear = torch.nn.quantized.Linear
qlinear = nnq.Linear if activation_statically_quantized else nnqd.Linear
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
assert not is_dynamic_quant, 'LinearReLU does not support dynamic quantization'
qlinear = torch.nn.intrinsic.quantized.LinearReLU
else:
raise Exception("unhandled linear type:", type(self.linear))
quantized = qlinear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
# activation needs to be quantized for static quantization
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {})
self.linear_node.target,
(load_arg(quantized=activation_statically_quantized)(self.linear_node.args[0]),), {})
elif self.linear_node.op == 'call_function':
if debug:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
quantized_input_idxs = []
if activation_statically_quantized:
quantized_input_idxs.append(0)
if weight_is_quantized(qconfig):
quantized_input_idxs.append(1)
args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
linear_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
root_module = quantizer.modules['']
return quantize_node(
root_module,
quantizer.quantized_graph,
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
# TODO: this code can be merged with dynamic linear code
if activation_statically_quantized:
# quantize output for statically quantized linear op
root_module = quantizer.modules['']
return quantize_node(
root_module,
quantizer.quantized_graph,
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
# output for dynamically quantized linear op is not quantized
return linear_out
else: # non-debug option
# linear args
# (x, weight, bias, ...)
args = load_arg(quantized=[0, 1])(self.linear_node.args)
weight_quantized = weight_is_quantized(qconfig)
linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1])

# get other arguments
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
# all args after bias, including bias
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
Expand All @@ -303,17 +336,24 @@ def convert(self, quantizer, node, load_arg, debug=False):
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = (weight, bias)
prepack_args = (linear_weight, bias)
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype(qconfig))
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
'call_function', prepack_op, prepack_args, {})
# construct linear input
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qlinear_args = (linear_input, packed_weight, scale, zero_point)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
if activation_statically_quantized:
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qlinear_args = (linear_input, packed_weight, scale, zero_point)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
else:
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
qlinear_args = (linear_input, packed_weight)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_dynamic, qlinear_args, kwargs)

@register_quant_pattern(torch.nn.BatchNorm2d)
@register_quant_pattern(torch.nn.BatchNorm3d)
Expand Down Expand Up @@ -537,7 +577,8 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler):
"""
def convert(self, quantizer, node, load_arg, debug=False):
assert node.op == 'call_module'
if quantizer.is_dynamic_quant:
qconfig = quantizer.qconfig_map[node.name]
if activation_is_dynamically_quantized(qconfig):
convert = torch.quantization.convert_dynamic_child_module_fx
else:
convert = torch.quantization.convert_child_module_fx
Expand Down
34 changes: 17 additions & 17 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from .utils import (
_parent_name,
quantize_node,
activation_is_dynamically_quantized,
activation_is_statically_quantized,
)

from collections import OrderedDict
Expand Down Expand Up @@ -308,11 +310,10 @@ def get_qconfig(module_name):
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant, is_child_module):
if not inplace:
model = copy.deepcopy(model)
self.is_dynamic_quant = is_dynamic_quant
if self.is_dynamic_quant:
self.patterns = get_dynamic_quant_patterns()
else:
self.patterns = get_quant_patterns()
# if is_dynamic_quant:
# self.patterns = get_dynamic_quant_patterns()
# else:
self.patterns = get_quant_patterns()

flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
Expand Down Expand Up @@ -391,7 +392,7 @@ def insert_observer(node, observer, device):
# observe custom module
custom_module = self.modules[node.target]
traced_custom_module = symbolic_trace(custom_module)
if self.is_dynamic_quant:
if activation_is_dynamically_quantized(qconfig):
prepare = torch.quantization.prepare_dynamic_child_module_fx
else:
prepare = torch.quantization.prepare_child_module_fx
Expand All @@ -405,7 +406,7 @@ def insert_observer(node, observer, device):


# don't need to insert observer for output in dynamic quantization
if self.is_dynamic_quant:
if activation_is_dynamically_quantized(qconfig):
continue

# inserting observers for output of observed module, or mark the output
Expand Down Expand Up @@ -530,13 +531,11 @@ def _run_weight_observers(self, observed):
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False, is_child_module=False):
self.restore_state(model)
# TODO: uncomment after deepcopy is fixed
# if not inplace:
# model = copy.deepcopy(model)
self.is_dynamic_quant = is_dynamic_quant
# run weight observers before inserting quant dequant nodes
# for dynamic quantization
if self.is_dynamic_quant:
self._run_weight_observers(model)
if not inplace:
model = copy.deepcopy(model)
# always run weight observers in the top level forward method
# for dynamically quantized ops
self._run_weight_observers(model)

# move to cpu since we only have quantized cpu kernels
model.eval().cpu()
Expand Down Expand Up @@ -637,7 +636,7 @@ def is_quantized(node):
result = self.quantized_graph.node_copy(node, load_non_quantized)
quantized = False
else:
result = obj.convert(self, node, load_arg)
result = obj.convert(self, node, load_arg, debug=debug)
if node.op == 'call_module' and is_observed_standalone_module(self.modules[node.target]):
quantized = self.modules[node.target]._output_is_observed
else:
Expand All @@ -653,7 +652,7 @@ def is_quantized(node):
quantized = is_quantized(node.args[0])

# output of dynamic quantization is not quantized
if self.is_dynamic_quant:
if activation_is_dynamically_quantized(qconfig):
quantized = False

if quantized:
Expand Down Expand Up @@ -877,7 +876,8 @@ def visit_arg(arg):
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
is_weight = True
if (not self.is_dynamic_quant) or is_weight:
if qconfig is not None and \
(activation_is_statically_quantized(qconfig) or is_weight):
# overwrite previous quant config
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
return visit_arg
Expand Down
35 changes: 35 additions & 0 deletions torch/quantization/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,38 @@ def get_next_qparams_idx(module, qparams):
qparam_full_path = key + str(idx)
inputs.append(graph.create_node('get_attr', qparam_full_path))
return graph.create_node('call_function', quantize_op, tuple(inputs), {})

def activation_is_dynamically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
dynamically quantized or not
"""
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype in [torch.float32, torch.float16]

def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
statically quantized or not
"""
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype in [torch.quint8, torch.qint8]

def weight_dtype(qconfig):
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype

def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
dynamically quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]

def get_linear_prepack_op_for_dtype(dtype):
if dtype == torch.float16:
return torch.ops.quantized.linear_prepack_fp16
elif dtype == torch.qint8:
return torch.ops.quantized.linear_prepack
else:
raise Exception("can't get linear prepack op for dtype:", dtype)
6 changes: 3 additions & 3 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,16 +655,16 @@ def checkGraphModeFxOp(self, model, inputs, quant_type,
self.assertEqual((result - result_debug).abs().max(), 0), \
'Expecting debug and non-debug option to produce identical result'

qgraph_to_check = qgraph_debug if debug else qgraph
if print_debug_info:
print()
print('quant type:', quant_type)
print('origianl graph module:', type(model))
self.printGraphModule(original)
print()
print('quantized graph module:', type(qgraph))
self.printGraphModule(qgraph)
print('quantized graph module:', type(qgraph_to_check))
self.printGraphModule(qgraph_to_check)
print()
qgraph_to_check = qgraph_debug if debug else qgraph
self.checkGraphModuleNodes(
qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)

Expand Down

0 comments on commit 3a7cbbb

Please sign in to comment.