Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Support quantization for standalone module #44074

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
510a368
[quant][graphmode][fx] Support quantization for custom module
jerryzh168 Sep 2, 2020
3c083a1
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 2, 2020
e671e62
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 3, 2020
7637725
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 8, 2020
f61dc18
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 8, 2020
1007c5c
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 9, 2020
a286d90
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 9, 2020
a9f935e
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 9, 2020
8d02562
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 9, 2020
a0f666c
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 9, 2020
7673461
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
b802e20
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
66f4b1b
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
0a35c1c
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
500f4a1
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
4a44531
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
6c0bc37
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 10, 2020
0e8e655
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 22, 2020
11ac09c
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 22, 2020
62ba7ed
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 23, 2020
9ad9402
Update on "[quant][graphmode][fx] Support quantization for custom mod…
jerryzh168 Sep 23, 2020
7400339
Update on "[quant][graphmode][fx] Support quantization for traceable …
jerryzh168 Sep 23, 2020
e9a6ee5
Update on "[quant][graphmode][fx] Support quantization for traceable …
jerryzh168 Sep 23, 2020
72dc90b
Update on "[quant][graphmode][fx] Support quantization for traceable …
jerryzh168 Sep 24, 2020
f5d1072
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
8ceac95
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
7220244
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
71d4d1e
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
657a796
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
3e7deee
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
dc95134
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
4037602
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
a7398c5
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 25, 2020
d8801b8
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 26, 2020
326034b
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 28, 2020
f451b2b
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 29, 2020
425b7d9
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 29, 2020
5c65719
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 29, 2020
b6b9998
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 29, 2020
6587a4b
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 29, 2020
7dd90f4
Update on "[quant][graphmode][fx] Support quantization for standalone…
jerryzh168 Sep 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -9,6 +9,8 @@
# symbolic trace
from torch.fx import symbolic_trace

from torch.fx.symbolic_trace import Tracer

# graph mode quantization based on fx
from torch.quantization import (
QuantType,
Expand Down Expand Up @@ -320,6 +322,96 @@ def forward(self, x):
m = convert_static_fx(m)
m(dict_input)

def test_standalone_module_class(self):
class StandaloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
return self.conv(x)

class CustomTracer(Tracer):
def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
not isinstance(m, torch.nn.Sequential)) or \
isinstance(m, StandaloneModule)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.standalone = StandaloneModule()

def forward(self, x):
x = self.conv(x)
x = self.standalone(x)
return x

class RefM(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x

data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M()
original_ref_m = RefM()
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

m = CustomTracer().trace(original_m).eval()
qconfig_dict = {'': default_qconfig, 'standalone_module_name': ['standalone']}
# check prepared model
m = prepare_fx(m, qconfig_dict)
# calibration
m(data)
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
# for output of conv in the standalone module
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 1
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)

# check converted/quantized model
m = convert_fx(m)
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
count_check = {
# quantization of input happens in parent module
# quantization of output happens in the quantized conv module
ns.call_function(torch.quantize_per_tensor) : 0,
# dequantization for output happens in parent module
ns.call_method('dequantize') : 0,
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check)
res = m(data)

# quantize the reference model
ref_m = symbolic_trace(original_ref_m).eval()
ref_m = prepare_fx(ref_m, qconfig_dict)
ref_m(data)
ref_m = convert_fx(ref_m)
ref_res = ref_m(data)
self.assertEqual(res, ref_res)

@skipIfNoFBGEMM
def test_qconfig_none(self):
class M(torch.nn.Module):
Expand Down
18 changes: 18 additions & 0 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -531,6 +531,24 @@ def convert(self, quantizer, node, load_arg, debug=False):
# module attribute like module._QUANTIZED_INPUT_INDEXES
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))

class StandaloneModuleQuantizeHandler(QuantizeHandler):
""" Converts an observed standalone module to quantized standalone module
by calling convert_fx on the observed standalone module.
"""
def convert(self, quantizer, node, load_arg, debug=False):
assert node.op == 'call_module'
if quantizer.is_dynamic_quant:
convert = torch.quantization.convert_dynamic_child_module_fx
else:
convert = torch.quantization.convert_child_module_fx
observed_standalone_module = quantizer.modules[node.target]
quantized_standalone_module = convert(observed_standalone_module, debug=debug)
parent_name, name = _parent_name(node.target)
# update the modules dict
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
quantizer.modules[node.target] = quantized_standalone_module
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))


# 2. Post Training Dynamic Quantizatoin Patterns
@register_dynamic_quant_pattern(torch.nn.Linear)
Expand Down