Skip to content

Commit

Permalink
[XLA:GPU][MLIR-Based emitters] Fix canonicalization patterns for Appl…
Browse files Browse the repository at this point in the history
…yIndexingOp.

PiperOrigin-RevId: 631526519
  • Loading branch information
pifon2a authored and tensorflower-gardener committed May 7, 2024
1 parent 49d3c9f commit 2100915
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 39 deletions.
96 changes: 67 additions & 29 deletions third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h"

#include <cstdint>
#include <optional>
#include <utility>
#include <vector>

Expand All @@ -35,6 +36,7 @@ limitations under the License.
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project // IWYU pragma: keep
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
Expand Down Expand Up @@ -160,6 +162,16 @@ void AllocateSharedOp::getAsmResultNames(
// ApplyIndexingOp
//===----------------------------------------------------------------------===//

void ApplyIndexingOp::build(OpBuilder &builder, OperationState &result,
ValueRange dims, ValueRange symbols,
const IndexingMap &indexing_map) {
SmallVector<Value, 4> operands;
operands.reserve(dims.size() + symbols.size());
operands.append(dims.begin(), dims.end());
operands.append(symbols.begin(), symbols.end());
build(builder, result, operands, indexing_map);
}

void ApplyIndexingOp::build(OpBuilder &builder, OperationState &result,
ValueRange operands,
const IndexingMap &indexing_map) {
Expand Down Expand Up @@ -306,10 +318,6 @@ LogicalResult ApplyIndexingOp::verify() {
"operand, lower_bounds, upper_bounds count and affine map dimension "
"and symbol count must match");
}
IndexingMap indexing_map = getIndexingMap();
if (indexing_map.IsKnownEmpty()) {
return emitOpError("indexing map is empty");
}
return success();
}

Expand Down Expand Up @@ -384,41 +392,66 @@ struct FoldApplyIndexingOperands
MLIRContext *ctx = affine_map.getContext();
unsigned num_operands = indexing_op->getNumOperands();
unsigned num_dims = affine_map.getNumDims();
llvm::SmallBitVector constant_operands(num_operands, false);
mlir::DenseMap<AffineExpr, AffineExpr> replacements;
unsigned num_nonconstant_operands = num_operands;
unsigned num_symbols = affine_map.getNumSymbols();

SmallVector<std::optional<int64_t>> constant_values(num_operands,
std::nullopt);
bool constant_found = false;
SmallVector<int64_t> dim_id_map(num_dims, -1);
SmallVector<int64_t> symbol_id_map(num_symbols, -1);
for (auto &operand : indexing_op->getOpOperands()) {
unsigned operand_number = operand.getOperandNumber();
auto constant = operand.get().getDefiningOp<arith::ConstantIndexOp>();
if (!constant) continue;

unsigned operand_number = operand.getOperandNumber();
replacements[operand_number < num_dims
? getAffineDimExpr(operand_number, ctx)
: getAffineSymbolExpr(operand_number - num_dims, ctx)] =
getAffineConstantExpr(constant.value(), ctx);
constant_operands.set(operand_number);
--num_nonconstant_operands;
constant_values[operand_number] = constant.value();
constant_found = true;
}
if (replacements.empty()) {
if (!constant_found) {
return rewriter.notifyMatchFailure(indexing_op,
"No constant operands found");
}
unsigned new_num_dims = 0;
unsigned new_num_symbols = 0;
SmallVector<AffineExpr, 2> dim_replacements, symbol_replacements;
dim_replacements.reserve(num_dims);
symbol_replacements.reserve(num_symbols);

unsigned new_num_operands = new_num_dims + new_num_symbols;
SmallVector<Value, 4> new_operands;
new_operands.reserve(num_nonconstant_operands);
ArrayRef<int64_t> lbs = indexing_op.getLowerBounds();
ArrayRef<int64_t> ubs = indexing_op.getUpperBounds();
new_operands.reserve(new_num_operands);
SmallVector<int64_t, 4> new_lbs, new_ubs;
new_lbs.reserve(num_nonconstant_operands);
new_ubs.reserve(num_nonconstant_operands);
for (auto [index, operand] : llvm::enumerate(indexing_op.getOperands())) {
if (constant_operands[index]) continue;
new_operands.push_back(operand);
new_lbs.push_back(lbs[index]);
new_ubs.push_back(ubs[index]);
new_lbs.reserve(new_num_operands);
new_ubs.reserve(new_num_operands);

for (auto [operand, constant_value, lb, ub] : llvm::zip(
indexing_op->getOpOperands(), constant_values,
indexing_op.getLowerBounds(), indexing_op.getUpperBounds())) {
unsigned operand_id = operand.getOperandNumber();
if (constant_value.has_value()) {
if (operand_id < num_dims) {
dim_replacements.push_back(
getAffineConstantExpr(*constant_value, ctx));
} else {
symbol_replacements.push_back(
getAffineConstantExpr(*constant_value, ctx));
}
} else {
if (operand_id < num_dims) {
dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx));
} else {
symbol_replacements.push_back(
getAffineSymbolExpr(new_num_symbols++, ctx));
}
new_operands.push_back(operand.get());
new_lbs.push_back(lb);
new_ubs.push_back(ub);
}
}
rewriter.replaceOpWithNewOp<ApplyIndexingOp>(
indexing_op, new_operands, affine_map.replace(replacements), new_lbs,
new_ubs);
indexing_op, new_operands,
affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements,
new_num_dims, new_num_symbols),
new_lbs, new_ubs);
return success();
}
};
Expand All @@ -432,15 +465,20 @@ struct FoldApplyIndexingResults
PatternRewriter &rewriter) const override {
mlir::Location loc = indexing_op.getLoc();
IndexingMap indexing_map = indexing_op.getIndexingMap();
if (indexing_map.IsKnownEmpty()) {
return rewriter.notifyMatchFailure(indexing_op,
"Domain of the indexing map is empty");
}
AffineMap *affine_map = &indexing_map.GetMutableAffineMap();
unsigned num_results = affine_map->getNumResults();
SmallVector<AffineExpr, 4> new_exprs;
new_exprs.reserve(num_results);
SmallVector<Value, 4> new_values;
new_values.reserve(num_results);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (mlir::OpResult opresult : indexing_op->getOpResults()) {
if (opresult.use_empty()) {
new_values.push_back(opresult);
new_values.push_back(zero);
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> {
let results = (outs Variadic<Index>);

let builders = [
OpBuilder<(ins "mlir::ValueRange":$dims, "mlir::ValueRange":$symbols,
"const IndexingMap&":$indexing_map)>,
OpBuilder<(ins "mlir::ValueRange":$operands,
"const IndexingMap&":$indexing_map)>,
OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,17 @@ func.func @fold_operands(%d0: index) -> index {

// CHECK-LABEL: func.func @fold_operands
// CHECK-SAME: %[[ARG_0:.*]]: index)
// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10])
// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10])

// -----

func.func @fold_operands_and_results(%arg0: index, %arg1: index)
-> (index, index) {
%0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)>
(%arg0 in [0, 4], %arg1 in [0, 5])
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func.func @fold_operands_and_results
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index)
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,4 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index)
// expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}}
%0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2])
func.return %0#0, %0#1 : index, index
}

// -----

#map0 = affine_map<(d0) -> (d0)>
func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> index {
// expected-error @+1 {{indexing map is empty}}
%0 = xla_gpu.apply_indexing #map0 (%d0 in [100, 0])
func.return %0 : index
}

0 comments on commit 2100915

Please sign in to comment.