Skip to content

Commit

Permalink
fx quant: fix types on _find_quants (#49616)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49616

Add types to `_find_quants` I/O and fix resulting errors,
needed for an upcoming bug fix.

Test Plan:
```
mypy torch/quantization
python test/test_quantization.py TestQuantizeFx
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25645719

fbshipit-source-id: 4bf788b55fd4fd086c83a4438b9c2df22b9cff49
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 22, 2020
1 parent 7c90b20 commit edce6b1
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def input_is_observed(arg):
load_arg, observed_node_names_set)

def insert_observer_for_input_arg_of_observed_node(
node: Node, observed_node_names_set: Set[str], quants: Dict[str, Any],
node: Node, observed_node_names_set: Set[str],
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]],
model: torch.nn.Module,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[str, str], observed_graph: Graph,
Expand Down Expand Up @@ -376,7 +377,8 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any,
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an DefaultQuantizeHandler object for each
quants = self._find_quants(model.graph, matches)
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
self._find_quants(model.graph, matches)

self.activation_post_process_map = dict()
env: Dict[Any, Any] = {}
Expand Down Expand Up @@ -547,7 +549,8 @@ def _convert(self, model: GraphModule, debug: bool = False,
model.graph, self.modules, self.patterns,
custom_module_classes=custom_module_classes)

quants = self._find_quants(model.graph, matches)
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
self._find_quants(model.graph, matches)

self.quantized_graph = Graph()
env: Dict[Any, Any] = {}
Expand All @@ -569,14 +572,9 @@ def load_non_quantized(n):
return env[n.name]

def load_quantized(n):
if n.name not in quant_env:
assert n.name in env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in float environment:' + str(env)
assert n.name in quants, \
'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
assert n.name in quant_env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in quant environment:' + str(quant_env)
return quant_env[n.name]

def load_x(n):
Expand Down Expand Up @@ -941,7 +939,7 @@ def is_standalone_module(node_target):
return match_map

def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
) -> Dict[str, Any]:
) -> Dict[str, Tuple[DefaultQuantizeHandler, Callable]]:
"""
Takes the nodes in the input graph and pending matches, and finds and
returns the input and output nodes which need to be quantized.
Expand All @@ -954,7 +952,7 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler),
activation_post_process (observer/fake_quantize module) constructor)
"""
quants: Dict[str, Any] = {}
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = {}

def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
Expand All @@ -969,15 +967,11 @@ def visit_arg(arg):
(activation_is_statically_quantized(qconfig) or is_weight):
act_post_process_ctr = qconfig.weight if is_weight else \
qconfig.activation
quants[arg.name] = (
DefaultQuantizeHandler(self, arg), qconfig, is_weight)
# overwrite the constructor from qconfig
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
matched_pattern,
act_post_process_ctr)
# overwrite previous activation post process constructor if
# necessary
quants[arg.name] = (
DefaultQuantizeHandler(self, arg), act_post_process_ctr)
return visit_arg
Expand Down

0 comments on commit edce6b1

Please sign in to comment.