Skip to content

Commit

Permalink
[quant][graphmode][fx] Scope support for call_method in QuantizationT…
Browse files Browse the repository at this point in the history
…racer

Summary:
Previously we did not set the qconfig for call_method node correctly since it requires us to know
the scope (module path of the module whose forward graph contains the node) of the node. This
PR modifies the QuantizationTracer to record the scope information and build a map from call_method
Node to module path, which will be used when we construct qconfig_map

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qconfig_for_call_method
Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f30902dc69be9c263d31a851b07f757efe5ac2e1
Pull Request resolved: #50173
  • Loading branch information
jerryzh168 committed Jan 7, 2021
1 parent 6838ece commit a1ec09b
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 42 deletions.
44 changes: 43 additions & 1 deletion test/quantization/test_quantize_fx.py
Expand Up @@ -564,7 +564,6 @@ def forward(self, x):
"": None,
"object_type": [
(nn.Conv2d, default_qconfig),
("chunk", None)
]
}
# make sure it runs
Expand Down Expand Up @@ -915,6 +914,49 @@ def forward(self, x):
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m)

def test_qconfig_for_call_method(self):
class Sub(torch.nn.Module):
def forward(self, x):
return x.transpose(2, 3)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sub = Sub()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = self.conv1(x) # default_qconfig
x = self.sub(x) # None
x = self.conv2(x) # default_qconfig
return x.transpose(2, 3) # default_qconfig


m = M().eval()
# since we don't quantize sub, we should have dequantize after the self.conv1
# and quantize before self.conv2
# however, the dequantize after conv2 should happen after x.transpose since
# it is configured with default_qconfig
qconfig_dict = {"": default_qconfig, "module_name": [("sub", None)]}
m = prepare_fx(m, qconfig_dict)
m(torch.randn(2, 1, 3, 3))
m = convert_fx(m)
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_method("transpose"),
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("transpose"),
ns.call_method("dequantize")
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)

# make sure it runs
m(torch.randn(2, 1, 3, 3))

@skipIfNoFBGEMM
def test_qat_and_script(self):
model = LinearModelWithSubmodule().train()
Expand Down
5 changes: 2 additions & 3 deletions torch/quantization/fx/qconfig_utils.py
Expand Up @@ -88,10 +88,9 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig):
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
assert modules is not None
def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig):
module_type_qconfig = get_object_type_qconfig(
qconfig_dict, type(modules[module_name]), global_qconfig)
qconfig_dict, module_type, global_qconfig)
module_name_regex_qconfig = get_module_name_regex_qconfig(
qconfig_dict, module_name, module_type_qconfig)
module_name_qconfig = get_module_name_qconfig(
Expand Down
61 changes: 30 additions & 31 deletions torch/quantization/fx/quantize.py
Expand Up @@ -61,8 +61,6 @@

from .qconfig_utils import *

import warnings

from typing import Optional, Dict, Any, List, Tuple, Set, Callable

# Define helper types
Expand Down Expand Up @@ -346,50 +344,47 @@ def _generate_qconfig_map(
self,
root: torch.nn.Module,
input_graph: Graph,
qconfig_dict: Any) -> None:
global_qconfig = qconfig_dict.get('', None)
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, Any]]) -> None:
global_qconfig = qconfig_dict.get("", None)

self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == 'get_attr':
if node.op == "get_attr":
module_name, _ = _parent_name(node.target)
assert self.modules is not None
self.qconfig_map[node.name] = get_qconfig(
self.modules, qconfig_dict, module_name, global_qconfig)
elif node.op == 'call_function':
qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig)
elif node.op == "call_function":
# precedence: [TODO] module_name_qconfig (need scope support
# from fx)
# > function_qconfig > global_qconfig
function_qconfig = get_object_type_qconfig(
qconfig_dict, node.target, global_qconfig)
self.qconfig_map[node.name] = function_qconfig
elif node.op == 'call_method':
self_obj = node.args[0]
# qconfig for call_method should be the same as the `self`
# object for the call
if self_obj.name in self.qconfig_map:
qconfig = self.qconfig_map[self_obj.name]
else:
# need scope info for each node to support this
warnings.warn(
"Scope info is not yet supported, taking default " +
"qconfig for value {}".format(node.name))
qconfig = get_qconfig(
self.modules, qconfig_dict, '', global_qconfig)
qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig)
elif node.op == "call_method":
module_path, module_type = node_name_to_scope[node.name]
# use the qconfig of the module that the node belongs to
qconfig = get_qconfig(
qconfig_dict, module_type, module_path, global_qconfig)
self.qconfig_map[node.name] = qconfig
elif node.op == 'call_module':
assert self.modules is not None
module_qconfig = get_qconfig(
self.modules, qconfig_dict, node.target, global_qconfig)
qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig)
# regex is not supported eager mode propagate_qconfig_, we'll
# need to set the qconfig explicitly here in case regex
# is used
assert self.modules is not None
self.modules[node.target].qconfig = module_qconfig
self.qconfig_map[node.name] = module_qconfig

def _prepare(self, model: GraphModule, qconfig_dict: Any,
prepare_custom_config_dict: Optional[Dict[str, Any]],
is_standalone_module: bool) -> GraphModule:
def _prepare(
self,
model: GraphModule,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, Any]],
prepare_custom_config_dict: Optional[Dict[str, Any]],
is_standalone_module: bool) -> GraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Expand Down Expand Up @@ -428,7 +423,7 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any,

convert_dict_to_ordered_dict(qconfig_dict)
# map from node name to qconfig, used in _find_matches
self._generate_qconfig_map(model, model.graph, qconfig_dict)
self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope)

# match the patterns that will get quantized
standalone_module_name_configs = prepare_custom_config_dict.get(
Expand Down Expand Up @@ -579,11 +574,15 @@ def restore_state(self, observed: GraphModule) -> None:
self.prepare_custom_config_dict = \
observed._prepare_custom_config_dict # type: ignore

def prepare(self, model: GraphModule, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
def prepare(
self,
model: GraphModule,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, Any]],
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> GraphModule:
return self._prepare(
model, qconfig_dict, prepare_custom_config_dict,
model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict,
is_standalone_module)

def _run_weight_observers(self, observed: GraphModule) -> None:
Expand Down
96 changes: 89 additions & 7 deletions torch/quantization/quantize_fx.py
@@ -1,12 +1,13 @@
import torch
from torch.fx import GraphModule # type: ignore
from torch.fx.symbolic_trace import Tracer # type: ignore
from torch.fx.node import Target, Node, Argument # type: ignore
from .fx import Fuser # noqa: F401
from .fx import Quantizer # noqa: F401
from .fx.utils import graph_pretty_str # noqa: F401
from .fx.utils import get_custom_module_class_keys # noqa: F401
from torch.nn.intrinsic import _FusedModule
from typing import Dict, Any, List, Callable
from typing import Dict, Any, List, Callable, Tuple, Optional

def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
Expand Down Expand Up @@ -41,20 +42,99 @@ def _fuse_fx(
fuser = Fuser()
return fuser.fuse(graph_module, fuse_custom_config_dict)

class CustomTracer(Tracer):
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example:
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type

class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic
tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self,
scope: Scope,
current_module: torch.nn.Module,
current_module_path: str):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)

def __enter__(self):
return

def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return


class QuantizationTracer(Tracer):
def __init__(self, skipped_module_names: List[str],
skipped_module_classes: List[Callable]):
super().__init__()
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes

def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
# NB: initialized the module_type of top level module to None
# we are assuming people won't configure the model with the type of top level
# module here, since people can use "" for global config
# We can change this if there is a use case that configures
# qconfig using top level module type
self.scope = Scope("", None)
self.node_name_to_scope = {}

def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
return (m.__module__.startswith("torch.nn") and
not isinstance(m, torch.nn.Sequential)) or \
module_qualified_name in self.skipped_module_names or \
type(m) in self.skipped_module_classes or \
isinstance(m, _FusedModule)

def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
def scoped_forward(_args, _kwargs):
# Creating scope with information of current module
# scope will be restored automatically upon exit
with ScopeContextManager(self.scope, m, module_qualified_name):
return forward(*_args, **_kwargs)
return scoped_forward(args, kwargs)
return self.create_proxy("call_module", module_qualified_name, args, kwargs)

def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr)
if kind == "call_method":
self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type)
return node

def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None,
Expand Down Expand Up @@ -89,18 +169,20 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
float_custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
skipped_module_classes += float_custom_module_classes
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
quantizer = Quantizer()
return quantizer.prepare(
graph_module,
qconfig_dict,
tracer.node_name_to_scope,
prepare_custom_config_dict=prepare_custom_config_dict,
is_standalone_module=is_standalone_module)

def _prepare_standalone_module_fx(
model: torch.nn.Module, qconfig_dict: Any,
model: torch.nn.Module,
qconfig_dict: Any,
prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
Expand Down

0 comments on commit a1ec09b

Please sign in to comment.