-
Notifications
You must be signed in to change notification settings - Fork 14.4k
Revert "[mlir] Improve mlir-query by adding matcher combinators" #145534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
qinkunbao
merged 1 commit into
main
from
revert-141423-Improve-Mlir-Query-with-matcher-combinators
Jun 24, 2025
Merged
Revert "[mlir] Improve mlir-query by adding matcher combinators" #145534
qinkunbao
merged 1 commit into
main
from
revert-141423-Improve-Mlir-Query-with-matcher-combinators
Jun 24, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
)" This reverts commit 12611a7.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Qinkun Bao (qinkunbao) ChangesReverts llvm/llvm-project#141423 Patch is 31.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145534.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 5fe6965f32efb..012bf7b9ec4a9 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,9 +108,6 @@ class MatcherDescriptor {
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;
- // If the matcher is variadic, it can take any number of arguments.
- virtual bool isVariadic() const = 0;
-
// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;
@@ -143,8 +140,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}
- bool isVariadic() const override { return false; }
-
unsigned getNumArgs() const override { return argKinds.size(); }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -158,54 +153,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
const std::vector<ArgKind> argKinds;
};
-class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
-public:
- using VarOp = DynMatcher::VariadicOperator;
- VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
- VarOp varOp, StringRef matcherName)
- : minCount(minCount), maxCount(maxCount), varOp(varOp),
- matcherName(matcherName) {}
-
- VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
- Diagnostics *error) const override {
- if (args.size() < minCount || maxCount < args.size()) {
- addError(error, nameRange, ErrorType::RegistryWrongArgCount,
- {llvm::Twine("requires between "), llvm::Twine(minCount),
- llvm::Twine(" and "), llvm::Twine(maxCount),
- llvm::Twine(" args, got "), llvm::Twine(args.size())});
- return VariantMatcher();
- }
-
- std::vector<VariantMatcher> innerArgs;
- for (int64_t i = 0, e = args.size(); i != e; ++i) {
- const ParserValue &arg = args[i];
- const VariantValue &value = arg.value;
- if (!value.isMatcher()) {
- addError(error, arg.range, ErrorType::RegistryWrongArgType,
- {llvm::Twine(i + 1), llvm::Twine("matcher: "),
- llvm::Twine(value.getTypeAsString())});
- return VariantMatcher();
- }
- innerArgs.push_back(value.getMatcher());
- }
- return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
- }
-
- bool isVariadic() const override { return true; }
-
- unsigned getNumArgs() const override { return 0; }
-
- void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
- kinds.push_back(ArgKind(ArgKind::Matcher));
- }
-
-private:
- const unsigned minCount;
- const unsigned maxCount;
- const VarOp varOp;
- const StringRef matcherName;
-};
-
// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
@@ -277,14 +224,6 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}
-// Variadic operator overload.
-template <unsigned MinCount, unsigned MaxCount>
-std::unique_ptr<MatcherDescriptor>
-makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
- StringRef matcherName) {
- return std::make_unique<VariadicOperatorMatcherDescriptor>(
- MinCount, MaxCount, func.varOp, matcherName);
-}
} // namespace mlir::query::matcher::internal
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 6d06ca13d1344..f8abf20ef60bb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,9 +21,7 @@
namespace mlir::query::matcher {
-/// Finds and collects matches from the IR. After construction
-/// `collectMatches` can be used to traverse the IR and apply
-/// matchers.
+/// A class that provides utilities to find operations in the IR.
class MatchFinder {
public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 88109430b6feb..183b2514e109f 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a
-// `match(...)` method whose parameters define the context of the match.
-// Support includes simple (unary) matchers as well as matcher combinators
-// (anyOf, allOf, etc.)
+// Matchers are methods that return a Matcher which provides a method one of the
+// following methods: match(Operation *op), match(Operation *op,
+// SetVector<Operation *> &matchedOps)
//
+// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
@@ -25,15 +25,6 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir::query::matcher {
-class DynMatcher;
-namespace internal {
-
-bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers);
-bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers);
-
-} // namespace internal
// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
@@ -93,27 +84,6 @@ class MatcherFnImpl : public MatcherInterface {
MatcherFn matcherFn;
};
-// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
-// match the given operation.
-using VariadicOperatorFunction = bool (*)(Operation *op,
- SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers);
-
-template <VariadicOperatorFunction Func>
-class VariadicMatcher : public MatcherInterface {
-public:
- VariadicMatcher(std::vector<DynMatcher> matchers)
- : matchers(std::move(matchers)) {}
-
- bool match(Operation *op) override { return Func(op, nullptr, matchers); }
- bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
- return Func(op, &matchedOps, matchers);
- }
-
-private:
- std::vector<DynMatcher> matchers;
-};
-
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
@@ -122,31 +92,6 @@ class DynMatcher {
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}
- // Construct from a variadic function.
- enum VariadicOperator {
- // Matches operations for which all provided matchers match.
- AllOf,
- // Matches operations for which at least one of the provided matchers
- // matches.
- AnyOf
- };
-
- static std::unique_ptr<DynMatcher>
- constructVariadic(VariadicOperator Op,
- std::vector<DynMatcher> innerMatchers) {
- switch (Op) {
- case AllOf:
- return std::make_unique<DynMatcher>(
- new VariadicMatcher<internal::allOfVariadicOperator>(
- std::move(innerMatchers)));
- case AnyOf:
- return std::make_unique<DynMatcher>(
- new VariadicMatcher<internal::anyOfVariadicOperator>(
- std::move(innerMatchers)));
- }
- llvm_unreachable("Invalid Op value.");
- }
-
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -168,59 +113,6 @@ class DynMatcher {
std::string functionName;
};
-// VariadicOperatorMatcher related types.
-template <typename... Ps>
-class VariadicOperatorMatcher {
-public:
- VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
- : varOp(varOp), params(std::forward<Ps>(params)...) {}
-
- operator std::unique_ptr<DynMatcher>() const & {
- return DynMatcher::constructVariadic(
- varOp, getMatchers(std::index_sequence_for<Ps...>()));
- }
-
- operator std::unique_ptr<DynMatcher>() && {
- return DynMatcher::constructVariadic(
- varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
- }
-
-private:
- // Helper method to unpack the tuple into a vector.
- template <std::size_t... Is>
- std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
- return {DynMatcher(std::get<Is>(params))...};
- }
-
- template <std::size_t... Is>
- std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
- return {DynMatcher(std::get<Is>(std::move(params)))...};
- }
-
- const DynMatcher::VariadicOperator varOp;
- std::tuple<Ps...> params;
-};
-
-// Overloaded function object to generate VariadicOperatorMatcher objects from
-// arbitrary matchers.
-template <unsigned MinCount, unsigned MaxCount>
-struct VariadicOperatorMatcherFunc {
- DynMatcher::VariadicOperator varOp;
-
- template <typename... Ms>
- VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
- static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
- "invalid number of parameters for variadic matcher");
- return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
- }
-};
-
-namespace internal {
-const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
- anyOf = {DynMatcher::AnyOf};
-const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
- allOf = {DynMatcher::AllOf};
-} // namespace internal
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 7181648f06f89..441205b3a9615 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,8 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines slicing-analysis matchers that extend and abstract the
-// core implementations from `SliceAnalysis.h`.
+// This file provides matchers for MLIRQuery that peform slicing analysis
//
//===----------------------------------------------------------------------===//
@@ -17,9 +16,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"
-/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
-/// if `innerMatcher` matches. The traversal stops once the desired depth level
-/// is reached.
+/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
+/// Additionally, it limits the slice computation to a certain depth level using
+/// a custom filter.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
@@ -120,77 +119,6 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}
-/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
-/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
-template <typename BaseMatcher, typename Filter>
-class PredicateBackwardSliceMatcher {
-public:
- PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive, bool omitBlockArguments,
- bool omitUsesFromAbove)
- : innerMatcher(std::move(innerMatcher)),
- filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
- omitBlockArguments(omitBlockArguments),
- omitUsesFromAbove(omitUsesFromAbove) {}
-
- bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
- backwardSlice.clear();
- BackwardSliceOptions options;
- options.inclusive = inclusive;
- options.omitUsesFromAbove = omitUsesFromAbove;
- options.omitBlockArguments = omitBlockArguments;
- if (innerMatcher.match(rootOp)) {
- options.filter = [&](Operation *subOp) {
- return !filterMatcher.match(subOp);
- };
- LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
- assert(result.succeeded() && "expected backward slice to succeed");
- (void)result;
- return options.inclusive ? backwardSlice.size() > 1
- : backwardSlice.size() >= 1;
- }
- return false;
- }
-
-private:
- BaseMatcher innerMatcher;
- Filter filterMatcher;
- bool inclusive;
- bool omitBlockArguments;
- bool omitUsesFromAbove;
-};
-
-/// Computes the forward-slice of all users reachable from `rootOp`,
-/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
-template <typename BaseMatcher, typename Filter>
-class PredicateForwardSliceMatcher {
-public:
- PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive)
- : innerMatcher(std::move(innerMatcher)),
- filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
-
- bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
- forwardSlice.clear();
- ForwardSliceOptions options;
- options.inclusive = inclusive;
- if (innerMatcher.match(rootOp)) {
- options.filter = [&](Operation *subOp) {
- return !filterMatcher.match(subOp);
- };
- getForwardSlice(rootOp, &forwardSlice, options);
- return options.inclusive ? forwardSlice.size() > 1
- : forwardSlice.size() >= 1;
- }
- return false;
- }
-
-private:
- BaseMatcher innerMatcher;
- Filter filterMatcher;
- bool inclusive;
-};
-
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
@@ -202,7 +130,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}
-/// Matches all transitive defs of a top-level operation up to N levels.
+/// Matches all transitive defs of a top-level operation up to N levels
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
@@ -211,28 +139,6 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}
-/// Matches all transitive defs of a top-level operation and stops where
-/// `filterMatcher` rejects.
-template <typename BaseMatcher, typename Filter>
-inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
-m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive, bool omitBlockArguments,
- bool omitUsesFromAbove) {
- return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
- std::move(innerMatcher), std::move(filterMatcher), inclusive,
- omitBlockArguments, omitUsesFromAbove);
-}
-
-/// Matches all users of a top-level operation and stops where
-/// `filterMatcher` rejects.
-template <typename BaseMatcher, typename Filter>
-inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
-m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive) {
- return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
- std::move(innerMatcher), std::move(filterMatcher), inclusive);
-}
-
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 1a47576de1841..98c0a18e25101 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,12 +26,7 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
class VariantMatcher {
- class MatcherOps {
- public:
- std::optional<DynMatcher>
- constructVariadicOperator(DynMatcher::VariadicOperator varOp,
- ArrayRef<VariantMatcher> innerMatchers) const;
- };
+ class MatcherOps;
// Payload interface to be specialized by each matcher type. It follows a
// similar interface as VariantMatcher itself.
@@ -48,9 +43,6 @@ class VariantMatcher {
// Clones the provided matcher.
static VariantMatcher SingleMatcher(DynMatcher matcher);
- static VariantMatcher
- VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
- ArrayRef<VariantMatcher> args);
// Makes the matcher the "null" matcher.
void reset();
@@ -69,7 +61,6 @@ class VariantMatcher {
: value(std::move(value)) {}
class SinglePayload;
- class VariadicOpPayload;
std::shared_ptr<const Payload> value;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index ba202762fdfbb..629479bf7adc1 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
- MatchersInternal.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
deleted file mode 100644
index 01f412ade846b..0000000000000
--- a/mlir/lib/Query/Matcher/MatchersInternal.cpp
+++ /dev/null
@@ -1,33 +0,0 @@
-//===--- MatchersInternal.cpp----------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Query/Matcher/MatchersInternal.h"
-#include "llvm/ADT/SetVector.h"
-
-namespace mlir::query::matcher {
-
-namespace internal {
-
-bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers) {
- return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
- if (matchedOps)
- return matcher.match(op, *matchedOps);
- return matcher.match(op);
- });
-}
-bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers) {
- return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
- if (matchedOps)
- return matcher.match(op, *matchedOps);
- return matcher.match(op);
- });
-}
-} // namespace internal
-} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 08b610453b11a..4b511c5f009e7 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
unsigned argNumber = ctxEntry.second;
std::vector<ArgKind> nextTypeSet;
- if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
+ if (argNumber < ctor->getNumArgs())
ctor->getArgKinds(argNumber, nextTypeSet);
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
const internal::MatcherDescriptor &matcher = *m.getValue();
llvm::StringRef name = m.getKey();
- unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
+ unsigned numArgs = matcher.getNumArgs();
std::vector<std::vector<ArgKind>> argKinds(numArgs);
for (const ArgKind &kind : acceptedTypes) {
@@ -115,9 +115,6 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
}
...
[truncated]
|
DrSergei
pushed a commit
to DrSergei/llvm-project
that referenced
this pull request
Jun 24, 2025
anthonyhatran
pushed a commit
to anthonyhatran/llvm-project
that referenced
this pull request
Jun 26, 2025
rlavaee
pushed a commit
to rlavaee/llvm-project
that referenced
this pull request
Jul 1, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reverts #141423