Skip to content

Commit

Permalink
quant: add type annotations on quantization.fx.Quantizer matches
Browse files Browse the repository at this point in the history
Summary:

As titled, continuing to incrementally type quantization.fx.Quantizer.

Test Plan:

```
mypy torch/quantization/
python test/test_quantization.py TestQuantizeFx
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 020fe46b33bb9c54e9c0f9bbd20d1972f9dbb8e8
Pull Request resolved: #48350
  • Loading branch information
vkuzo committed Nov 21, 2020
1 parent 7e346ba commit e6ebe99
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torch/quantization/fx/quantize.py
Expand Up @@ -57,11 +57,13 @@
import warnings
import re

from typing import Optional, Dict, Any, List, Union
from typing import Optional, Dict, Any, List, Union, Tuple

# Define helper types

QConfigAny = Union[torch.quantization.QConfig, torch.quantization.QConfigDynamic]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]

# ------------------------
# Helper Functions
Expand Down Expand Up @@ -449,6 +451,7 @@ def insert_observer_for_output_of_the_node(
if activation_is_statically_quantized(qconfig):
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and model.training:
# we only insert fake quantize module in qat
assert pattern is not None
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(pattern, None)
assert activation_post_process_ctr is not None, \
Expand All @@ -475,6 +478,7 @@ def is_observed(input_arg):
observed_node_names_set.add(node.name)
elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and
quantize_handler.num_node_args == 1):
assert matched_nodes is not None
input_node = matched_nodes[-1] # first node in the sequence

def input_is_observed(arg):
Expand Down Expand Up @@ -766,6 +770,7 @@ def insert_quantize_node(node):
result = self.quantized_graph.node_copy(node, load_non_quantized)
quantized = False
else:
assert obj is not None
result = obj.convert(self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict)
quantized = is_output_quantized(node)

Expand Down Expand Up @@ -869,7 +874,7 @@ def _find_matches(
self, graph, modules, patterns,
standalone_module_names=None,
standalone_module_classes=None,
custom_module_classes=None):
custom_module_classes=None) -> Dict[str, MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Expand Down Expand Up @@ -899,7 +904,7 @@ def _find_matches(
if standalone_module_names is None:
standalone_module_names = []

match_map: Dict[Any, Any] = {}
match_map: Dict[str, MatchResult] = {}
all_matched = set()

def record_match(pattern, node, matched):
Expand Down

0 comments on commit e6ebe99

Please sign in to comment.