Skip to content

Commit

Permalink
[StableHLO] Add canonicalization pattern for select (iree-org#13768)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar authored and nhasabni committed Aug 24, 2023
1 parent da1cf1f commit 81aea5c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -383,6 +384,60 @@ struct CompareOpCanon final : OpRewritePattern<mlir::stablehlo::CompareOp> {
}
};

struct SelectOpCanon final : OpRewritePattern<mlir::stablehlo::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type) return failure();

Value trueVal = op.getOnTrue();
Value falseVal = op.getOnFalse();

// Eliminate select with two identical outcomes.
if (trueVal == falseVal) {
rewriter.replaceOp(op, trueVal);
return success();
}

// Simplify when the condition is a constant.
Value pred = op.getPred();
DenseElementsAttr cond;
if (!matchPattern(pred, m_Constant(&cond))) {
return failure();
}

// Handle splat predicate and select either `trueVal` or `falseVal`.
if (cond.isSplat()) {
rewriter.replaceOp(op, cond.getSplatValue<bool>() ? trueVal : falseVal);
return success();
}

// Handle elementwise selection when both outcomes are also constants. This
// will create a new, likely non-splat constant.
if (cond.getNumElements() > kFoldOpEltLimit) return failure();

DenseElementsAttr trueAttr;
if (!matchPattern(trueVal, m_Constant(&trueAttr))) return failure();

DenseElementsAttr falseAttr;
if (!matchPattern(falseVal, m_Constant(&falseAttr))) return failure();

SmallVector<Attribute> newValues;
newValues.reserve(cond.getNumElements());
for (auto [condElem, trueElem, falseElem] : llvm::zip_equal(
cond.getValues<bool>(), trueAttr.getValues<Attribute>(),
falseAttr.getValues<Attribute>())) {
newValues.push_back(condElem ? trueElem : falseElem);
}

rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
op, DenseElementsAttr::get(type, newValues));
return success();
}
};

struct BroadcastInDimOpCanon final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -648,7 +703,7 @@ void populateCanonicalizationPatterns(MLIRContext *context,
PatternBenefit benefit) {
patterns->add<
// Arithmetic ops.
AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon,
AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon,
// Complex ops.
RealOpCanon, ImagOpCanon,
// Query ops.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,39 @@ func.func @compare_folds()

// -----

// CHECK-LABEL: func.func @select
// CHECK-SAME: ([[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARGC:%.+]]: tensor<2xi1>)
func.func @select(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %argC: tensor<2xi1>)
-> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<4xi32>) {
%c0 = stablehlo.constant dense<false> : tensor<i1>
%c1 = stablehlo.constant dense<true> : tensor<i1>

%c0x2 = stablehlo.constant dense<false> : tensor<2xi1>
%c1x2 = stablehlo.constant dense<true> : tensor<2xi1>

%cond = stablehlo.constant dense<[false, true, false, true]> : tensor<4xi1>
%foo = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%bar = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi32>

%0 = stablehlo.select %argC, %arg0, %arg0 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = stablehlo.select %c0, %arg0, %arg1 : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%2 = stablehlo.select %c1, %arg0, %arg1 : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%3 = stablehlo.select %c0x2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%4 = stablehlo.select %c1x2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%5 = stablehlo.select %argC, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>

%6 = stablehlo.select %cond, %foo, %bar : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.select [[ARGC]], [[ARG0]], [[ARG1]]
// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<[5, 2, 7, 4]> : tensor<4xi32>

// CHECK-NEXT: return [[ARG0]], [[ARG1]], [[ARG0]], [[ARG1]], [[ARG0]], [[R0]], [[C0]]
return %0, %1, %2, %3, %4, %5, %6 :
tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<4xi32>
}

// -----

// CHECK-LABEL: func.func @broadcast_in_dim
// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>)
func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>)
Expand Down

0 comments on commit 81aea5c

Please sign in to comment.