Skip to content

Commit

Permalink
fx quant: add typing for fuser (#48844)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48844

Add types to function I/O for `Fuser` to improve readability

Test Plan:
```
mypy torch/quantization/
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25337314

fbshipit-source-id: e5074d71c7834f24975169d36bf49357e53650ff
  • Loading branch information
vkuzo authored and facebook-github-bot committed Dec 5, 2020
1 parent 63a71a8 commit fa5f7d8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
23 changes: 18 additions & 5 deletions torch/quantization/fx/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torch.fx import ( # type: ignore
GraphModule,
Node,
map_arg
)

Expand All @@ -18,19 +19,28 @@

from .fusion_patterns import * # noqa: F401

from .quantization_types import Pattern

from typing import Callable, Tuple, Optional


class Fuser:
def fuse(self, model, fuse_custom_config_dict=None):
def fuse(self, model: GraphModule,
fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}

input_root = model
input_graph = model.graph
self.modules = dict(input_root.named_modules())

additional_fusion_patterns = fuse_custom_config_dict.get("additional_quant_pattern", {})
fusion_patterns = get_combined_dict(get_default_fusion_patterns(), additional_fusion_patterns)
additional_fusion_patterns = \
fuse_custom_config_dict.get("additional_quant_pattern", {})
fusion_patterns = get_combined_dict(
get_default_fusion_patterns(), additional_fusion_patterns)
# find fusion
fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
fusion_pairs = self._find_matches(
input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
env: Dict[Any, Any] = {}

Expand All @@ -40,6 +50,7 @@ def load_arg(a):
for node in input_graph.nodes:
root_node, obj = fusion_pairs.get(node.name, (None, None))
if root_node is node:
assert obj is not None
env[node.name] = obj.fuse(self, load_arg)
elif root_node is None:
env[node.name] = self.fused_graph.node_copy(node, load_arg)
Expand All @@ -48,7 +59,9 @@ def load_arg(a):
model = GraphModule(input_root, self.fused_graph)
return model

def _find_matches(self, root, graph, patterns):
def _find_matches(self, root: GraphModule, graph: Graph,
patterns: Dict[Pattern, Callable]
) -> Dict[str, Tuple[Node, Optional[Any]]]:
modules = dict(root.named_modules())
match_map = {} # node name -> (root_node, match_value?)

Expand Down
5 changes: 3 additions & 2 deletions torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import sys
from collections import OrderedDict
from typing import Dict, Any, Union, Tuple, Callable
from typing import Dict, Any

from .quantization_types import Pattern

# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Callable, Callable]]

# pattern for conv bn fusion
DEFAULT_FUSION_PATTERNS = OrderedDict()
Expand Down
3 changes: 3 additions & 0 deletions torch/quantization/fx/quantization_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Union, Callable, Tuple

Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Callable, Callable]]

0 comments on commit fa5f7d8

Please sign in to comment.