Skip to content

Commit

Permalink
Fix SubgraphMatch for case of no anchor found
Browse files Browse the repository at this point in the history
ghstack-source-id: 43e81f4ad4c9ca4903a82a91191f8900f34ef56e
Pull Request resolved: #86421
  • Loading branch information
SherlockNoMad committed Oct 6, 2022
1 parent a75edfa commit 5fb70c5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
22 changes: 22 additions & 0 deletions test/test_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,27 @@ def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
TestCase(False, False, 1),
]

class NoAnchorFound:
# This test case is for pattern where no matching anchor is found in the target graph
# `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
@staticmethod
def forward(x):
x = x + 1
return x

@staticmethod
def pattern(a):
b1 = a.relu()
return b1

test_cases = [
# match_output, match_placeholder, num_matches
TestCase(False, False, 0),
TestCase(True, False, 0),
TestCase(False, True, 0),
TestCase(True, True, 0)
]

@instantiate_parametrized_tests
class TestFXMatcherUtils(JitTestCase):

Expand All @@ -683,6 +704,7 @@ class TestFXMatcherUtils(JitTestCase):
MultipleOutputsHorizontalPattern,
MultiOutputWithWithInvalidMatches,
QuantizationFp8Pattern,
NoAnchorFound,
])
def test_subgraph_matcher(self, test_model):

Expand Down
28 changes: 25 additions & 3 deletions torch/fx/passes/utils/matcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,26 @@
import torch.utils._pytree as pytree
from typing import Dict, List, Set, Any
import logging
import os

__all__ = ['SubgraphMatcher', 'InternalMatch']

logger = logging.getLogger(__name__)
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
def _init_logger():
logger = logging.getLogger(__name__)

level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter("%(filename)s > %(message)s")
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger

logger = _init_logger()

@compatibility(is_backward_compatible=False)
@dataclass
Expand Down Expand Up @@ -126,7 +142,7 @@ def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[Inte
return non_overlapping_matches

def _match_args(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
assert not(isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"

if isinstance(pn, Node) and not isinstance(gn, Node):
if pn.op == "placeholder":
Expand Down Expand Up @@ -247,6 +263,9 @@ def match(self, graph: Graph) -> List[InternalMatch]:
if self._nodes_are_equal(pattern_anchor, node):
match_candidates[pattern_anchor].append(node)
match_candidates_list = list(match_candidates.items())

logger.info(f"Initial match_candidates_list: {match_candidates_list}\n")

matches: List[InternalMatch] = []

def backtracking(anchor_index, match):
Expand Down Expand Up @@ -275,7 +294,8 @@ def backtracking(anchor_index, match):
match = copy.copy(saved_match)

match = InternalMatch(anchors=self.pattern_anchors)
backtracking(0, match)
if match_candidates_list:
backtracking(0, match)

# filter out the matches where the subgraph is not fully_contained
before = len(matches)
Expand All @@ -302,4 +322,6 @@ def backtracking(anchor_index, match):
if before != after:
logger.info(f"Filtered out {before - after} matches because matched subgraphs are overlapping")

logger.info(f"Matches returned: {matches}")

return matches

0 comments on commit 5fb70c5

Please sign in to comment.