Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][graphmode][fx] Scope support for call_method in QuantizationTracer #50173

Closed
wants to merge 8 commits into from
70 changes: 69 additions & 1 deletion test/quantization/test_quantize_fx.py
Expand Up @@ -564,7 +564,6 @@ def forward(self, x):
"": None,
"object_type": [
(nn.Conv2d, default_qconfig),
("chunk", None)
]
}
# make sure it runs
Expand Down Expand Up @@ -915,6 +914,75 @@ def forward(self, x):
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m)

def test_qconfig_for_call_method(self):
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = x.transpose(2, 3)
x = self.conv(x)
return x.transpose(2, 3)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sub = Sub()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = self.conv1(x)
x = self.sub(x)
x = self.conv2(x)
return x.transpose(2, 3)

qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]}
# since sub is configured to have qconfig None, we should dequantize the output
# of self.conv1 and quantize the input of self.conv2
# dequantize after conv2 should happen after transpose since
# it is configured with default_qconfig
# nodes in Sub module instance is not quantized
node_list1 = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_method("transpose"),
ns.call_module(nn.Conv2d),
ns.call_method("transpose"),
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("transpose"),
ns.call_method("dequantize")
]

qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]}
# Only nodes in Sub module instance are quantized
# the first transpose is not quantized because the input is not quantized
node_list2 = [
ns.call_module(nn.Conv2d),
ns.call_method("transpose"),
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("transpose"),
ns.call_method("dequantize"),
ns.call_module(nn.Conv2d),
ns.call_method("transpose"),
]

for qconfig_dict, node_list in [
(qconfig_dict1, node_list1),
(qconfig_dict2, node_list2)
]:
m = M().eval()
m = prepare_fx(m, qconfig_dict)
m(torch.randn(2, 1, 3, 3))
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node_list=node_list)
# make sure it runs
m(torch.randn(2, 1, 3, 3))

@skipIfNoFBGEMM
def test_qat_and_script(self):
model = LinearModelWithSubmodule().train()
Expand Down
5 changes: 2 additions & 3 deletions torch/quantization/fx/qconfig_utils.py
Expand Up @@ -88,10 +88,9 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig):
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
assert modules is not None
def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig):
module_type_qconfig = get_object_type_qconfig(
qconfig_dict, type(modules[module_name]), global_qconfig)
qconfig_dict, module_type, global_qconfig)
module_name_regex_qconfig = get_module_name_regex_qconfig(
qconfig_dict, module_name, module_type_qconfig)
module_name_qconfig = get_module_name_qconfig(
Expand Down
61 changes: 30 additions & 31 deletions torch/quantization/fx/quantize.py
Expand Up @@ -61,8 +61,6 @@

from .qconfig_utils import *

import warnings

from typing import Optional, Dict, Any, List, Tuple, Set, Callable

# Define helper types
Expand Down Expand Up @@ -346,50 +344,47 @@ def _generate_qconfig_map(
self,
root: torch.nn.Module,
input_graph: Graph,
qconfig_dict: Any) -> None:
global_qconfig = qconfig_dict.get('', None)
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
global_qconfig = qconfig_dict.get("", None)

self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == 'get_attr':
if node.op == "get_attr":
module_name, _ = _parent_name(node.target)
assert self.modules is not None
self.qconfig_map[node.name] = get_qconfig(
self.modules, qconfig_dict, module_name, global_qconfig)
elif node.op == 'call_function':
qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig)
elif node.op == "call_function":
# precedence: [TODO] module_name_qconfig (need scope support
# from fx)
# > function_qconfig > global_qconfig
function_qconfig = get_object_type_qconfig(
qconfig_dict, node.target, global_qconfig)
self.qconfig_map[node.name] = function_qconfig
elif node.op == 'call_method':
self_obj = node.args[0]
# qconfig for call_method should be the same as the `self`
# object for the call
if self_obj.name in self.qconfig_map:
qconfig = self.qconfig_map[self_obj.name]
else:
# need scope info for each node to support this
warnings.warn(
"Scope info is not yet supported, taking default " +
"qconfig for value {}".format(node.name))
qconfig = get_qconfig(
self.modules, qconfig_dict, '', global_qconfig)
qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig)
elif node.op == "call_method":
module_path, module_type = node_name_to_scope[node.name]
# use the qconfig of the module that the node belongs to
qconfig = get_qconfig(
qconfig_dict, module_type, module_path, global_qconfig)
self.qconfig_map[node.name] = qconfig
elif node.op == 'call_module':
assert self.modules is not None
module_qconfig = get_qconfig(
self.modules, qconfig_dict, node.target, global_qconfig)
qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig)
# regex is not supported eager mode propagate_qconfig_, we'll
# need to set the qconfig explicitly here in case regex
# is used
assert self.modules is not None
self.modules[node.target].qconfig = module_qconfig
self.qconfig_map[node.name] = module_qconfig

def _prepare(self, model: GraphModule, qconfig_dict: Any,
prepare_custom_config_dict: Optional[Dict[str, Any]],
is_standalone_module: bool) -> GraphModule:
def _prepare(
self,
model: GraphModule,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]],
prepare_custom_config_dict: Optional[Dict[str, Any]],
is_standalone_module: bool) -> GraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.

Expand Down Expand Up @@ -428,7 +423,7 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any,

convert_dict_to_ordered_dict(qconfig_dict)
# map from node name to qconfig, used in _find_matches
self._generate_qconfig_map(model, model.graph, qconfig_dict)
self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope)

# match the patterns that will get quantized
standalone_module_name_configs = prepare_custom_config_dict.get(
Expand Down Expand Up @@ -579,11 +574,15 @@ def restore_state(self, observed: GraphModule) -> None:
self.prepare_custom_config_dict = \
observed._prepare_custom_config_dict # type: ignore

def prepare(self, model: GraphModule, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
def prepare(
self,
model: GraphModule,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]],
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
return self._prepare(
model, qconfig_dict, prepare_custom_config_dict,
model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict,
is_standalone_module)

def _run_weight_observers(self, observed: GraphModule) -> None:
Expand Down
92 changes: 85 additions & 7 deletions torch/quantization/quantize_fx.py
@@ -1,12 +1,13 @@
import torch
from torch.fx import GraphModule # type: ignore
from torch.fx.symbolic_trace import Tracer # type: ignore
from torch.fx.node import Target, Node, Argument # type: ignore
from .fx import Fuser # noqa: F401
from .fx import Quantizer # noqa: F401
from .fx.utils import graph_pretty_str # noqa: F401
from .fx.utils import get_custom_module_class_keys # noqa: F401
from torch.nn.intrinsic import _FusedModule
from typing import Dict, Any, List, Callable
from typing import Dict, Any, List, Callable, Tuple, Optional

def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
Expand Down Expand Up @@ -41,20 +42,95 @@ def _fuse_fx(
fuser = Fuser()
return fuser.fuse(graph_module, fuse_custom_config_dict)

class CustomTracer(Tracer):
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example:
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)

class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()

def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x

"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type

class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic
tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self,
scope: Scope,
current_module: torch.nn.Module,
current_module_path: str):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)

def __enter__(self):
return

def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return


class QuantizationTracer(Tracer):
def __init__(self, skipped_module_names: List[str],
skipped_module_classes: List[Callable]):
super().__init__()
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes

def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
# NB: initialized the module_type of top level module to None
# we are assuming people won't configure the model with the type of top level
# module here, since people can use "" for global config
# We can change this if there is a use case that configures
# qconfig using top level module type
self.scope = Scope("", None)
self.node_name_to_scope : Dict[str, Tuple[str, type]] = {}

def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
return (m.__module__.startswith("torch.nn") and
not isinstance(m, torch.nn.Sequential)) or \
module_qualified_name in self.skipped_module_names or \
type(m) in self.skipped_module_classes or \
isinstance(m, _FusedModule)

def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
module_qualified_name = self.path_of_module(m)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a this is copied from default call_module code, would it be a problem if the default code changes? do you think it's better to provide an API in call_module to allow this kind of extension?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the future-proofed way to do this is to instantiate the guard and call super().call_module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sg

# Creating scope with information of current module
# scope will be restored automatically upon exit
with ScopeContextManager(self.scope, m, module_qualified_name):
return super().call_module(m, forward, args, kwargs)

def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr)
if kind == "call_method":
self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type)
return node

def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
Expand Down Expand Up @@ -89,18 +165,20 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
float_custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
skipped_module_classes += float_custom_module_classes
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
quantizer = Quantizer()
return quantizer.prepare(
graph_module,
qconfig_dict,
tracer.node_name_to_scope,
prepare_custom_config_dict=prepare_custom_config_dict,
is_standalone_module=is_standalone_module)

def _prepare_standalone_module_fx(
model: torch.nn.Module, qconfig_dict: Any,
model: torch.nn.Module,
qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
Expand Down