Skip to content

Commit

Permalink
[quant][graphmode][fx] Support non_traceable_module/module_class
Browse files Browse the repository at this point in the history
Summary:
Allow user to specify a list of qualified names for non traceable submodule
or type of the non traceable submodule
See quantize_fx.py for api

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 14, 2020
1 parent 8ce1bc6 commit a881f90
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 14 deletions.
53 changes: 53 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -761,6 +761,59 @@ def forward(self, x):
ref_res = ref_m(data)
self.assertEqual(res, ref_res)

@skipIfNoFBGEMM
def test_non_traceable_module(self):
class NonTraceable(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
for k in x.keys():
print(x[k])
return x

class NonTraceable2(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
# data dependent control flow is not traceable
for i in x:
print(i)
return x

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.m1 = NonTraceable()
self.m2 = NonTraceable2()

def forward(self, x):
x = self.m1(x)
x = self.m2(x)
return x

m = M().eval()
qconfig_dict = {"": default_qconfig}
prepare_custom_config_dict = {
"non_traceable_module_name": [
"m1"
],
"non_traceable_module_class": [
NonTraceable2
]
}
m = prepare_fx(
m, qconfig_dict,
prepare_custom_config_dict=prepare_custom_config_dict)

node_occurrence = {
ns.call_module(NonTraceable) : 1,
ns.call_module(NonTraceable2) : 1,
}
# make sure these modules are not traced
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)

class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
Expand Down
40 changes: 26 additions & 14 deletions torch/quantization/quantize_fx.py
Expand Up @@ -24,16 +24,16 @@ def _fuse_fx(graph_module, inplace=False):
return fuser.fuse(graph_module, inplace)

class CustomTracer(Tracer):
def __init__(self, standalone_modules, custom_module_classes):
def __init__(self, skipped_module_names, skipped_module_classes):
super().__init__()
self.standalone_modules = standalone_modules
self.custom_module_classes = custom_module_classes
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes

def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
not isinstance(m, torch.nn.Sequential)) or \
module_qualified_name in self.standalone_modules or \
type(m) in self.custom_module_classes
module_qualified_name in self.skipped_module_names or \
type(m) in self.skipped_module_classes


def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, is_standalone_module=False):
Expand All @@ -50,17 +50,19 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}

skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", [])
skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", [])

# symbolically trace the model
if is_standalone_module:
# standlone module is traced before quantizing standalone modules
graph_module = symbolic_trace(model)
else:
standalone_modules = prepare_custom_config_dict.get('standalone_module_name', [])
if not is_standalone_module:
# standalone module and custom module config are applied in top level module
standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', [])
skipped_module_names += standalone_module_names
custom_module_config = prepare_custom_config_dict.get('float_to_observed_custom_module_class', {})
custom_module_classes = list(custom_module_config.keys())
# skipping tracing standalone modules when tracing top level module
tracer = CustomTracer(standalone_modules, custom_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
skipped_module_classes += custom_module_classes
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, inplace)
quantizer = Quantizer()
return quantizer.prepare(
Expand Down Expand Up @@ -156,7 +158,17 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No
# float custom module to observed custom module
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
},
# the qualified name for the submodule that will skip symbolic trace
"non_traceable_module_name": [
"non_traceable_module"
],
# the module class that will skip symbolic trace
"non_traceable_module_class": [
NonTraceableModule
]
}
Expand Down

0 comments on commit a881f90

Please sign in to comment.