-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir] Reapply 141423 mlir-query combinators plus fix #146156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Reapply 141423 mlir-query combinators plus fix #146156
Conversation
Limit backward-slice with nested matching Add variadic operators Add test cases Add test cases for variadic matchers Relocate variadic matchers use signed for arithemtic & avoid copy Add verifier check for extract function Add slicing function extraction test; improve documentation; use lowercase for errors
3d88b01
to
fc67153
Compare
@llvm/pr-subscribers-mlir Author: Denzel-Brian Budii (chios202) ChangesAn uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of #141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild. Patch is 31.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146156.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..5fe6965f32efb 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;
+ // If the matcher is variadic, it can take any number of arguments.
+ virtual bool isVariadic() const = 0;
+
// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}
+ bool isVariadic() const override { return false; }
+
unsigned getNumArgs() const override { return argKinds.size(); }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
const std::vector<ArgKind> argKinds;
};
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+ using VarOp = DynMatcher::VariadicOperator;
+ VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+ VarOp varOp, StringRef matcherName)
+ : minCount(minCount), maxCount(maxCount), varOp(varOp),
+ matcherName(matcherName) {}
+
+ VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+ Diagnostics *error) const override {
+ if (args.size() < minCount || maxCount < args.size()) {
+ addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+ {llvm::Twine("requires between "), llvm::Twine(minCount),
+ llvm::Twine(" and "), llvm::Twine(maxCount),
+ llvm::Twine(" args, got "), llvm::Twine(args.size())});
+ return VariantMatcher();
+ }
+
+ std::vector<VariantMatcher> innerArgs;
+ for (int64_t i = 0, e = args.size(); i != e; ++i) {
+ const ParserValue &arg = args[i];
+ const VariantValue &value = arg.value;
+ if (!value.isMatcher()) {
+ addError(error, arg.range, ErrorType::RegistryWrongArgType,
+ {llvm::Twine(i + 1), llvm::Twine("matcher: "),
+ llvm::Twine(value.getTypeAsString())});
+ return VariantMatcher();
+ }
+ innerArgs.push_back(value.getMatcher());
+ }
+ return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+ }
+
+ bool isVariadic() const override { return true; }
+
+ unsigned getNumArgs() const override { return 0; }
+
+ void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+ kinds.push_back(ArgKind(ArgKind::Matcher));
+ }
+
+private:
+ const unsigned minCount;
+ const unsigned maxCount;
+ const VarOp varOp;
+ const StringRef matcherName;
+};
+
// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+ StringRef matcherName) {
+ return std::make_unique<VariadicOperatorMatcherDescriptor>(
+ MinCount, MaxCount, func.varOp, matcherName);
+}
} // namespace mlir::query::matcher::internal
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f8abf20ef60bb..6d06ca13d1344 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,7 +21,9 @@
namespace mlir::query::matcher {
-/// A class that provides utilities to find operations in the IR.
+/// Finds and collects matches from the IR. After construction
+/// `collectMatches` can be used to traverse the IR and apply
+/// matchers.
class MatchFinder {
public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..88109430b6feb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provides a
+// `match(...)` method whose parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators
+// (anyOf, allOf, etc.)
//
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
@@ -25,6 +25,15 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
@@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface {
MatcherFn matcherFn;
};
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+ SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+ VariadicMatcher(std::vector<DynMatcher> matchers)
+ : matchers(std::move(matchers)) {}
+
+ bool match(Operation *op) override { return Func(op, nullptr, matchers); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+ return Func(op, &matchedOps, matchers);
+ }
+
+private:
+ std::vector<DynMatcher> matchers;
+};
+
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
@@ -92,6 +122,31 @@ class DynMatcher {
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}
+ // Construct from a variadic function.
+ enum VariadicOperator {
+ // Matches operations for which all provided matchers match.
+ AllOf,
+ // Matches operations for which at least one of the provided matchers
+ // matches.
+ AnyOf
+ };
+
+ static std::unique_ptr<DynMatcher>
+ constructVariadic(VariadicOperator Op,
+ std::vector<DynMatcher> innerMatchers) {
+ switch (Op) {
+ case AllOf:
+ return std::make_unique<DynMatcher>(
+ new VariadicMatcher<internal::allOfVariadicOperator>(
+ std::move(innerMatchers)));
+ case AnyOf:
+ return std::make_unique<DynMatcher>(
+ new VariadicMatcher<internal::anyOfVariadicOperator>(
+ std::move(innerMatchers)));
+ }
+ llvm_unreachable("Invalid Op value.");
+ }
+
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +168,59 @@ class DynMatcher {
std::string functionName;
};
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+ VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+ : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+ operator std::unique_ptr<DynMatcher>() const & {
+ return DynMatcher::constructVariadic(
+ varOp, getMatchers(std::index_sequence_for<Ps...>()));
+ }
+
+ operator std::unique_ptr<DynMatcher>() && {
+ return DynMatcher::constructVariadic(
+ varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+ }
+
+private:
+ // Helper method to unpack the tuple into a vector.
+ template <std::size_t... Is>
+ std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+ return {DynMatcher(std::get<Is>(params))...};
+ }
+
+ template <std::size_t... Is>
+ std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+ return {DynMatcher(std::get<Is>(std::move(params)))...};
+ }
+
+ const DynMatcher::VariadicOperator varOp;
+ std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+ DynMatcher::VariadicOperator varOp;
+
+ template <typename... Ms>
+ VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+ static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+ "invalid number of parameters for variadic matcher");
+ return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+ }
+};
+
+namespace internal {
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+ anyOf = {DynMatcher::AnyOf};
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+ allOf = {DynMatcher::AllOf};
+} // namespace internal
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 441205b3a9615..7181648f06f89 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
//
//===----------------------------------------------------------------------===//
@@ -16,9 +17,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
@@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+ PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove)
+ : innerMatcher(std::move(innerMatcher)),
+ filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+ omitBlockArguments(omitBlockArguments),
+ omitUsesFromAbove(omitUsesFromAbove) {}
+
+ bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+ backwardSlice.clear();
+ BackwardSliceOptions options;
+ options.inclusive = inclusive;
+ options.omitUsesFromAbove = omitUsesFromAbove;
+ options.omitBlockArguments = omitBlockArguments;
+ if (innerMatcher.match(rootOp)) {
+ options.filter = [&](Operation *subOp) {
+ return !filterMatcher.match(subOp);
+ };
+ LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
+ assert(result.succeeded() && "expected backward slice to succeed");
+ (void)result;
+ return options.inclusive ? backwardSlice.size() > 1
+ : backwardSlice.size() >= 1;
+ }
+ return false;
+ }
+
+private:
+ BaseMatcher innerMatcher;
+ Filter filterMatcher;
+ bool inclusive;
+ bool omitBlockArguments;
+ bool omitUsesFromAbove;
+};
+
+/// Computes the forward-slice of all users reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateForwardSliceMatcher {
+public:
+ PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive)
+ : innerMatcher(std::move(innerMatcher)),
+ filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
+
+ bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
+ forwardSlice.clear();
+ ForwardSliceOptions options;
+ options.inclusive = inclusive;
+ if (innerMatcher.match(rootOp)) {
+ options.filter = [&](Operation *subOp) {
+ return !filterMatcher.match(subOp);
+ };
+ getForwardSlice(rootOp, &forwardSlice, options);
+ return options.inclusive ? forwardSlice.size() > 1
+ : forwardSlice.size() >= 1;
+ }
+ return false;
+ }
+
+private:
+ BaseMatcher innerMatcher;
+ Filter filterMatcher;
+ bool inclusive;
+};
+
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
@@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
@@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove) {
+ return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+ std::move(innerMatcher), std::move(filterMatcher), inclusive,
+ omitBlockArguments, omitUsesFromAbove);
+}
+
+/// Matches all users of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
+m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive) {
+ return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
+ std::move(innerMatcher), std::move(filterMatcher), inclusive);
+}
+
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
class VariantMatcher {
- class MatcherOps;
+ class MatcherOps {
+ public:
+ std::optional<DynMatcher>
+ constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> innerMatchers) const;
+ };
// Payload interface to be specialized by each matcher type. It follows a
// similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
// Clones the provided matcher.
static VariantMatcher SingleMatcher(DynMatcher matcher);
+ static VariantMatcher
+ VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> args);
// Makes the matcher the "null" matcher.
void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
: value(std::move(value)) {}
class SinglePayload;
+ class VariadicOpPayload;
std::shared_ptr<const Payload> value;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
+ MatchersInternal.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..01f412ade846b
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,33 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers) {
+ return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
+ if (matchedOps)
+ return matcher.match(op, *matchedOps);
+ return matcher.match(op);
+ });
+}
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers) {
+ return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
+ if (matchedOps)
+ return matcher.match(op, *matchedOps);
+ return matcher.match(op);
+ });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
unsigned argNumber = ctxEntry.second;
std::vector<ArgKind> nextTypeSet;
- if (argNumber < ctor->getNumArgs())
+ if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
ctor->getArgKinds(argNumber, nextTypeSet);
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
const internal::MatcherDescriptor &matcher = *m.getValue();
llvm::StringRef name = m.getKey();
- unsigned numArgs = matcher.getNumArgs();
+ unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
std::vector<std::vector<ArgKind>> argKinds(numArgs);
for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
}
}
...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Denzel-Brian Budii (chios202) ChangesAn uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of #141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild. Patch is 31.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146156.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..5fe6965f32efb 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;
+ // If the matcher is variadic, it can take any number of arguments.
+ virtual bool isVariadic() const = 0;
+
// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}
+ bool isVariadic() const override { return false; }
+
unsigned getNumArgs() const override { return argKinds.size(); }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
const std::vector<ArgKind> argKinds;
};
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+ using VarOp = DynMatcher::VariadicOperator;
+ VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+ VarOp varOp, StringRef matcherName)
+ : minCount(minCount), maxCount(maxCount), varOp(varOp),
+ matcherName(matcherName) {}
+
+ VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+ Diagnostics *error) const override {
+ if (args.size() < minCount || maxCount < args.size()) {
+ addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+ {llvm::Twine("requires between "), llvm::Twine(minCount),
+ llvm::Twine(" and "), llvm::Twine(maxCount),
+ llvm::Twine(" args, got "), llvm::Twine(args.size())});
+ return VariantMatcher();
+ }
+
+ std::vector<VariantMatcher> innerArgs;
+ for (int64_t i = 0, e = args.size(); i != e; ++i) {
+ const ParserValue &arg = args[i];
+ const VariantValue &value = arg.value;
+ if (!value.isMatcher()) {
+ addError(error, arg.range, ErrorType::RegistryWrongArgType,
+ {llvm::Twine(i + 1), llvm::Twine("matcher: "),
+ llvm::Twine(value.getTypeAsString())});
+ return VariantMatcher();
+ }
+ innerArgs.push_back(value.getMatcher());
+ }
+ return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+ }
+
+ bool isVariadic() const override { return true; }
+
+ unsigned getNumArgs() const override { return 0; }
+
+ void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+ kinds.push_back(ArgKind(ArgKind::Matcher));
+ }
+
+private:
+ const unsigned minCount;
+ const unsigned maxCount;
+ const VarOp varOp;
+ const StringRef matcherName;
+};
+
// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+ StringRef matcherName) {
+ return std::make_unique<VariadicOperatorMatcherDescriptor>(
+ MinCount, MaxCount, func.varOp, matcherName);
+}
} // namespace mlir::query::matcher::internal
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f8abf20ef60bb..6d06ca13d1344 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,7 +21,9 @@
namespace mlir::query::matcher {
-/// A class that provides utilities to find operations in the IR.
+/// Finds and collects matches from the IR. After construction
+/// `collectMatches` can be used to traverse the IR and apply
+/// matchers.
class MatchFinder {
public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..88109430b6feb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provides a
+// `match(...)` method whose parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators
+// (anyOf, allOf, etc.)
//
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
@@ -25,6 +25,15 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
@@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface {
MatcherFn matcherFn;
};
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+ SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+ VariadicMatcher(std::vector<DynMatcher> matchers)
+ : matchers(std::move(matchers)) {}
+
+ bool match(Operation *op) override { return Func(op, nullptr, matchers); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+ return Func(op, &matchedOps, matchers);
+ }
+
+private:
+ std::vector<DynMatcher> matchers;
+};
+
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
@@ -92,6 +122,31 @@ class DynMatcher {
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}
+ // Construct from a variadic function.
+ enum VariadicOperator {
+ // Matches operations for which all provided matchers match.
+ AllOf,
+ // Matches operations for which at least one of the provided matchers
+ // matches.
+ AnyOf
+ };
+
+ static std::unique_ptr<DynMatcher>
+ constructVariadic(VariadicOperator Op,
+ std::vector<DynMatcher> innerMatchers) {
+ switch (Op) {
+ case AllOf:
+ return std::make_unique<DynMatcher>(
+ new VariadicMatcher<internal::allOfVariadicOperator>(
+ std::move(innerMatchers)));
+ case AnyOf:
+ return std::make_unique<DynMatcher>(
+ new VariadicMatcher<internal::anyOfVariadicOperator>(
+ std::move(innerMatchers)));
+ }
+ llvm_unreachable("Invalid Op value.");
+ }
+
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +168,59 @@ class DynMatcher {
std::string functionName;
};
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+ VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+ : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+ operator std::unique_ptr<DynMatcher>() const & {
+ return DynMatcher::constructVariadic(
+ varOp, getMatchers(std::index_sequence_for<Ps...>()));
+ }
+
+ operator std::unique_ptr<DynMatcher>() && {
+ return DynMatcher::constructVariadic(
+ varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+ }
+
+private:
+ // Helper method to unpack the tuple into a vector.
+ template <std::size_t... Is>
+ std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+ return {DynMatcher(std::get<Is>(params))...};
+ }
+
+ template <std::size_t... Is>
+ std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+ return {DynMatcher(std::get<Is>(std::move(params)))...};
+ }
+
+ const DynMatcher::VariadicOperator varOp;
+ std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+ DynMatcher::VariadicOperator varOp;
+
+ template <typename... Ms>
+ VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+ static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+ "invalid number of parameters for variadic matcher");
+ return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+ }
+};
+
+namespace internal {
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+ anyOf = {DynMatcher::AnyOf};
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+ allOf = {DynMatcher::AllOf};
+} // namespace internal
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 441205b3a9615..7181648f06f89 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
//
//===----------------------------------------------------------------------===//
@@ -16,9 +17,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
@@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+ PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove)
+ : innerMatcher(std::move(innerMatcher)),
+ filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+ omitBlockArguments(omitBlockArguments),
+ omitUsesFromAbove(omitUsesFromAbove) {}
+
+ bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+ backwardSlice.clear();
+ BackwardSliceOptions options;
+ options.inclusive = inclusive;
+ options.omitUsesFromAbove = omitUsesFromAbove;
+ options.omitBlockArguments = omitBlockArguments;
+ if (innerMatcher.match(rootOp)) {
+ options.filter = [&](Operation *subOp) {
+ return !filterMatcher.match(subOp);
+ };
+ LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
+ assert(result.succeeded() && "expected backward slice to succeed");
+ (void)result;
+ return options.inclusive ? backwardSlice.size() > 1
+ : backwardSlice.size() >= 1;
+ }
+ return false;
+ }
+
+private:
+ BaseMatcher innerMatcher;
+ Filter filterMatcher;
+ bool inclusive;
+ bool omitBlockArguments;
+ bool omitUsesFromAbove;
+};
+
+/// Computes the forward-slice of all users reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateForwardSliceMatcher {
+public:
+ PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive)
+ : innerMatcher(std::move(innerMatcher)),
+ filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
+
+ bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
+ forwardSlice.clear();
+ ForwardSliceOptions options;
+ options.inclusive = inclusive;
+ if (innerMatcher.match(rootOp)) {
+ options.filter = [&](Operation *subOp) {
+ return !filterMatcher.match(subOp);
+ };
+ getForwardSlice(rootOp, &forwardSlice, options);
+ return options.inclusive ? forwardSlice.size() > 1
+ : forwardSlice.size() >= 1;
+ }
+ return false;
+ }
+
+private:
+ BaseMatcher innerMatcher;
+ Filter filterMatcher;
+ bool inclusive;
+};
+
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
@@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
@@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove) {
+ return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+ std::move(innerMatcher), std::move(filterMatcher), inclusive,
+ omitBlockArguments, omitUsesFromAbove);
+}
+
+/// Matches all users of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
+m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive) {
+ return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
+ std::move(innerMatcher), std::move(filterMatcher), inclusive);
+}
+
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
class VariantMatcher {
- class MatcherOps;
+ class MatcherOps {
+ public:
+ std::optional<DynMatcher>
+ constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> innerMatchers) const;
+ };
// Payload interface to be specialized by each matcher type. It follows a
// similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
// Clones the provided matcher.
static VariantMatcher SingleMatcher(DynMatcher matcher);
+ static VariantMatcher
+ VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> args);
// Makes the matcher the "null" matcher.
void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
: value(std::move(value)) {}
class SinglePayload;
+ class VariadicOpPayload;
std::shared_ptr<const Payload> value;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
+ MatchersInternal.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..01f412ade846b
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,33 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers) {
+ return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
+ if (matchedOps)
+ return matcher.match(op, *matchedOps);
+ return matcher.match(op);
+ });
+}
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers) {
+ return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
+ if (matchedOps)
+ return matcher.match(op, *matchedOps);
+ return matcher.match(op);
+ });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
unsigned argNumber = ctxEntry.second;
std::vector<ArgKind> nextTypeSet;
- if (argNumber < ctor->getNumArgs())
+ if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
ctor->getArgKinds(argNumber, nextTypeSet);
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
const internal::MatcherDescriptor &matcher = *m.getValue();
llvm::StringRef name = m.getKey();
- unsigned numArgs = matcher.getNumArgs();
+ unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
std::vector<std::vector<ArgKind>> argKinds(numArgs);
for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
}
}
...
[truncated]
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/207/builds/3192 Here is the relevant piece of the build log for the reference
|
An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of llvm#141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild. Note: the fix is included as part of the second commit
An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of llvm#141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild. Note: the fix is included as part of the second commit
An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of #141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild.
Note: the fix is included as part of the second commit