Skip to content

Commit

Permalink
[quant][graphmode][fx] custom_module support static/dynamic/weight_on…
Browse files Browse the repository at this point in the history
…ly quant

Summary:
Previously we only support static quant, this PR added support for other types of quantization.

Note qat is actually orthogonal to these quant types, this is referring to the convert step where we
convert the observed module to a quantized module.

for qat, user will provide a CustomModule -> FakeQuantizedCustomModule in prepare_custom_config_dict
and FakeQuantizedCustomModule -> static/dynamic/weight_only quantized CustomModule in convert_custom_config_dict.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 849441a7a8eec52c6510889df3bc1f8597a84546
Pull Request resolved: #46786
  • Loading branch information
jerryzh168 committed Oct 27, 2020
1 parent 7731370 commit 05d4721
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 69 deletions.
123 changes: 74 additions & 49 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -9,6 +9,7 @@
# graph mode quantization based on fx
from torch.quantization import (
QuantType,
quant_type_to_str,
prepare_fx,
convert_fx,
prepare_qat_fx,
Expand Down Expand Up @@ -632,104 +633,126 @@ def test_custom_module_class(self):
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
return self.conv(x)
return self.linear(x)

class ObservedCustomModule(torch.nn.Module):
def __init__(self, conv):
def __init__(self, linear):
super().__init__()
self.conv = conv
self.linear = linear

def forward(self, x):
return self.conv(x)
return self.linear(x)

@classmethod
def from_float(cls, float_module):
assert hasattr(float_module, 'qconfig')
observed = cls(float_module.conv)
observed = cls(float_module.linear)
observed.qconfig = float_module.qconfig
return observed

class QuantizedCustomModule(torch.nn.Module):
def __init__(self, conv):
class StaticQuantCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.conv = conv
self.linear = linear

def forward(self, x):
return self.conv(x)
return self.linear(x)

@classmethod
def from_observed(cls, observed_module):
assert hasattr(observed_module, 'qconfig')
assert hasattr(observed_module, 'activation_post_process')
observed_module.conv.activation_post_process = \
observed_module.linear.activation_post_process = \
observed_module.activation_post_process
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
quantized = cls(nnq.Linear.from_float(observed_module.linear))
return quantized

class DynamicallyQuantizedCustomModule(torch.nn.Module):
def __init__(self, conv):
class DynamicQuantCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.conv = conv
self.linear = linear

def forward(self, x):
return self.conv(x)
return self.linear(x)

@classmethod
def from_observed(cls, observed_module):
assert hasattr(observed_module, 'qconfig')
assert hasattr(observed_module, 'activation_post_process')
quantized = cls(nnqd.Conv2d.from_float(observed_module.conv))
quantized = cls(nnqd.Linear.from_float(observed_module.linear))
return quantized

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.linear = torch.nn.Linear(3, 3)
self.custom = CustomModule()

def forward(self, x):
x = self.conv(x)
x = self.linear(x)
x = self.custom(x)
return x

class RefM(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
self.linear1 = torch.nn.Linear(3, 3)
self.linear2 = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.linear1(x)
x = self.linear2(x)
return x

data = torch.randn(1, 1, 1, 1)
data = torch.randn(3, 3)
# instantiate M and RefM and align the parameters
original_m = M().eval()
original_ref_m = RefM().eval()
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())
original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach())
original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach())
original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach())
original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach())

test_configs = {
"static": (default_qconfig, StaticQuantCustomModule, 3),
"dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0)
}

# TODO: add other quant types after mixed mode support
for quant_type in [QuantType.STATIC]:
qconfig_dict = {
"": default_qconfig,
}
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
for quant_type in [QuantType.DYNAMIC]:
key = quant_type_to_str(quant_type)
qconfig, quantized_module_class, num_observers = test_configs[key]
qconfig_dict = {"": qconfig}
if key == "static":
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
CustomModule: ObservedCustomModule
}
}
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedCustomModule: quantized_module_class
}
}
}
}
else:
prepare_custom_config_dict = {
"non_traceable_module_class": [
CustomModule
]
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"dynamic": {
CustomModule: quantized_module_class
}
}
}

# check prepared model
m = prepare_fx(
original_m,
Expand All @@ -739,20 +762,22 @@ def forward(self, x):
m(data)
# all activation observers are inserted in the top level module
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 3
ns.call_module(torch.quantization.MinMaxObserver): num_observers
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)

# check converted/quantized model
m = convert_fx(
m,
convert_custom_config_dict=convert_custom_config_dict)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
if quant_type == QuantType.STATIC:
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Linear) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
self.assertEqual(type(m.custom), quantized_module_class)
res = m(data)

# quantize the reference model
Expand Down
2 changes: 1 addition & 1 deletion torch/quantization/__init__.py
Expand Up @@ -28,7 +28,7 @@ def default_eval_fn(model, calib_data):
# Top level API for graph mode quantization on GraphModule(torch.fx)
'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', # quantization type
'QuantType', 'quant_type_to_str', # quantization type
# custom module APIs
'get_default_static_quant_module_mappings', 'get_static_quant_module_class',
'get_default_dynamic_quant_module_mappings',
Expand Down
16 changes: 12 additions & 4 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -16,6 +16,7 @@
_parent_name,
quantize_node,
get_per_tensor_qparams,
get_swapped_custom_module_class,
activation_is_statically_quantized,
weight_is_quantized,
weight_dtype,
Expand Down Expand Up @@ -176,6 +177,12 @@ def __init__(self, quantizer, node):

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
# TODO: debug option for conv module
qconfig = quantizer.qconfig_map[node.name]
activation_statically_quantized = activation_is_statically_quantized(qconfig)
# only static qunatization (for both ptq and qat) is supported for conv
if not activation_statically_quantized:
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

if self.conv_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
Expand Down Expand Up @@ -587,13 +594,14 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
assert convert_custom_config_dict is not None
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None)
assert custom_module_class_mapping is not None
qconfig = quantizer.qconfig_map[node.name]
observed_custom_module = quantizer.modules[node.target]
if node.name in quantizer.activation_post_process_map:
if activation_is_statically_quantized(qconfig):
assert node.name in quantizer.activation_post_process_map
observed_custom_module.activation_post_process = \
quantizer.activation_post_process_map[node.name]
quantized_custom_module_class = custom_module_class_mapping.get(type(observed_custom_module), None)
assert quantized_custom_module_class is not None, "did not found quantized custom module for:" + \
str(type(observed_custom_module))
quantized_custom_module_class = get_swapped_custom_module_class(
observed_custom_module, custom_module_class_mapping, qconfig)
quantized_custom_module = \
quantized_custom_module_class.from_observed(observed_custom_module)
parent_name, name = _parent_name(node.target)
Expand Down
22 changes: 13 additions & 9 deletions torch/quantization/fx/quantize.py
Expand Up @@ -36,6 +36,8 @@
from .utils import (
_parent_name,
quantize_node,
get_custom_module_class_keys,
get_swapped_custom_module_class,
activation_is_statically_quantized,
)

Expand Down Expand Up @@ -347,9 +349,9 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_

# match the patterns that will get quantized
standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None)
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", None)
custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class")
matches = self._find_matches(
model.graph, self.modules, self.patterns, standalone_module_names, custom_module_class_mapping)
model.graph, self.modules, self.patterns, standalone_module_names, custom_module_classes)

# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
Expand Down Expand Up @@ -403,8 +405,9 @@ def insert_observer(node, observer, device):

if isinstance(obj, CustomModuleQuantizeHandler):
custom_module = self.modules[node.target]
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
observed_custom_module_class = \
custom_module_class_mapping[type(custom_module)]
get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
Expand Down Expand Up @@ -569,10 +572,11 @@ def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict
model.eval().cpu()
self.modules = dict(model.named_modules())

custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None)
custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict, "observed_to_quantized_custom_module_class")
matches = self._find_matches(
model.graph, self.modules, self.patterns,
custom_module_class_mapping=custom_module_class_mapping)
custom_module_classes=custom_module_classes)

quants = self._find_quants(model.graph, matches)

Expand Down Expand Up @@ -818,7 +822,7 @@ def convert(self, model, inplace=False, debug=False, convert_custom_config_dict=

def _find_matches(
self, graph, modules, patterns,
standalone_module_names=None, custom_module_class_mapping=None):
standalone_module_names=None, custom_module_classes=None):
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Expand All @@ -839,8 +843,8 @@ def _find_matches(
...
}
"""
if custom_module_class_mapping is None:
custom_module_class_mapping = {}
if custom_module_classes is None:
custom_module_classes = []

match_map = {}
all_matched = set()
Expand Down Expand Up @@ -870,7 +874,7 @@ def record_match(pattern, node, matched):
# add custom module instances to the match result
for node in graph.nodes:
if node.op == 'call_module' and \
type(self.modules[node.target]) in custom_module_class_mapping:
type(self.modules[node.target]) in custom_module_classes:
custom_module_qconfig = self.qconfig_map[node.name]
match_map[node.name] = (
node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig)
Expand Down

0 comments on commit 05d4721

Please sign in to comment.