Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx quant: fix types on _find_quants #49616

Closed
wants to merge 5 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 11 additions & 7 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def insert_observer_for_input_arg_of_observed_node(
env: Dict[str, str], observed_graph: Graph,
load_arg: Callable):
if node.name not in observed_node_names_set and node.name in quants:
_, activation_post_process_ctr = quants[node.name]
_, activation_post_process_ctr, is_weight = quants[node.name]
if activation_post_process_ctr is not None:
insert_observer(
node, activation_post_process_ctr(),
Expand Down Expand Up @@ -378,7 +378,8 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any,
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an DefaultQuantizeHandler object for each
quants = self._find_quants(model.graph, matches)
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable, bool]] = \
self._find_quants(model.graph, matches)

self.activation_post_process_map = dict()
env: Dict[Any, Any] = {}
Expand Down Expand Up @@ -549,7 +550,8 @@ def _convert(self, model: GraphModule, debug: bool = False,
model.graph, self.modules, self.patterns,
custom_module_classes=custom_module_classes)

quants = self._find_quants(model.graph, matches)
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable, bool]] = \
self._find_quants(model.graph, matches)

self.quantized_graph = Graph()
env: Dict[Any, Any] = {}
Expand Down Expand Up @@ -578,7 +580,8 @@ def load_quantized(n):
assert n.name in quants, \
'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
# TODO: ideally fix the type error before land
quant_env[n.name] = quant.convert(self, env[n.name]) # type: ignore
return quant_env[n.name]

def load_x(n):
Expand Down Expand Up @@ -943,7 +946,7 @@ def is_standalone_module(node_target):
return match_map

def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
) -> Dict[str, Any]:
) -> Dict[str, Tuple[DefaultQuantizeHandler, Callable, bool]]:
"""
Takes the nodes in the input graph and pending matches, and finds and
returns the input and output nodes which need to be quantized.
Expand All @@ -956,7 +959,8 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler),
activation_post_process (observer/fake_quantize module) constructor)
"""
quants: Dict[str, Any] = {}
# quants: Dict[str, Any] = {}
quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable, bool]] = {}
vkuzo marked this conversation as resolved.
Show resolved Hide resolved

def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
Expand All @@ -981,7 +985,7 @@ def visit_arg(arg):
# overwrite previous activation post process constructor if
# necessary
quants[arg.name] = (
DefaultQuantizeHandler(self, arg), act_post_process_ctr)
DefaultQuantizeHandler(self, arg), act_post_process_ctr, is_weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need is_weight, this info is used to decide act_post_process_ctr already in L976

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also we can remove L978-979 I think

return visit_arg

for node in graph.nodes:
Expand Down