Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx quant: add more typehints #48774

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
84 changes: 44 additions & 40 deletions torch/quantization/fx/quantize.py
Expand Up @@ -57,12 +57,12 @@
import warnings
import re

from typing import Optional, Dict, Any, List, Union, Tuple, Set
from typing import Optional, Dict, Any, List, Union, Tuple, Set, Callable

# Define helper types

QConfigAny = Union[torch.quantization.QConfig,
torch.quantization.QConfigDynamic]
torch.quantization.QConfigDynamic, None]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]

Expand All @@ -75,9 +75,9 @@
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
def get_new_attr_name_with_prefix(prefix):
def get_new_attr_name(module):
def get_attr_name(i):
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
def get_new_attr_name(module: torch.nn.Module):
def get_attr_name(i: int):
return prefix + str(i)
i = 0
attr_name = get_attr_name(i)
Expand All @@ -87,7 +87,7 @@ def get_attr_name(i):
return attr_name
return get_new_attr_name

def collect_producer_nodes(node):
def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
r''' Starting from a target node, trace back until we hit inpu or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example
Expand All @@ -114,7 +114,8 @@ def forward(self, x):
frontier.append(arg)
return nodes

def graph_module_from_producer_nodes(root, producer_nodes):
def graph_module_from_producer_nodes(
root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
r''' Construct a graph module from extracted producer nodes
from `collect_producer_nodes` function
Args:
Expand All @@ -137,7 +138,7 @@ def load_arg(a):
graph_module = GraphModule(root, graph)
return graph_module

def assert_and_get_unique_device(module):
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
Expand All @@ -151,13 +152,10 @@ def assert_and_get_unique_device(module):
device = next(iter(devices)) if len(devices) > 0 else None
return device

def is_submodule_of_fake_quant(name, module, named_modules):
parent_name, _ = _parent_name(name)
return is_activation_post_process(named_modules[parent_name])

def is_observed_standalone_module_node(node, modules):
def is_observed_standalone_module_node(
node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
return node.op == 'call_module' and \
is_observed_standalone_module(modules[node.target])
is_observed_standalone_module(modules[node.target]) # type: ignore


def get_flattened_qconfig_dict(qconfig_dict):
Expand Down Expand Up @@ -247,9 +245,11 @@ def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
return module_name_qconfig

def insert_observer(
node, observer, model_device, model,
activation_post_process_map, env, observed_graph, load_arg,
observed_node_names_set):
node: Node, observer: torch.quantization.ObserverBase,
model_device: Any, model: torch.nn.Module,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
observed_node_names_set: Set[str]):
"""Insert observer for node by modifying the observed_graph and
attach observer module to the model
Args:
Expand All @@ -273,15 +273,15 @@ def insert_observer(
observed_node_names_set.add(node.name)

def insert_observer_for_special_module(
quantize_handler, modules, prepare_custom_config_dict, qconfig,
node):
quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module],
prepare_custom_config_dict: Any, qconfig: Any, node: Node):
""" Insert observer for custom module and standalone module
Returns: standalone_module_input_idxs: the indexs for inputs that
needs to be observed by parent module
"""
assert modules is not None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = modules[node.target]
custom_module = modules[node.target] # type: ignore
custom_module_class_mapping = prepare_custom_config_dict.get(
"float_to_observed_custom_module_class", {})
observed_custom_module_class = \
Expand All @@ -293,7 +293,7 @@ def insert_observer_for_special_module(
setattr(modules[parent_name], name, observed_custom_module)
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = modules[node.target]
standalone_module = modules[node.target] # type: ignore
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
Expand All @@ -304,22 +304,22 @@ def insert_observer_for_special_module(
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name,
observed_standalone_module)
modules[node.target] = observed_standalone_module
modules[node.target] = observed_standalone_module # type: ignore

def insert_observer_for_output_of_the_node(
node,
quantize_handler,
qconfig,
modules,
model,
pattern,
model_device,
activation_post_process_map,
env,
observed_graph,
load_arg,
observed_node_names_set,
matched_nodes):
node: Node,
quantize_handler: QuantizeHandler,
qconfig: Any,
modules: Dict[str, torch.nn.Module],
model: torch.nn.Module,
pattern: Any,
model_device: Any,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[Any, Any],
observed_graph: Graph,
load_arg: Callable,
observed_node_names_set: Set[str],
matched_nodes: Optional[List[Node]]):
""" Insert observer/fake_quantize module for output of the observed
module if needed
"""
Expand Down Expand Up @@ -391,9 +391,11 @@ def input_is_observed(arg):
load_arg, observed_node_names_set)

def insert_observer_for_input_arg_of_observed_node(
node, observed_node_names_set, quants,
model_device, model, activation_post_process_map, env, observed_graph,
load_arg):
node: Node, observed_node_names_set: Set[str], quants: Dict[str, Any],
model_device: Any, model: torch.nn.Module,
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
env: Dict[str, str], observed_graph: Graph,
load_arg: Callable):
if node.name not in observed_node_names_set and node.name in quants:
_, activation_post_process_ctr = quants[node.name]
if activation_post_process_ctr is not None:
Expand Down Expand Up @@ -578,6 +580,7 @@ def load_arg(a):
# index for input of custom module that needs to be observed in
# parent
if qconfig is not None:
assert obj is not None
insert_observer_for_special_module(
obj, self.modules, prepare_custom_config_dict, qconfig,
node)
Expand Down Expand Up @@ -1049,7 +1052,8 @@ def is_standalone_module(node_target):

return match_map

def _find_quants(self, graph, matches):
def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
) -> Dict[str, Any]:
"""
Takes the nodes in the input graph and pending matches, and finds and
returns the input and output nodes which need to be quantized.
Expand All @@ -1062,7 +1066,7 @@ def _find_quants(self, graph, matches):
node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler),
activation_post_process (observer/fake_quantize module) constructor)
"""
quants: Dict[Any, Any] = {}
quants: Dict[str, Any] = {}

def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
Expand Down