Skip to content

Commit

Permalink
[StableHLO] Port gather to linalg lowering pattern (iree-org#13779)
Browse files Browse the repository at this point in the history
This pattern fell through the cracks during the initial porting of
hlo-to-linalg lowering in iree-org#12957.

With this pattern and the most recent canon patterns, we produce the
same code as the mhlo input conversion pipeline on the input from
iree-org#13729.

Also fixed issues with undefined FileCheck variables in tests.

Issue: iree-org#12678
  • Loading branch information
kuhar authored and nhasabni committed Aug 24, 2023
1 parent 5723b11 commit cfc60ab
Show file tree
Hide file tree
Showing 7 changed files with 756 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,210 @@ class MapOpToMapConverter : public OpConversionPattern<mlir::stablehlo::MapOp> {
}
};

/// This lowering encompasses the full range of the Gather operation and
/// therefore is very general and just loops over the output and calculate the
/// corresponding input index. It follows the explanation at
/// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler
/// should be able to optimize that a bit, but in order to get efficient
/// lowerings, special-cases of gather should be extracted in separate
/// lowerings, and ideally encapsulated as separate ops or canonicalization
/// patterns.
struct GatherConversion final : OpConversionPattern<mlir::stablehlo::GatherOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::stablehlo::GatherOp gatherOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Location loc = gatherOp.getLoc();

Value startIndices = adaptor.getStartIndices();
Value operand = adaptor.getOperand();

auto resultType =
getTypeConverter()->convertType<RankedTensorType>(gatherOp.getType());
RankedTensorType startIndicesType =
dyn_cast<RankedTensorType>(startIndices.getType());
// We could actually deal with an unranked result by inferring the result
// rank, but the current reifyReturnTypes doesn't support unranked either.
if (!resultType || !startIndicesType) {
return rewriter.notifyMatchFailure(gatherOp,
"unranked start indices or result");
}

int64_t resultRank = resultType.getRank();
// slice_sizes has to have the same size as operand.rank, and doing it this
// way permits an unranked operand.
int64_t operandRank = gatherOp.getSliceSizes().getNumElements();

int64_t indexVectorDim = gatherOp.getDimensionNumbers().getIndexVectorDim();

ArrayRef<int64_t> offsetDims =
gatherOp.getDimensionNumbers().getOffsetDims();
ArrayRef<int64_t> collapsedSliceDims =
gatherOp.getDimensionNumbers().getCollapsedSliceDims();
ArrayRef<int64_t> startIndexMap =
gatherOp.getDimensionNumbers().getStartIndexMap();

// We'll need these later and creating them on demand we end up with
// duplicates, which also makes lit tests really hard to write.
SmallVector<Value> constants;
for (int64_t i = 0, e = std::max({resultRank, operandRank, int64_t{2}});
i < e; ++i) {
constants.push_back(
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)));
}

Value emptyOp = getEmptyTensorFor(rewriter, loc, resultType, gatherOp,
adaptor.getOperands());

ValueRange ins;
SmallVector<AffineMap, 1> indexingMaps(
{rewriter.getMultiDimIdentityMap(resultRank)});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/resultType,
/*inputs=*/ins,
/*outputs=*/emptyOp, indexingMaps, getNParallelLoopsAttrs(resultRank),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(gatherOp));

// Now populate the linalg generic region
Region& region = linalgOp.getRegion();
Block* block = rewriter.createBlock(&region, region.end());
block->addArguments(resultType.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);

// Dimensions in the result that aren't offset dimensions are called batch.
SmallVector<int64_t> batchDims;
for (int64_t dim = 0; dim < resultRank; ++dim) {
if (!llvm::is_contained(offsetDims, dim)) {
batchDims.push_back(dim);
}
}

// Same as with the constants. Creating these all up front is easier than
// potentially getting duplicates later.
SmallVector<Value> linalgIndices;
for (int64_t i = 0; i < resultRank; ++i) {
linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
}

// Now the complicated part. For a given output dimension we build up an
// index into the input. It's composed of two parts: the index coming from
// start_indices, and the offset from that index along the offset
// dimensions. Everything includes dimension shuffling and remapping as well
// because of the way gather is defined to allow for any-layout input by
// adding more attributes.

// The base gather index (`G` in the documentation) points to a place in
// start_indices along the batch dimensions.
SmallVector<Value> gatherIndex;
for (int64_t dim : batchDims) {
gatherIndex.push_back(linalgIndices[dim]);
}

SmallVector<Value> indexFromStartIndices;
for (size_t i = 0, e = startIndexMap.size(); i != e; ++i) {
// The index along the index_vector dimension of start_indices varies.
// Basically indexFromStartIndices indexes into a "row" along
// index_vector_dim, where the row is selected by the current output
// index.
// But if index_vector_dim is equal to start_indices.rank, then
// start_indices gets a trailing 1 dimension added. So the row we're
// extracting always has length 1 and the index into it is always 0, so we
// just use the gather index directly
SmallVector<Value> gCombine(gatherIndex);
if (indexVectorDim != startIndicesType.getRank()) {
assert(indexVectorDim <= static_cast<int64_t>(gCombine.size()));
gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]);
}

indexFromStartIndices.push_back(extractIndexFromTensor(
rewriter, loc, startIndices, gatherOp.getStartIndices().getType(),
gCombine));
}

// But then start indices are shuffled by the start index map. To make a
// full index into the operand, all missing indices are zeroes.
SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]);
for (auto [idx, value] : llvm::enumerate(startIndexMap)) {
remappedIndexFromIndices[value] = indexFromStartIndices[idx];
}

// Now we construct the index based on the offset. First we need to remap
// the offset dimensions by dropping the collapsed indices.
SmallVector<unsigned> remappedOffsetDims;
for (int64_t i = 0; i < operandRank; ++i) {
if (!llvm::is_contained(collapsedSliceDims, i)) {
remappedOffsetDims.push_back(static_cast<unsigned>(i));
}
}

assert(remappedOffsetDims.size() == offsetDims.size());

// Clamp out of bounds indices.
for (int i = 0, operandIndexDim = 0; i < operandRank; ++i) {
// Compute the size of the output shape dimension corresponding to this
// index dimension. If it's collapsed set it to 1.
Value outputDimSize = constants[1];
if (!llvm::is_contained(collapsedSliceDims, i)) {
outputDimSize = rewriter.createOrFold<tensor::DimOp>(
loc, emptyOp, offsetDims[operandIndexDim++]);
}

// If this is a skipped dimension, we're done and don't have to clamp.
if (remappedIndexFromIndices[i] == constants[0]) continue;

Value operandDimSize =
rewriter.createOrFold<tensor::DimOp>(loc, operand, i);
Value largestValidIndex = rewriter.createOrFold<arith::SubIOp>(
loc, operandDimSize, outputDimSize);

// Clamp indices to [0, i, operand_dim-output_dim].
Value clamp = rewriter.create<arith::MinSIOp>(
loc,
rewriter.create<arith::MaxSIOp>(loc, constants[0],
remappedIndexFromIndices[i]),
largestValidIndex);
remappedIndexFromIndices[i] = clamp;
}

// For the (remapped) offset dimensions, the index is the current index in
// the output. As before this is expanded to a full index into the operand
// by using zeros for the missing indices.
SmallVector<Value> indexFromOffset(operandRank, constants[0]);
for (auto [remappedOffsetDim, offsetDim] :
llvm::zip_equal(remappedOffsetDims, offsetDims)) {
indexFromOffset[remappedOffsetDim] = linalgIndices[offsetDim];
}

// Now we add together our two indices to get the final index into the
// operand.
SmallVector<Value> combinedIndex;
for (int64_t i = 0; i < operandRank; ++i)
combinedIndex.push_back(rewriter.createOrFold<arith::AddIOp>(
loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
indexFromOffset[i]));

Value extractOperand;
if (isa<RankedTensorType>(operand.getType())) {
extractOperand = operand;
} else {
// Cannot extract from unranked tensors, cast to ranked first.
SmallVector<int64_t> dims(operandRank, ShapedType::kDynamic);
auto type = RankedTensorType::get(
dims, cast<TensorType>(operand.getType()).getElementType());
extractOperand = rewriter.create<tensor::CastOp>(loc, type, operand);
}
Value element =
rewriter.create<tensor::ExtractOp>(loc, extractOperand, combinedIndex);
rewriter.create<linalg::YieldOp>(loc, element);

rewriter.replaceOp(gatherOp, linalgOp.getResults());

return success();
}
};

/// Converts xla-hlo.select_and_scatter op to a sequence of linalg.generics ops.
/// The current version computes the scattered index and populates the correct
/// value for each tile. It does not currently handle overlapping tiles.
Expand Down Expand Up @@ -2451,6 +2655,7 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context,
ConcatenateConverter,
ConstConverterTensor,
EinsumToLinalgConverter,
GatherConversion,
RealDynamicSliceConverter,
ReshapeOpConverter,
ReverseConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"stablehlo_to_linalg_convolution.mlir",
"stablehlo_to_linalg_dot_prod.mlir",
"stablehlo_to_linalg_ext.mlir",
"stablehlo_to_linalg_gather.mlir",
"stablehlo_to_linalg_pointwise.mlir",
"stablehlo_to_linalg_random.mlir",
"stablehlo_to_linalg_reduce.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_lit_test_suite(
"stablehlo_to_linalg_convolution.mlir"
"stablehlo_to_linalg_dot_prod.mlir"
"stablehlo_to_linalg_ext.mlir"
"stablehlo_to_linalg_gather.mlir"
"stablehlo_to_linalg_pointwise.mlir"
"stablehlo_to_linalg_random.mlir"
"stablehlo_to_linalg_reduce.mlir"
Expand Down
Loading

0 comments on commit cfc60ab

Please sign in to comment.