Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Merge all quantization mode #45292

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ab5a2a9
[quant][graphmode][fx] Merge all quantization mode
jerryzh168 Sep 24, 2020
37fbe18
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
07eb451
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
0c151b5
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
b843407
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
7ff4192
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
ecd6c0d
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
a2b89c2
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
e8d7249
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
cb35222
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 25, 2020
2fc0268
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 26, 2020
405a051
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 28, 2020
d3f8352
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 29, 2020
1063712
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 29, 2020
ca2c71f
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 29, 2020
87f8d7e
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 29, 2020
1e74234
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 29, 2020
1a7d246
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 29, 2020
0625b25
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 30, 2020
8541de3
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 30, 2020
1d191e0
Update on "[quant][graphmode][fx] Merge all quantization mode"
jerryzh168 Sep 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 41 additions & 39 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
fuse_fx,
prepare_fx,
convert_fx,
prepare_static_fx,
convert_static_fx,
quantize_static_fx,
quantize_dynamic_fx,
prepare_qat_fx,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
Expand Down Expand Up @@ -158,11 +154,11 @@ def test_functional_debug(self):
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 1
node_occurrence[weight_prepack_node] = 0
node_occurrence[quantized_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node=quantized_node,
expected_node_occurrence=node_occurrence,
debug=True)

Expand All @@ -183,7 +179,8 @@ def forward(self, x):
original = symbolic_trace(m)
qconfig = default_dynamic_qconfig
qconfig_dict = {'': qconfig}
quantized = quantize_dynamic_fx(original, qconfig_dict, debug=True)
prepared = prepare_fx(original, qconfig_dict)
quantized = convert_fx(prepared, debug=True)
qparams = (quantized._scale_0, quantized._zero_point_0)
weight_obs = qconfig.weight()
weight_obs(quantized.weight)
Expand Down Expand Up @@ -226,14 +223,12 @@ def forward(self, x):
for debug in [True, False]:
node_occurrence = dict()
if weight_prepack_node:
if debug:
node_occurrence[weight_prepack_node] = 1
else:
node_occurrence[weight_prepack_node] = 0
node_occurrence[weight_prepack_node] = 0
m = ModuleClass(*module_constructor_inputs).eval()
m = symbolic_trace(m)
qconfig_dict = {"": float16_dynamic_qconfig}
m = quantize_dynamic_fx(m, qconfig_dict, debug=debug)
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m, debug=debug)
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)


Expand Down Expand Up @@ -293,13 +288,19 @@ def __init__(self):
def forward(self, x):
return self.conv(x)

model = symbolic_trace(M().eval())
model = M().eval()
model = symbolic_trace(model)
qconfig_dict = {'': default_qconfig}
non_inplace_model = quantize_static_fx(
model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=False)
inplace_model = model
inplace_model = quantize_static_fx(
inplace_model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True)
prepared = prepare_fx(
model, qconfig_dict, inplace=False)
test_only_eval_fn(model, self.img_data_2d)
non_inplace_model = convert_fx(prepared, inplace=True)

prepared = prepare_fx(
model, qconfig_dict, inplace=True)
test_only_eval_fn(model, self.img_data_2d)
inplace_model = convert_fx(prepared, inplace=True)

non_inplace_res = non_inplace_model(self.img_data_2d[0][0])
inplace_res = inplace_model(self.img_data_2d[0][0])
self.assertEqual(non_inplace_res, inplace_res)
Expand All @@ -319,9 +320,9 @@ def forward(self, x):
dict_input = {"input": torch.randn(1, 1, 1, 1)}
m = symbolic_trace(M()).eval()
qconfig_dict = {"": default_qconfig}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
m(dict_input)
m = convert_static_fx(m)
m = convert_fx(m)
m(dict_input)

def test_standalone_module_class(self):
Expand Down Expand Up @@ -431,10 +432,10 @@ def forward(self, x):
m = symbolic_trace(m)
qconfig_dict = {"": default_qconfig,
"module_name": [("conv2", None)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand All @@ -460,10 +461,10 @@ def forward(self, x):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand All @@ -485,10 +486,10 @@ def forward(self, x, y):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data, data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data, data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand All @@ -513,10 +514,10 @@ def forward(self, x):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
Expand Down Expand Up @@ -558,7 +559,7 @@ def forward(self, x):
"object_type": [(nn.Conv2d, object_type_qconfig)],
"module_name_regex": [("module_conv*", module_name_regex_qconfig)],
"module_name": [("module_conv2", module_name_qconfig)]}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
self.assertEqual(m.linear.qconfig, global_qconfig)
self.assertEqual(m.conv.qconfig, object_type_qconfig)
self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig)
Expand All @@ -577,10 +578,10 @@ def forward(self, x):
m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {'': default_qconfig}
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_static_fx(m)
m = convert_fx(m)
m(data)
for name, module in m.named_modules():
self.assertFalse(hasattr(module, 'qconfig'),
Expand Down Expand Up @@ -632,12 +633,13 @@ def test_save_observer_state_dict(self):
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
# symbolically trace
model = symbolic_trace(model)
model = prepare_static_fx(model, qconfig_dict)
model = prepare_fx(model, qconfig_dict)

# run it through input
x = torch.randn(5, 5)
model(x)

quant = convert_static_fx(model)
quant = convert_fx(model)

# save state_dict of model
obs_dict = torch.quantization.get_observer_state_dict(model)
Expand All @@ -648,12 +650,12 @@ def test_save_observer_state_dict(self):
# Load the stats into new model
model_2 = orig
model_2 = symbolic_trace(model_2)
model_2 = prepare_static_fx(model_2, qconfig_dict)
model_2 = prepare_fx(model_2, qconfig_dict)

loaded_dict = torch.load(b)
torch.quantization.load_observer_state_dict(model_2, loaded_dict)

quant_2 = convert_static_fx(model_2)
quant_2 = convert_fx(model_2)

# Verify that loaded state dict produces same results.
self.assertEqual(quant(x), quant_2(x))
Expand Down Expand Up @@ -765,7 +767,7 @@ def is_leaf_module(self, m, module_qualified_name):
m = CustomTracer().trace(original_m).eval()
qconfig_dict = {'': default_qconfig}
# check prepared model
m = prepare_static_fx(m, qconfig_dict)
m = prepare_fx(m, qconfig_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module
Expand All @@ -775,7 +777,7 @@ def is_leaf_module(self, m, module_qualified_name):
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)

# check converted/quantized model
m = convert_static_fx(m)
m = convert_fx(m)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
Expand Down Expand Up @@ -1347,7 +1349,7 @@ def forward(self, x):
data = torch.rand(1, 3, 10, 10)
# This model is not executable since we just put all ops
# in the same forward
m = M()
m = M().eval()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
Expand Down Expand Up @@ -1442,7 +1444,7 @@ def forward(self, x):

# This model is not executable since we just put all ops
# in the same forward
m = M()
m = M().eval()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
Expand Down
12 changes: 0 additions & 12 deletions torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@ def insert(fn):
def get_quant_patterns():
return QUANTIZATION_PATTERNS

DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict()
# Register pattern for dynamic quantization
def register_dynamic_quant_pattern(pattern):
def insert(fn):
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

# Get patterns for dynamic quantization
def get_dynamic_quant_patterns():
return DYNAMIC_QUANTIZATION_PATTERNS

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
Expand Down