-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][Vector] Add vector.shuffle
tree transformation
#145740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR adds a new transformation that turns sequences of Example:
The algorithm leverages the structured extraction/insertion information of There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds a new transformation that turns sequences of Example:
The algorithm leverages the structured extraction/insertion information of There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nice documention! I think I get the basic idea, but I need to spend some more time getting into the details. Possible edge case to test out:
func.func @foo(
%a : vector<2xf32>,
%b : vector<1xf32>,
%c : vector<f32>,
%d : vector<f32>,
%e : vector<f32>) -> vector<6xf32> {
%0:2 = vector.to_elements %a : vector<2xf32>
%1:1 = vector.to_elements %b : vector<1xf32>
%2:1 = vector.to_elements %c : vector<f32>
%3:1 = vector.to_elements %d : vector<f32>
%4:1 = vector.to_elements %e : vector<f32>
%5 = vector.from_elements %0#0, %1#0, %2#0, %3#0, %4#0, %0#1 : vector<6xf32>
return %5 : vector<6xf32>
}
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but this rank check got me wondering. I would like to propose removing the implicit abillity to do a shape_cast out of vector.to_elements
and vector.from_elements
operations, so that they must act on rank-1 vectors. Actually I've thought this before for other Vector ops that do reshape-like things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here comes from the limitations of vector.shape
to represent n-D shuffles, not really from the vector.to_/from_elements
. That limitation is actually more like a TODO that we should address at some point.
vector.to_/from_elements
semantics naturally extend to n-D vectors given the extraction/insertion order they define but, yes, I guess we could see it as an "implicit shape cast"...
I think, though, we've been moving towards the opposite direction. To have a cohesive multi-dimensional vector layer we need these "implicit shape casts" so that ops work nicely across the board without having to special-case 1-D from n-D... This supports even more the idea that shape casts are really no-ops...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the check for scalable vectors is on the same line ... :)
FromElementsOp
doesn't support scalable vectors. It would be good to add a comment - or better yet, replace failure with notifyMatchFailure
. 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the solution to this particular TODO is to have the final shuffle create a 1-D vector that feeds a shape_cast op that is the final replacement of the from_elements.
I think that ops could still work nicely across the board with more explicit shape_casts, without detracting from their n-D nature. But if the only operation that allows the rank to change is a shape_cast, many of the (quadratic) interactions between ops would be greatly simplified. A topic for another place and time, let me focus on this PR now!
|
||
// Duplicate the last vector if the number of `vector.to_elements` is odd to | ||
// simplify the shuffle tree algorithm. | ||
if (toElementsDefs.size() % 2 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a check that it is a power of 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check refers to the number of vector.to_elements
inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense. I was thinking that you might want to 'pad' all the way to a power of 2 so that shuffles at all depths were good. But your approach of a padding to a power of 2 at each level is more efficiently (O(logN) vs O(N) padding).
++inputIdx) { | ||
auto &interval = currentLevelIntervals[inputIdx]; | ||
const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; | ||
const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to power-of-2 comment: If previous level here had 3 intervals, current level has 2. If inputIdx = 1 here, you're accessing index 3 of previous intervals -- problem? That's why I think it might be necessary to ensure the number of starting intervals is a power of 2 (stricter than just being even).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I thought I had a check to duplicate the last input, similar to the one in the constructor, but I must have removed it at some point. Let me fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Happy to clarify any questions you may have!
|
||
// Duplicate the last vector if the number of `vector.to_elements` is odd to | ||
// simplify the shuffle tree algorithm. | ||
if (toElementsDefs.size() % 2 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check refers to the number of vector.to_elements
inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
++inputIdx) { | ||
auto &interval = currentLevelIntervals[inputIdx]; | ||
const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; | ||
const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I thought I had a check to duplicate the last input, similar to the one in the constructor, but I must have removed it at some point. Let me fix that.
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here comes from the limitations of vector.shape
to represent n-D shuffles, not really from the vector.to_/from_elements
. That limitation is actually more like a TODO that we should address at some point.
vector.to_/from_elements
semantics naturally extend to n-D vectors given the extraction/insertion order they define but, yes, I guess we could see it as an "implicit shape cast"...
I think, though, we've been moving towards the opposite direction. To have a cohesive multi-dimensional vector layer we need these "implicit shape casts" so that ops work nicely across the board without having to special-case 1-D from n-D... This supports even more the idea that shape casts are really no-ops...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - this is quite involved, but you've done a great job documenting and modularising it!
The high-level logic makes sense, but some of the finer details are still unclear to me. I’ll definitely need a few more passes through it 😅
As usual, I started with the tests to get a broad overview. I’ve left a few comments there - mostly suggesting more emphasis on edge cases. Maybe you could consider grouping the tests to make those clearer?
Also, replacing some (or most) uses of failure with notifyMatchFailure would be great 🙂 #selfDocumentingCode
FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs) | ||
: fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) { | ||
|
||
assert(fromElementsOp && "from_elements op is required"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this first assert required? fromElemOp
is a mandatory argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can be null?
mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
Show resolved
Hide resolved
/// 2. Each input vector to each level is used only once. | ||
/// 3. The number of levels in the tree is: | ||
/// ceil(log2(# `vector.to_elements` ops)). | ||
/// 4. Vectors at each level of the tree have the same vector length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the inputs to vector.to_elements
don't meet this criteria?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation bails out, although it should be easy to support...
mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
Show resolved
Hide resolved
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the check for scalable vectors is on the same line ... :)
FromElementsOp
doesn't support scalable vectors. It would be good to add a comment - or better yet, replace failure with notifyMatchFailure
. 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Drop to_from
(and all the variants of that) from function names. The test file already encodes the fact that all tests exercise the vector.to_elements
+ vector.from_elements
-> vector.shuffle
.
Also, what are the high-level categories in this test files? I see two:
- genuine shuffle (e.g.
@to_from_elements_single_input_shuffle
) - concat
@to_from_elements_shuffle_tree_concat_4x8_to_32
- concat with poison values (e.g.
@to_from_elements_shuffle_tree_concat_3x4_to_12
)
Anything else? If this is correct, it would be good to clarify this split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's complicated but it general we have 3 categories: concatenations, broadcast and arbitrary shuffles. I'm using those tags in the function names. Poison vs non-poison is a bit orthogonal as poison may appear at any level of the tree (or not...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, makes sense!
What about cases where the input vectors have different length? I don't see any tests for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have shuffle_tree_arbitrary_mixed_sizes
// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to | ||
// the shuffle index within that level. | ||
|
||
func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing these function names its hard to tell what the difference is:
@to_from_elements_single_input_shuffle
,@from_elements_to_elements_single_shuffle
Wouldn't this be clearer:
@single_input
@multiple_inputs
or@two_inputs
Ultimately, it's:
// single input
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
vs
// two inputs
%2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic that I've followed is:
- Shuffles have two inputs so no need to specify "multi_input" shuffle every time. It's the "default".
- Single input shuffle is the exception so it's worth adding the "single_input" tag for it.
- Shuffle tree has multiple shuffles in general so no need to specify "multi_suffle". It's the "default"
- Single shuffle tree is the exception so it's worth adding the "single_shuffle" tag for it.
|
||
// ----- | ||
|
||
func.func @to_from_elements_shuffle_tree_concat_64x4_256( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, this example is a bit too long and I'm not sure whether it adds much unique coverage. Do we believe that jumping from e.g. 4 to 64 input vectors changes much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it changes the depth of the tree... the other tests are mostly generating 1 or 2 levels. I think it's important to test a large depth at least once.
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] | ||
/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] | ||
/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] | ||
/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing to me. Looking at this sentence above:
/// The interval of an input vector is the range of positions in the final output that the input vector contributes to.
And:
/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
I see that the range for %0 is [0, 4]
(%0#
+ %0#4
), but:
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
Could you add a bit more explanation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow the comment. Let me explain to see where the gap is:
Level 0 has 4 inputs (%2, %1, %0, %0, ...
- inputIntervalsPerLevel[0][0]
The first index corresponds to the level (0) and the second to the input at that level, so the input 0 at level 0 is %2
. That's why:
/// * inputIntervalsPerLevel[0][0] = interval(%2)
Could you help me understand what is missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see why I got confused:
/// The interval of an input vector is the range of positions in the final output that the input vector contributes to.
vs input in inputIntervalsPerLevel
. Perhaps:
inputIntervalsPerLevel
->outputIntervalsPerLevelPerVector
?outputIntervals
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intervals model what an input vector contributes to the final output but they belong to the input vectors to a specific level. They are used to produce an output interval, which will be an input interval for the next level. That is:
input vector 0 input vector 1 input vector 2 input vector 3
| | | |
Iteration 0:
input interval 0 input interval 1 input interval 2 input interval 3
\ / \ /
------------------------------------------- tree level 0 ---------------------------------------
| |
output interval 0-1 output interval 2-3
| |
Iteration 1:
| |
input interval 0-1 input interval 2-3
\ /
------------------------------------------- tree level 1 ---------------------------------------
|
output interval 0-1-2-3
I think if we use output
in the name of what is called "input interval" right now is going to be confusing.
Actually, I can add this amazing ASCII to the doc and elaborate a bit more if that helps! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or perhaps we can just call them intervals... Let me try to simplify this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this transformation in the first place? Is it to make lowering to llvm/spirv easier?
This is mostly implementing "2. Simplified Pattern Recognition and Optimization" in https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779. It's not about making the lowering easier but far more efficient both in terms of performance and compile time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback!
- Fixed value/op/interval duplication bug.
- Constrained the match to uniform vector inputs with rank.
- Added more tests
- Addressed misc. feedback
/// 2. Each input vector to each level is used only once. | ||
/// 3. The number of levels in the tree is: | ||
/// ceil(log2(# `vector.to_elements` ops)). | ||
/// 4. Vectors at each level of the tree have the same vector length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation bails out, although it should be easy to support...
mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
Show resolved
Hide resolved
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] | ||
/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] | ||
/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] | ||
/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow the comment. Let me explain to see where the gap is:
Level 0 has 4 inputs (%2, %1, %0, %0, ...
- inputIntervalsPerLevel[0][0]
The first index corresponds to the level (0) and the second to the input at that level, so the input 0 at level 0 is %2
. That's why:
/// * inputIntervalsPerLevel[0][0] = interval(%2)
Could you help me understand what is missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's complicated but it general we have 3 categories: concatenations, broadcast and arbitrary shuffles. I'm using those tags in the function names. Poison vs non-poison is a bit orthogonal as poison may appear at any level of the tree (or not...)
// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to | ||
// the shuffle index within that level. | ||
|
||
func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic that I've followed is:
- Shuffles have two inputs so no need to specify "multi_input" shuffle every time. It's the "default".
- Single input shuffle is the exception so it's worth adding the "single_input" tag for it.
- Shuffle tree has multiple shuffles in general so no need to specify "multi_suffle". It's the "default"
- Single shuffle tree is the exception so it's worth adding the "single_shuffle" tag for it.
|
||
// ----- | ||
|
||
func.func @to_from_elements_shuffle_tree_concat_64x4_256( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it changes the depth of the tree... the other tests are mostly generating 1 or 2 levels. I think it's important to test a large depth at least once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
From what I can tell, only rank-1 vectors with equal number of elements are supported? It would be good to capture this high-level restriction somewhere in the code + comments, as well as in tests.
Given the complexity of this logic, would it make sense to extend it to more generic cases?
/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] | ||
/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] | ||
/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] | ||
/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see why I got confused:
/// The interval of an input vector is the range of positions in the final output that the input vector contributes to.
vs input in inputIntervalsPerLevel
. Perhaps:
inputIntervalsPerLevel
->outputIntervalsPerLevelPerVector
?outputIntervals
?
if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs))) | ||
return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources"); | ||
|
||
int64_t numElements = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear what numElements
captures.
int64_t numElements = | |
int64_t expectedNumElementsPerVector = |
Actually, just after sending the update I realized that inputs with mixed vector sizes are supported because |
If full generality is not required, I would skip. The number of test case might get out of hand very quickly. This change is already quite large. For me it would be sufficient if you left a TODO somewhere at the top of the file and added a negative test for mixed sizes. Sometimes less is more :) |
|
I enabled it as combining elements from two vectors with arbitrary vector sizes should actually be a very common use case... and I already had tests for it. Implementation-wise, it doesn't change anything |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! Just a few more minor comments from me in this review.
Regarding the top-level motivation:
This is mostly implementing "2. Simplified Pattern Recognition and Optimization" in https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779. It's not about making the lowering easier but far more efficient both in terms of performance and compile time.
I agree with this, but can we please get a few more details? Here are my initial thoughts along this line. Suppose the total number of elements in the final from_elements vector is N
and the total number of to_elements ops is T
. So T
< N
, and the ratio N/T
is the mean number of extracted elements that go into the final vector.
My understanding is that the total number of shuffles will still be O(T)
, so the number of ops isn't reduced. But the mean size of the shuffles (mask size) will now hopefully be O(1)
because of the leaves being merged with small masks. This is better than the O(N)
with the current (naive) lowering. Is that correct, and is that where the efficiency gain will come from?
Worst case scenario for this pattern? I'm wondering what happens with something that looks like a transpose where the elements are 'scattered'. Like
[[ 0#0 1#0 ... 7#0 ]
[ 0#1 1#1 ... 7#1 ]]
If I understand correctly this case the mean size of the shuffle is still O(N) (with a better constant) because from level 1 onwards the intervals are large. This is no worse than the 'naive' case, so it seems like the pattern can only improve the situation. Wondering if anyone has an even worse case in mind.
Only partially related to this PR, but I also wonder if there is something like this is LLVMIR, where a sequence of shufflevector instructions is turned into a tree to reduce the average mask size.
/// Populate patterns to rewrite sequences of `vector.to_elements` + | ||
/// `vector.from_elements` operations into a tree of `vector.shuffle` | ||
/// operations. | ||
void populateVectorToFromElementsToShuffleTreePatterns( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you envisage this being used more directly as a pattern, or as the pass lower-vector-to-from-elements-to-shuffle-tree
? I suppose that if it were to be used only as a pass, it would be more efficient if implemented directly as a walk over all ops (just a thought, not a request for change).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different projects may apply it in different ways... that's mostly why I followed the same structure as we follow for other transformations...
// ===---------------------------------------------------------------------===// | ||
|
||
/// Compute the intervals for all the vectors in the shuffle tree. The interval | ||
/// of a vector is the range of positions that vector contributes to the final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// of a vector is the range of positions that vector contributes to the final | |
/// of a vector is the range of positions that the vector contributes to the final |
// The interval of a vector at the current level is the union of the | ||
// intervals of the two vectors from the previous level being shuffled at | ||
// this level. | ||
interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this std::min(.,.) is redundant and this is always prevLhsInterval.first
Similarly, I think the std::min(.,.) isn't needed in the interval.second case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. These were needed when I was trying other shuffling/interval schemes.
mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
Show resolved
Hide resolved
The number of shuffles is linear to the number of
More "scattered" patterns will lead to wider In terms of code size, the actual "gain" comes when we lower the
LLVM's InstCombine implements something similar but at a much higher cost. It goes through the large sequences of extractelement/insertelement instructions to reconstruct the information that InstCombine also applies some canonicalization to shufflevector instructions but I'm not aware of any major transformation happening to the actual tree shuffle sequence generated by this transformation. |
11e63c8
to
bf1c747
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Addressing some of the comments
/// Populate patterns to rewrite sequences of `vector.to_elements` + | ||
/// `vector.from_elements` operations into a tree of `vector.shuffle` | ||
/// operations. | ||
void populateVectorToFromElementsToShuffleTreePatterns( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different projects may apply it in different ways... that's mostly why I followed the same structure as we follow for other transformations...
// The interval of a vector at the current level is the union of the | ||
// intervals of the two vectors from the previous level being shuffled at | ||
// this level. | ||
interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. These were needed when I was trying other shuffling/interval schemes.
mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The high level logic makes a lot of sense to me, thanks for all the comments!
I've left a couple of minor suggestions, but these are minor and non-blocking. Approving as is.
Give the size and complexity, please wait for at least one more +1 before landing.
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] DELETEME
func.func @unsupported_multi_dim_vector_inputs(%a: vector<2x4xf32>, %b: vector<2x4xf32>) -> vector<4xf32> { | ||
%0:8 = vector.to_elements %a : vector<2x4xf32> | ||
%1:8 = vector.to_elements %b : vector<2x4xf32> | ||
%2 = vector.from_elements %0#0, %0#7, | ||
%1#0, %1#7 : vector<4xf32> | ||
return %2 : vector<4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @unsupported_multi_dim_vector_output(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<2x2xf32> { | ||
%0:8 = vector.to_elements %a : vector<8xf32> | ||
%1:8 = vector.to_elements %b : vector<8xf32> | ||
%2 = vector.from_elements %0#0, %0#7, | ||
%1#0, %1#7 : vector<2x2xf32> | ||
return %2 : vector<2x2xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing CHECK
lines
func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { | ||
%0:8 = vector.to_elements %a : vector<8xf32> | ||
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> | ||
return %1 : vector<8xf32> | ||
} | ||
|
||
// CHECK-LABEL: func @single_input_shuffle( | ||
// CHECK-SAME: %[[A:.*]]: vector<8xf32> | ||
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32> | ||
// CHECK: return %[[L0SH0]] | ||
|
||
// ----- | ||
|
||
func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the tests that follow, I would update these names for consistency - at least in my head it feels more consistent :)
func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { | |
%0:8 = vector.to_elements %a : vector<8xf32> | |
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> | |
return %1 : vector<8xf32> | |
} | |
// CHECK-LABEL: func @single_input_shuffle( | |
// CHECK-SAME: %[[A:.*]]: vector<8xf32> | |
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32> | |
// CHECK: return %[[L0SH0]] | |
// ----- | |
func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>, | |
func.func @shuffle_no_tree_single_input(%a: vector<8xf32>) -> vector<8xf32> { | |
%0:8 = vector.to_elements %a : vector<8xf32> | |
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> | |
return %1 : vector<8xf32> | |
} | |
// CHECK-LABEL: func @single_input_shuffle( | |
// CHECK-SAME: %[[A:.*]]: vector<8xf32> | |
// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32> | |
// CHECK: return %[[L0SH0]] | |
// ----- | |
func.func @shuffle_no_tree_multiple_inputs(%a: vector<8xf32>, |
%b: vector<4xf32>, | ||
%c: vector<4xf32>, | ||
%d: vector<4xf32>) -> vector<16xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Here and in other places, the indentation is off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, makes sense!
What about cases where the input vectors have different length? I don't see any tests for that?
/// Compute the intervals for all the vectors in the shuffle tree. The interval | ||
/// interval of a vector is the range of positions that the vector contributes | ||
/// to the final output vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Compute the intervals for all the vectors in the shuffle tree. The interval | |
/// interval of a vector is the range of positions that the vector contributes | |
/// to the final output vector. | |
/// Compute the intervals for all the vectors in the shuffle tree. The interval | |
/// of a vector is the range of positions that the vector contributes | |
/// to in the final output vector. |
/// algorithm generates two kinds of shuffle masks: permutation masks and | ||
/// permutation masks and propagation masks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// algorithm generates two kinds of shuffle masks: permutation masks and | |
/// permutation masks and propagation masks: | |
/// algorithm generates two kinds of shuffle masks: permutation masks and | |
/// propagation masks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Please ensure all the tests have CHECKS.
Just a thought, maybe for future consideration. What about, instead of creating a balanced tree, creating a tree which tries to shuffle vectors of similar size? i.e. at each step of construction, shuffle the 2 smallest vectors together.
Consider
// Concatenate vectors of 8, 4, 2, 1 and 1 elements. A total of 16 elements.
func.func @concat_of_sorts(%a : vector<8xf32>,
%b : vector<4xf32>,
%c : vector<2xf32>,
%d : vector<1xf32>,
%e : vector<1xf32>) -> vector<16xf32> {
%0:8 = vector.to_elements %a : vector<8xf32>
%1:4 = vector.to_elements %b : vector<4xf32>
%2:2 = vector.to_elements %c : vector<2xf32>
%3:1 = vector.to_elements %d : vector<1xf32>
%4:1 = vector.to_elements %e : vector<1xf32>
%out = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %3#0, %4#0 : vector<16xf32>
return %out : vector<16xf32>
}
Current tree generated is
(((a, b), (c, d)), ((d, P), P))
If it was
(a, (b, (c, (d, e))))
it'd have fewer nodes (avoids evenness constraint) and potentially the potentially the smallest mean mask size.
But maybe keeping masks small isn't high priority?
Another question actually. What is the motivation for
- Vectors at each level of the tree have the same vector length.
Is this because of the way vector.shuffle lowers to shufflevector? Because vector.shuffle does not have this constraint (but shufflevector does)
Thanks for the clear and interesting PR!
Thanks! Those are great ideas! Definitely something to consider for improvements... We could even have different strategies depending on the use case. My goal was to implement something "simple" and deterministic as a first step (all the simple vector shuffles can be :)). We can definitely iterate based on use cases. The other big TODO is implementing mask compression to reduce the mask size and poison spuriousness. I'm more worried about the latter than the former.
Adding to what I just said, it's a priority but not the main one. Vector sizes will be legalized anyways by the backend so no matter what we generate, the backend will split it and transform it anyways.
Yes, vector shuffle op with mixed sizes are currently lowered to insert/extract ops, which is kind of defeating the purpose of this transformation. We can relax some of these constraints but since this transformation should happen later in the pipeline, close to the lowering to LLVM, I wanted to reduce the number of steps needed to get "acceptable" code out of it. Keeping the vector size uniform is also a good "canonical" form. |
This PR adds a new transformation that turns sequences of `vector.to_elements` and `vector.from_elements` into a binary tree of `vector.shuffle` operations. (Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779). Example: ``` %0:4 = vector.to_elements %a : vector<4xf32> %1:4 = vector.to_elements %b : vector<4xf32> %2:4 = vector.to_elements %c : vector<4xf32> %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> ==> %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> ``` The algorithm leverages the structured extraction/insertion information of `vector.to_elements` and `vector.from_elements` operations and builds a set of intervals to determine the vector length that should be used at each level of the tree. There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.
bf1c747
to
e60d973
Compare
This PR adds a new transformation that turns sequences of
vector.to_elements
andvector.from_elements
into a binary tree ofvector.shuffle
operations.(Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779).
Example:
The algorithm leverages the structured extraction/insertion information of
vector.to_elements
andvector.from_elements
operations and builds a set of intervals to determine the vector length that should be used at each level of the tree to combine the level inputs in pairs.There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.