Skip to content

Commit

Permalink
[quant] Attach qconfig to all modules (#42576)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #42576

Previously we have qconfig propagate list and we only attach qconfig for modules
in the list, this works when everything is quantized in the form of module.
but now we are expanding quantization for functional/torch ops, we'll need to attach qconfig
to all modules

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D22939453

fbshipit-source-id: 7d6a1f73ff9bfe461b3afc75aa266fcc8f7db517
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Aug 12, 2020
1 parent e845b0a commit ac93d45
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 50 deletions.
14 changes: 0 additions & 14 deletions test/quantization/test_workflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
ModelWithNoQconfigPropagation,
AnnotatedSingleLayerLinearModel,
test_only_eval_fn,
)
Expand Down Expand Up @@ -391,18 +390,6 @@ def test_observer_scriptable(self):
loaded = torch.jit.load(buf)
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())

# TODO: move this to test_quantize.py
def test_no_qconfig_propagation(self):
model = ModelWithNoQconfigPropagation()
model.qconfig = torch.quantization.default_qconfig

model = prepare(model)
self.assertTrue(hasattr(model.fc1, 'qconfig'),
"QConfig is expected to propagate")
self.assertFalse(hasattr(model.no_quant_module, 'qconfig'),
"QConfig is expected to NOT propagate")


# HistogramObserver that works like it does on master
class _ReferenceHistogramObserver(HistogramObserver):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -541,7 +528,6 @@ def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
new_max = self.min_val + bin_width * (end_bin + 1)
return new_min, new_max


class TestRecordHistogramObserver(QuantizationTestCase):
# TODO: move this to quantize.py
def test_record_observer(self):
Expand Down
21 changes: 13 additions & 8 deletions torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def _propagate_qconfig_helper(module, qconfig_dict, white_list=None,
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
module_qconfig = getattr(module, 'qconfig', module_qconfig)

if type(module) in white_list:
module.qconfig = module_qconfig
module.qconfig = module_qconfig
for name, child in module.named_children():
module_prefix = prefix + '.' + name if prefix else name
_propagate_qconfig_helper(child, qconfig_dict, white_list,
Expand Down Expand Up @@ -85,7 +84,7 @@ def register_activation_post_process_hook(module):
'Expect activation_post_process attribut already attached to the module'
return module.register_forward_hook(_observer_forward_hook)

def add_observer_(module, non_leaf_module_list=None, device=None, prehook=None):
def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, prehook=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
Expand All @@ -99,6 +98,8 @@ def add_observer_(module, non_leaf_module_list=None, device=None, prehook=None):
Return:
None, module is modified inplace with added observer modules and forward_hooks
"""
if qconfig_propagation_list is None:
qconfig_propagation_list = DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST
# respect device affinity when adding observers
if device is None:
devices = get_unique_devices_(module)
Expand All @@ -125,12 +126,13 @@ def add_observer_(module, non_leaf_module_list=None, device=None, prehook=None):
child.add_module('activation_pre_process', prehook())
child.register_forward_pre_hook(_observer_forward_pre_hook)
else:
add_observer_(child, non_leaf_module_list, device, prehook)
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, prehook)

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
if hasattr(module, 'qconfig') and module.qconfig is not None and \
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential):
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
and type(module) in qconfig_propagation_list:
# observer and hook will be gone after we swap the module
activation = module.qconfig.activation()
if device is not None:
Expand Down Expand Up @@ -172,7 +174,7 @@ def add_quant_dequant(module):
module._modules[name] = add_quant_dequant(child)
return module

def prepare(model, inplace=False, white_list=DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST,
def prepare(model, inplace=False, white_list=None,
observer_non_leaf_module_list=None, prehook=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
Expand All @@ -191,15 +193,18 @@ def prepare(model, inplace=False, white_list=DEFAULT_QCONFIG_PROPAGATE_WHITE_LIS
"""
if not inplace:
model = copy.deepcopy(model)
propagate_qconfig_(model, qconfig_dict=None, white_list=white_list)
propagate_qconfig_list = white_list
if propagate_qconfig_list is None:
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST
propagate_qconfig_(model, qconfig_dict=None)

# sanity check common API misusage
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
warnings.warn("None of the submodule got qconfig applied. Make sure you "
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")

add_observer_(model, observer_non_leaf_module_list, prehook=prehook)
add_observer_(model, propagate_qconfig_list, observer_non_leaf_module_list, prehook=prehook)
return model

def _remove_qconfig(module):
Expand Down
37 changes: 9 additions & 28 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit
from torch.quantization.default_mappings import DEFAULT_DYNAMIC_MODULE_MAPPING
from torch.quantization.default_mappings import (
DEFAULT_DYNAMIC_MODULE_MAPPING,
DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST,
)
import unittest
from torch.testing import FileCheck

Expand Down Expand Up @@ -182,12 +185,15 @@ def checkHasPrepModules(self, module):
self.assertTrue(hasattr(module, 'quant'))
self.assertTrue(hasattr(module, 'dequant'))

def checkObservers(self, module):
def checkObservers(self, module, propagate_qconfig_list=None):
r"""Checks the module or module's leaf descendants
have observers in preperation for quantization
"""
if propagate_qconfig_list is None:
propagate_qconfig_list = DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST
if hasattr(module, 'qconfig') and module.qconfig is not None and \
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential):
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
and type(module) in propagate_qconfig_list:
self.assertTrue(hasattr(module, 'activation_post_process'),
'module: ' + str(type(module)) + ' do not have observer')
for child in module.children():
Expand Down Expand Up @@ -998,28 +1004,3 @@ def forward(self, x):
out = out.view(-1, 3 * 2 * 2)
out = self.fc(out)
return out

"""Model to make sure that the observers are not inserted into custom modules.
"""
class ModelWithNoQconfigPropagation(nn.Module):
class ListOutModule(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
# returns a list of tensors, not supported by observers
return [x]

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(5, 5).to(dtype=torch.float)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.no_quant_module = self.ListOutModule()

def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.dequant(x)
x = self.no_quant_module(x)
return x

0 comments on commit ac93d45

Please sign in to comment.