Skip to content

Commit

Permalink
[quant][graphmode][fx] Merge all quantization mode
Browse files Browse the repository at this point in the history
Summary:
This PR merges all quantization mode and will only expose the following top level functions:
```
prepare_fx
prepare_qat_fx
convert_fx
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 184888bba530fefefcf0fac48ead6b87c095c6a2
Pull Request resolved: #45292
  • Loading branch information
jerryzh168 committed Sep 25, 2020
1 parent d9d21fe commit d7eb4ca
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 67 deletions.
9 changes: 3 additions & 6 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def test_functional_debug(self):
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 1
node_occurrence[weight_prepack_node] = 0
node_occurrence[quantized_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node=quantized_node,
expected_node_occurrence=node_occurrence,
debug=True)

Expand Down Expand Up @@ -224,10 +224,7 @@ def forward(self, x):
for debug in [True, False]:
node_occurrence = dict()
if weight_prepack_node:
if debug:
node_occurrence[weight_prepack_node] = 1
else:
node_occurrence[weight_prepack_node] = 0
node_occurrence[weight_prepack_node] = 0
m = ModuleClass(*module_constructor_inputs).eval()
m = symbolic_trace(m)
qconfig_dict = {"": float16_dynamic_qconfig}
Expand Down
100 changes: 71 additions & 29 deletions torch/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from torch.fx.graph import (
Node,
)
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd

from ..quantization_mappings import (
get_static_quant_module_class,
get_quantized_operator,
Expand All @@ -17,6 +20,11 @@
_parent_name,
quantize_node,
get_per_tensor_qparams,
activation_is_dynamically_quantized,
activation_is_statically_quantized,
weight_is_quantized,
weight_dtype,
get_linear_prepack_op_for_dtype,
)

from abc import ABC, abstractmethod
Expand Down Expand Up @@ -235,7 +243,7 @@ def convert(self, quantizer, node, load_arg, debug=False):
# for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLU(QuantizeHandler):
class LinearReLUQuantizeHandler(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
Expand All @@ -248,50 +256,76 @@ def __init__(self, quantizer, node):
self.linear = quantizer.modules[self.linear_node.target]

def convert(self, quantizer, node, load_arg, debug=False):
qconfig = quantizer.qconfig_map[node.name]
activation_statically_quantized = activation_is_statically_quantized(qconfig)
# TODO: debug option for linear module
if self.linear_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name]
# 1. attach output activation post process to linear module
if node.name in quantizer.activation_post_process_map:
# this is the static quantization case
output_activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.linear.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
output_activation_post_process = None

if output_activation_post_process:
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
float_linear_module = self.linear[1].activation_post_process
else:
float_linear_module = self.linear
float_linear_module.activation_post_process = output_activation_post_process

# 2. select corresponding quantized linear class for the float linear class
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
qlinear = torch.nn.quantized.Linear
qlinear = nnq.Linear if activation_statically_quantized else nnqd.Linear
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
assert activation_statically_quantized, \
'Only static quantization is supported for LinearReLU'
qlinear = torch.nn.intrinsic.quantized.LinearReLU
else:
raise Exception("unhandled linear type:", type(self.linear))
quantized = qlinear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
# activation needs to be quantized for static quantization
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {})
self.linear_node.target,
(load_arg(quantized=activation_statically_quantized)(self.linear_node.args[0]),), {})
elif self.linear_node.op == 'call_function':
if debug:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
quantized_input_idxs = []
if activation_statically_quantized:
quantized_input_idxs.append(0)
if weight_is_quantized(qconfig):
quantized_input_idxs.append(1)
args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
linear_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
root_module = quantizer.modules['']
return quantize_node(
root_module,
quantizer.quantized_graph,
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
# TODO: this code can be merged with dynamic linear code
if activation_statically_quantized:
# quantize output for statically quantized linear op
root_module = quantizer.modules['']
return quantize_node(
root_module,
quantizer.quantized_graph,
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
# output for dynamically quantized linear op is not quantized
return linear_out
else: # non-debug option
# linear args
# (x, weight, bias, ...)
args = load_arg(quantized=[0, 1])(self.linear_node.args)
weight_quantized = weight_is_quantized(qconfig)
linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1])

# get other arguments
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
# all args after bias, including bias
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
Expand All @@ -303,17 +337,24 @@ def convert(self, quantizer, node, load_arg, debug=False):
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = (weight, bias)
prepack_args = (linear_weight, bias)
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype(qconfig))
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
'call_function', prepack_op, prepack_args, {})
# construct linear input
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qlinear_args = (linear_input, packed_weight, scale, zero_point)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
if activation_statically_quantized:
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qlinear_args = (linear_input, packed_weight, scale, zero_point)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
else:
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
qlinear_args = (linear_input, packed_weight)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_dynamic, qlinear_args, kwargs)

@register_quant_pattern(torch.nn.BatchNorm2d)
@register_quant_pattern(torch.nn.BatchNorm3d)
Expand Down Expand Up @@ -537,7 +578,8 @@ class StandaloneModuleQuantizeHandler(QuantizeHandler):
"""
def convert(self, quantizer, node, load_arg, debug=False):
assert node.op == 'call_module'
if quantizer.is_dynamic_quant:
qconfig = quantizer.qconfig_map[node.name]
if activation_is_dynamically_quantized(qconfig):
convert = torch.quantization.convert_dynamic_child_module_fx
else:
convert = torch.quantization.convert_child_module_fx
Expand Down
44 changes: 20 additions & 24 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from .pattern_utils import (
is_match,
get_quant_patterns,
get_dynamic_quant_patterns,
)

from .standalone_module import (
Expand All @@ -44,6 +43,8 @@
from .utils import (
_parent_name,
quantize_node,
activation_is_dynamically_quantized,
activation_is_statically_quantized,
)

from collections import OrderedDict
Expand Down Expand Up @@ -307,11 +308,7 @@ def get_qconfig(module_name):
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant, is_child_module):
if not inplace:
model = copy.deepcopy(model)
self.is_dynamic_quant = is_dynamic_quant
if self.is_dynamic_quant:
self.patterns = get_dynamic_quant_patterns()
else:
self.patterns = get_quant_patterns()
self.patterns = get_quant_patterns()

flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
Expand Down Expand Up @@ -391,7 +388,7 @@ def insert_observer(node, observer, device):
# observe standalone module
standalone_module = self.modules[node.target]
traced_standalone_module = symbolic_trace(standalone_module)
if self.is_dynamic_quant:
if activation_is_dynamically_quantized(qconfig):
prepare = torch.quantization.prepare_dynamic_child_module_fx
else:
prepare = torch.quantization.prepare_child_module_fx
Expand All @@ -404,8 +401,9 @@ def insert_observer(node, observer, device):
self.modules[node.target] = observed_standalone_module


# don't need to insert observer for output in dynamic quantization
if self.is_dynamic_quant:
# don't need to insert observer for output if activation does not
# need to be statically quantized
if not activation_is_statically_quantized(qconfig):
continue

# inserting observers for output of observed module, or mark the output
Expand Down Expand Up @@ -509,10 +507,10 @@ def prepare_dynamic(self, model, qconfig_dict, inplace=False, is_child_module=Fa
return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True, is_child_module=is_child_module)

def _run_weight_observers(self, observed):
r''' Extract the subgraph that produces the weight for dynamically quantized
node and run the subgraph to observe the weight.
Note that the observers of dynamically quantized modules are run during
the conversion step.
r''' Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are run during
the convert step.
'''
for node in observed.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
Expand All @@ -530,13 +528,11 @@ def _run_weight_observers(self, observed):
def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False, is_child_module=False):
self.restore_state(model)
# TODO: uncomment after deepcopy is fixed
# if not inplace:
# model = copy.deepcopy(model)
self.is_dynamic_quant = is_dynamic_quant
# run weight observers before inserting quant dequant nodes
# for dynamic quantization
if self.is_dynamic_quant:
self._run_weight_observers(model)
if not inplace:
model = copy.deepcopy(model)
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
self._run_weight_observers(model)

# move to cpu since we only have quantized cpu kernels
model.eval().cpu()
Expand Down Expand Up @@ -637,7 +633,7 @@ def is_quantized(node):
result = self.quantized_graph.node_copy(node, load_non_quantized)
quantized = False
else:
result = obj.convert(self, node, load_arg)
result = obj.convert(self, node, load_arg, debug=debug)
if node.op == 'call_module' and is_observed_standalone_module(self.modules[node.target]):
quantized = self.modules[node.target]._output_is_observed
else:
Expand All @@ -652,8 +648,7 @@ def is_quantized(node):
'CopyNode of type ' + node.op + ' is not handled'
quantized = is_quantized(node.args[0])

# output of dynamic quantization is not quantized
if self.is_dynamic_quant:
if not activation_is_statically_quantized(qconfig):
quantized = False

if quantized:
Expand Down Expand Up @@ -882,7 +877,8 @@ def visit_arg(arg):
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
is_weight = True
if (not self.is_dynamic_quant) or is_weight:
if qconfig is not None and \
(activation_is_statically_quantized(qconfig) or is_weight):
# overwrite previous quant config
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
return visit_arg
Expand Down
35 changes: 35 additions & 0 deletions torch/quantization/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,38 @@ def get_next_qparams_idx(module, qparams):
qparam_full_path = key + str(idx)
inputs.append(graph.create_node('get_attr', qparam_full_path))
return graph.create_node('call_function', quantize_op, tuple(inputs), {})

def activation_is_dynamically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
dynamically quantized or not
"""
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype in [torch.float32, torch.float16]

def activation_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
statically quantized or not
"""
assert qconfig is not None
activation = qconfig.activation()
return activation.dtype in [torch.quint8, torch.qint8]

def weight_dtype(qconfig):
assert qconfig is not None
weight = qconfig.weight()
return weight.dtype

def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8]

def get_linear_prepack_op_for_dtype(dtype):
if dtype == torch.float16:
return torch.ops.quantized.linear_prepack_fp16
elif dtype == torch.qint8:
return torch.ops.quantized.linear_prepack
else:
raise Exception("can't get linear prepack op for dtype:", dtype)

0 comments on commit d7eb4ca

Please sign in to comment.