Skip to content

Commit

Permalink
[quant][graphmode][fx] Merge all quantization mode (#45292)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45292

This PR merges all quantization mode and will only expose the following top level functions:
```
prepare_fx
prepare_qat_fx
convert_fx
```

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D23913105

fbshipit-source-id: 4e335286d6de225839daf51d1df54322d52d68e5
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 1, 2020
1 parent 3f440d7 commit ffcb098
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 383 deletions.
80 changes: 41 additions & 39 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
fuse_fx,
prepare_fx,
convert_fx,
prepare_static_fx,
convert_static_fx,
quantize_static_fx,
quantize_dynamic_fx,
prepare_qat_fx,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
Expand Down Expand Up @@ -158,11 +154,11 @@ def test_functional_debug(self):
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 1
node_occurrence[weight_prepack_node] = 0
node_occurrence[quantized_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node=quantized_node,
expected_node_occurrence=node_occurrence,
debug=True)

Expand All @@ -183,7 +179,8 @@ def forward(self, x):
original = symbolic_trace(m)
qconfig = default_dynamic_qconfig
qconfig_dict = {'': qconfig}
quantized = quantize_dynamic_fx(original, qconfig_dict, debug=True)
prepared = prepare_fx(original, qconfig_dict)
quantized = convert_fx(prepared, debug=True)
qparams = (quantized._scale_0, quantized._zero_point_0)
weight_obs = qconfig.weight()
weight_obs(quantized.weight)
Expand Down Expand Up @@ -226,14 +223,12 @@ def forward(self, x):
for debug in [True, False]:
node_occurrence = dict()
if weight_prepack_node:
if debug:
node_occurrence[weight_prepack_node] = 1
else:
node_occurrence[weight_prepack_node] = 0
node_occurrence[weight_prepack_node] = 0
m = ModuleClass(*module_constructor_inputs).eval()
m = symbolic_trace(m)
qconfig_dict = {"": float16_dynamic_qconfig}
m = quantize_dynamic_fx(m, qconfig_dict, debug=debug)
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m, debug=debug)
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)


Expand Down Expand Up @@ -293,13 +288,19 @@ def __init__(self):
def forward(self, x):
return self.conv(x)

model = symbolic_trace(M().eval())
model = M().eval()
model = symbolic_trace(model)
qconfig_dict = {'': default_qconfig}
non_inplace_model = quantize_static_fx(
model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=False)
inplace_model = model
inplace_model = quantize_static_fx(
inplace_model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True)
prepared = prepare_fx(
model, qconfig_dict, inplace=False)
test_only_eval_fn(model, self.img_data_2d)
non_inplace_model = convert_fx(prepared, inplace=True)

prepared = prepare_fx(
model, qconfig_dict, inplace=True)
test_only_eval_fn(model, self.img_data_2d)
inplace_model = convert_fx(prepared, inplace=True)

non_inplace_res = non_inplace_model(self.img_data_2d[0][0])
inplace_res = inplace_model(self.img_data_2d[0][0])
self.assertEqual(non_inplace_res, inplace_res)
Expand All @@ -319,9 +320,9 @@ def forward(self, x):
dict_input = {"input": torch.randn(1, 1, 1, 1)}
m = symbolic_trace(M()).eval()
qconfig_dict = {"": default_qconfig}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
m(dict_input)
m = convert_static_fx(m)
m = convert_fx(m)
m(dict_input)

def test_standalone_module_class(self):
Expand Down Expand Up @@ -431,10 +432,10 @@ def forward(self, x):
m = symbolic_trace(m)
qconfig_dict = {"": default_qconfig,
"module_name": [("conv2", None)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand All @@ -460,10 +461,10 @@ def forward(self, x):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand All @@ -485,10 +486,10 @@ def forward(self, x, y):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data, data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data, data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand All @@ -513,10 +514,10 @@ def forward(self, x):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand Down Expand Up @@ -558,7 +559,7 @@ def forward(self, x):
"object_type": [(nn.Conv2d, object_type_qconfig)],
"module_name_regex": [("module_conv*", module_name_regex_qconfig)],
"module_name": [("module_conv2", module_name_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
self.assertEqual(m.linear.qconfig, global_qconfig)
self.assertEqual(m.conv.qconfig, object_type_qconfig)
self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig)
Expand All @@ -577,10 +578,10 @@ def forward(self, x):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {'': default_qconfig}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
for name, module in m.named_modules():
self.assertFalse(hasattr(module, 'qconfig'),
Expand Down Expand Up @@ -632,12 +633,13 @@ def test_save_observer_state_dict(self):
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
# symbolically trace
model = symbolic_trace(model)
model = prepare_static_fx(model, qconfig_dict)
model = prepare_fx(model, qconfig_dict)

# run it through input
x = torch.randn(5, 5)
model(x)

quant = convert_static_fx(model)
quant = convert_fx(model)

# save state_dict of model
obs_dict = torch.quantization.get_observer_state_dict(model)
Expand All @@ -648,12 +650,12 @@ def test_save_observer_state_dict(self):
# Load the stats into new model
model_2 = orig
model_2 = symbolic_trace(model_2)
model_2 = prepare_static_fx(model_2, qconfig_dict)
model_2 = prepare_fx(model_2, qconfig_dict)

loaded_dict = torch.load(b)
torch.quantization.load_observer_state_dict(model_2, loaded_dict)

quant_2 = convert_static_fx(model_2)
quant_2 = convert_fx(model_2)

# Verify that loaded state dict produces same results.
self.assertEqual(quant(x), quant_2(x))
Expand Down Expand Up @@ -765,7 +767,7 @@ def is_leaf_module(self, m, module_qualified_name):
m = CustomTracer().trace(original_m).eval()
qconfig_dict = {'': default_qconfig}
# check prepared model
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module
Expand All @@ -775,7 +777,7 @@ def is_leaf_module(self, m, module_qualified_name):
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)

# check converted/quantized model
m = convert_static_fx(m)
m = convert_fx(m)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
Expand Down Expand Up @@ -1347,7 +1349,7 @@ def forward(self, x):
data = torch.rand(1, 3, 10, 10)
# This model is not executable since we just put all ops
# in the same forward
m = M()
m = M().eval()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
Expand Down Expand Up @@ -1442,7 +1444,7 @@ def forward(self, x):

# This model is not executable since we just put all ops
# in the same forward
m = M()
m = M().eval()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
Expand Down
12 changes: 0 additions & 12 deletions torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@ def insert(fn):
def get_quant_patterns():
return QUANTIZATION_PATTERNS

DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict()
# Register pattern for dynamic quantization
def register_dynamic_quant_pattern(pattern):
def insert(fn):
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

# Get patterns for dynamic quantization
def get_dynamic_quant_patterns():
return DYNAMIC_QUANTIZATION_PATTERNS

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
Expand Down

0 comments on commit ffcb098

Please sign in to comment.