Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx quant: fix types on _find_quants #49616

Closed
wants to merge 5 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 12 additions & 14 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def input_is_observed(arg):
load_arg, observed_node_names_set)

def insert_observer_for_input_arg_of_observed_node(
node: Node, observed_node_names_set: Set[str], quants: Dict[str, Any],
node: Node, observed_node_names_set: Set[str],
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]],
model: torch.nn.Module,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[str, str], observed_graph: Graph,
Expand Down Expand Up @@ -378,7 +379,8 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any,
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an DefaultQuantizeHandler object for each
quants = self._find_quants(model.graph, matches)
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
self._find_quants(model.graph, matches)

self.activation_post_process_map = dict()
env: Dict[Any, Any] = {}
Expand Down Expand Up @@ -549,7 +551,8 @@ def _convert(self, model: GraphModule, debug: bool = False,
model.graph, self.modules, self.patterns,
custom_module_classes=custom_module_classes)

quants = self._find_quants(model.graph, matches)
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
self._find_quants(model.graph, matches)

self.quantized_graph = Graph()
env: Dict[Any, Any] = {}
Expand All @@ -571,14 +574,9 @@ def load_non_quantized(n):
return env[n.name]

def load_quantized(n):
if n.name not in quant_env:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code was flagged unreachable, deleting

assert n.name in env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in float environment:' + str(env)
assert n.name in quants, \
'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
assert n.name in quant_env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in quant environment:' + str(quant_env)
return quant_env[n.name]

def load_x(n):
Expand Down Expand Up @@ -943,7 +941,7 @@ def is_standalone_module(node_target):
return match_map

def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
) -> Dict[str, Any]:
) -> Dict[str, Tuple[DefaultQuantizeHandler, Callable]]:
"""
Takes the nodes in the input graph and pending matches, and finds and
returns the input and output nodes which need to be quantized.
Expand All @@ -956,7 +954,7 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler),
activation_post_process (observer/fake_quantize module) constructor)
"""
quants: Dict[str, Any] = {}
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = {}

def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
Expand All @@ -972,7 +970,7 @@ def visit_arg(arg):
act_post_process_ctr = qconfig.weight if is_weight else \
qconfig.activation
quants[arg.name] = (
DefaultQuantizeHandler(self, arg), qconfig, is_weight)
DefaultQuantizeHandler(self, arg), qconfig)
# overwrite the constructor from qconfig
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
Expand Down