Skip to content

Commit

Permalink
fx quant: more typehints, part 2
Browse files Browse the repository at this point in the history
Summary:

Adds some more typehints throughout quantization/fx/quantize.py,
to help with readability.

Test Plan:

```
mypy torch/quantization/fx/quantize.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Dec 3, 2020
1 parent cc25808 commit b9ef984
Showing 1 changed file with 46 additions and 32 deletions.
78 changes: 46 additions & 32 deletions torch/quantization/fx/quantize.py
Expand Up @@ -447,15 +447,18 @@ def __init__(self):
self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None


def _qat_swap_modules(self, root, additional_qat_module_mapping):
def _qat_swap_modules(
self, root: torch.nn.Module,
additional_qat_module_mapping: Dict[Callable, Callable]) -> None:
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)

def _generate_qconfig_map(self,
root,
input_graph,
qconfig_dict):
def _generate_qconfig_map(
self,
root: torch.nn.Module,
input_graph: Graph,
qconfig_dict: Any) -> None:
global_qconfig = qconfig_dict.get('', None)

self.qconfig_map = dict()
Expand Down Expand Up @@ -495,8 +498,9 @@ def _generate_qconfig_map(self,
self.modules[node.target].qconfig = module_qconfig
self.qconfig_map[node.name] = module_qconfig

def _prepare(self, model, qconfig_dict, prepare_custom_config_dict,
is_standalone_module):
def _prepare(self, model: GraphModule, qconfig_dict: Any,
prepare_custom_config_dict: Optional[Dict[str, Any]],
is_standalone_module: bool) -> GraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Expand Down Expand Up @@ -534,6 +538,7 @@ def _prepare(self, model, qconfig_dict, prepare_custom_config_dict,
"standalone_module_class", None)
custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
assert self.patterns is not None
matches = self._find_matches(
model.graph, self.modules, self.patterns, standalone_module_names,
standalone_module_classes, custom_module_classes)
Expand All @@ -552,7 +557,7 @@ def load_arg(a):
return map_arg(a, lambda node: env[node.name])

# indexes for the inputs that needs to be observed
standalone_module_observed_input_idxs = []
standalone_module_observed_input_idxs: List[int] = []
graph_inputs = []
for node in model.graph.nodes:
if node.op == 'placeholder':
Expand Down Expand Up @@ -602,25 +607,28 @@ def load_arg(a):
model = mark_observed_module(model)
return model

def save_state(self, observed):
observed._activation_post_process_map = self.activation_post_process_map
observed._patterns = self.patterns
observed._qconfig_map = self.qconfig_map
def save_state(self, observed: GraphModule) -> None:
observed._activation_post_process_map = \
self.activation_post_process_map # type: ignore
observed._patterns = self.patterns # type: ignore
observed._qconfig_map = self.qconfig_map # type: ignore

def restore_state(self, observed):
def restore_state(self, observed: GraphModule) -> None:
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
self.activation_post_process_map = observed._activation_post_process_map
self.patterns = observed._patterns
self.qconfig_map = observed._qconfig_map

def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None,
is_standalone_module=False):
self.activation_post_process_map = \
observed._activation_post_process_map # type: ignore
self.patterns = observed._patterns # type: ignore
self.qconfig_map = observed._qconfig_map # type: ignore

def prepare(self, model: GraphModule, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
return self._prepare(
model, qconfig_dict, prepare_custom_config_dict,
is_standalone_module)

def _run_weight_observers(self, observed):
def _run_weight_observers(self, observed: GraphModule) -> None:
r''' Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are
Expand All @@ -640,8 +648,9 @@ def _run_weight_observers(self, observed):
weight_observer_module()
return

def _convert(self, model, debug=False, convert_custom_config_dict=None,
is_standalone_module=False):
def _convert(self, model: GraphModule, debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Expand All @@ -662,6 +671,7 @@ def _convert(self, model, debug=False, convert_custom_config_dict=None,
custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
assert self.patterns is not None
matches = self._find_matches(
model.graph, self.modules, self.patterns,
custom_module_classes=custom_module_classes)
Expand Down Expand Up @@ -905,7 +915,7 @@ def load_arg(a): # type: ignore
# Trace back from the weight node util we hit getattr, reconstruct the
# graph module with the traced nodes and run the graph module to pack the
# weight. then replace the original chain of ops with the packed weight.
def _fold_weight(self, quantized):
def _fold_weight(self, quantized: GraphModule) -> GraphModule:
packed_weights = dict()
# map from folded node name to the prepacked weight name
folded_nodes = dict()
Expand Down Expand Up @@ -951,19 +961,21 @@ def load_arg(a):
quantized = GraphModule(quantized_root, folded_graph)
return quantized

def convert(self, model, debug=False, convert_custom_config_dict=None,
is_standalone_module=False):
def convert(self, model: GraphModule, debug: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
quantized = self._convert(
model, debug, convert_custom_config_dict, is_standalone_module)
if not debug:
quantized = self._fold_weight(quantized)
return quantized

def _find_matches(
self, graph, modules, patterns,
standalone_module_names=None,
standalone_module_classes=None,
custom_module_classes=None) -> Dict[str, MatchResult]:
self, graph: Graph, modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Callable] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Expand Down Expand Up @@ -1017,7 +1029,7 @@ def record_match(pattern, node, matched):
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (
node, matched, pattern, value(self, node),
node, matched, pattern, value(self, node), # type: ignore
self.qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
Expand All @@ -1035,8 +1047,10 @@ def record_match(pattern, node, matched):

def is_standalone_module(node_target):
assert self.modules is not None
return node_target in standalone_module_names or \
type(self.modules[node_target]) in standalone_module_classes
return (
node_target in standalone_module_names or # type: ignore
type(self.modules[node_target]) in standalone_module_classes # type: ignore
)

# add standalone modules to the match
for node in graph.nodes:
Expand Down

0 comments on commit b9ef984

Please sign in to comment.