Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx quant: clean up functions in _generate_qconfig_map #48772

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 46 additions & 46 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,44 @@ def _convert_to_ordered_dict(key, qconfig_dict):
_convert_to_ordered_dict('module_name_regex', qconfig_dict)
_convert_to_ordered_dict('module_name', qconfig_dict)

def get_module_type_qconfig(qconfig_dict, module_type, fallback_qconfig):
return qconfig_dict['object_type'].get(
module_type, fallback_qconfig)

def get_function_qconfig(qconfig_dict, function, fallback_qconfig):
return qconfig_dict['object_type'].get(function, fallback_qconfig)

def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig):
for regex_pattern, qconfig in \
qconfig_dict['module_name_regex'].items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig

def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_dict['module_name']:
return qconfig_dict['module_name'][module_name]
else:
parent, _ = _parent_name(module_name)
return get_module_name_qconfig(qconfig_dict, parent, fallback_qconfig)

# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
assert modules is not None
module_type_qconfig = get_module_type_qconfig(
qconfig_dict, type(modules[module_name]), global_qconfig)
module_name_regex_qconfig = get_module_name_regex_qconfig(
qconfig_dict, module_name, module_type_qconfig)
module_name_qconfig = get_module_name_qconfig(
qconfig_dict, module_name, module_name_regex_qconfig)
return module_name_qconfig

# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv2d : [1],
Expand Down Expand Up @@ -262,58 +300,18 @@ def _generate_qconfig_map(self,
qconfig_dict):
global_qconfig = qconfig_dict.get('', None)

def get_module_type_qconfig(
module_type, fallback_qconfig=global_qconfig):
return qconfig_dict['object_type'].get(
module_type, fallback_qconfig)

def get_function_qconfig(
function, fallback_qconfig=global_qconfig):
return qconfig_dict['object_type'].get(function, fallback_qconfig)

def get_module_name_regex_qconfig(
module_name, fallback_qconfig=global_qconfig):
for regex_pattern, qconfig in \
qconfig_dict['module_name_regex'].items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig

def get_module_name_qconfig(
module_name, fallback_qconfig=global_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_dict['module_name']:
return qconfig_dict['module_name'][module_name]
else:
parent, _ = _parent_name(module_name)
return get_module_name_qconfig(parent, fallback_qconfig)

# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
def get_qconfig(module_name):
assert self.modules is not None
module_type_qconfig = \
get_module_type_qconfig(type(self.modules[module_name]))
module_name_regex_qconfig = \
get_module_name_regex_qconfig(module_name, module_type_qconfig)
module_name_qconfig = \
get_module_name_qconfig(module_name, module_name_regex_qconfig)
return module_name_qconfig

self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == 'get_attr':
module_name, _ = _parent_name(node.target)
self.qconfig_map[node.name] = get_qconfig(module_name)
self.qconfig_map[node.name] = get_qconfig(
self.modules, qconfig_dict, module_name, global_qconfig)
elif node.op == 'call_function':
# precedence: [TODO] module_name_qconfig (need scope support
# from fx)
# > function_qconfig > global_qconfig
function_qconfig = get_function_qconfig(node.target)
function_qconfig = get_function_qconfig(
qconfig_dict, node.target, global_qconfig)
self.qconfig_map[node.name] = function_qconfig
elif node.op == 'call_method':
self_obj = node.args[0]
Expand All @@ -326,10 +324,12 @@ def get_qconfig(module_name):
warnings.warn(
"Scope info is not yet supported, taking default " +
"qconfig for value {}".format(node.name))
qconfig = get_qconfig('')
qconfig = get_qconfig(
self.modules, qconfig_dict, '', global_qconfig)
self.qconfig_map[node.name] = qconfig
elif node.op == 'call_module':
module_qconfig = get_qconfig(node.target)
module_qconfig = get_qconfig(
self.modules, qconfig_dict, node.target, global_qconfig)
# regex is not supported eager mode propagate_qconfig_, we'll
# need to set the qconfig explicitly here in case regex
# is used
Expand Down