Skip to content

Commit

Permalink
Back out "Support regex-style matching for Any and Oneof (#82853)" (#…
Browse files Browse the repository at this point in the history
…83922)

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/fc470cf9806643efdbc1df650f9e8eafb671ba17

Original Phabricator Test Plan:

Reviewed By: yinghai, hl475, khabinov, terrycsy

Differential Revision: D38945806

fbshipit-source-id: 225a79458a38181904f780ffcfb2ea65786cc9f9
  • Loading branch information
Open Source Bot authored and facebook-github-bot committed Aug 25, 2022
1 parent d2d25aa commit 40dc914
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 107 deletions.
60 changes: 2 additions & 58 deletions test/test_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def pattern(a):
TestCase(False, True, 0),
]


class MultipleOutputsHorizontalPattern:
@staticmethod
def forward(x):
Expand All @@ -598,61 +599,6 @@ def pattern(a):
TestCase(True, True, 0)
]

class PatternWithPseudoAny:
@staticmethod
def forward(x):
x = x.relu()
x = x.sigmoid()

y = x.relu()
y = y + 1

z = y.relu()
z = z.relu()

return z

@staticmethod
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.any(y)
return z

test_cases = [
# match_output, match_placeholder, num_matches
TestCase(False, False, 3),
TestCase(True, False, 1),
TestCase(False, True, 1),
TestCase(True, True, 0)
]

class PatternWithPseudoOneof:
@staticmethod
def forward(x):
x = x.relu()
x = torch.sigmoid(x)

z = x.relu()
z = torch.relu(z)

y = x.relu()
y = y + 1

return y

@staticmethod
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.oneof(y, targets=["torch.sigmoid", "operator.add"])
return z

test_cases = [
# match_output, match_placeholder, num_matches
TestCase(False, False, 2),
TestCase(True, False, 1),
TestCase(False, True, 1),
TestCase(True, True, 0)
]

@instantiate_parametrized_tests
class TestFXMatcherUtils(JitTestCase):
Expand All @@ -670,9 +616,7 @@ class TestFXMatcherUtils(JitTestCase):
MultipleOutputsMultipleOverlappingMatches,
MultipleOutputsMultipleNonOverlappingMatches,
MultipleOutputsIdenticalAnchor,
MultipleOutputsHorizontalPattern,
PatternWithPseudoAny,
PatternWithPseudoOneof,
MultipleOutputsHorizontalPattern
])
def test_subgraph_matcher(self, test_model):
traced = symbolic_trace(test_model.forward)
Expand Down
49 changes: 0 additions & 49 deletions torch/fx/passes/utils/matcher_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass, field
from collections import defaultdict
import copy
import torch.library
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
Expand All @@ -10,42 +9,6 @@
__all__ = ['SubgraphMatcher', 'InternalMatch']


pseudo = torch.library.Library("pseudo", "DEF")

pseudo.define("any() -> ()")
"""
pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
For example, to match relu followed by one fx node:
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.any(y)
return z
"""

pseudo.define("oneof(*, str[] targets) -> ()")
"""
pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
"torch.ops.aten.foo", "torch.ops.prims.bar"]
For example, using following pattern with pseudo.oneof
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
return z
It will have 3 matches in the following function
def forward(y):
z = y.relu()
x = z.relu() # first match
x = x.relu()
x = torch.sigmoid(x) # second match
x = x.relu()
return x + 1 # third match
"""

@compatibility(is_backward_compatible=False)
@dataclass
class InternalMatch():
Expand Down Expand Up @@ -117,18 +80,6 @@ def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
if not self.match_placeholder and pn.op == "placeholder":
return True

if pn.target == torch.ops.pseudo.any:
return True

if pn.target == torch.ops.pseudo.oneof:
permissible_targets: List[str] = pn.kwargs.get("targets", list()) # type: ignore[assignment]
assert isinstance(permissible_targets, list), \
"pseudo.oneof(permissible_targets=[\"foo\", \"bar\"]) only accept targets as a list"
assert len(permissible_targets) > 0, "please specific as least one target for pseudo.oneof"

if gn._pretty_print_target(gn.target) in permissible_targets:
return True

if pn.op == gn.op:
if pn.op == "placeholder" or pn.op == "output":
return True
Expand Down

0 comments on commit 40dc914

Please sign in to comment.