Skip to content

Commit dec8af7

Browse files
committed
[mlir] Move SelectOp from Standard to Arithmetic
This is part of splitting up the standard dialect. See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for discussion. Differential Revision: https://reviews.llvm.org/D118648
1 parent 6a8ba31 commit dec8af7

File tree

116 files changed

+1033
-1135
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+1033
-1135
lines changed

flang/lib/Optimizer/Builder/Character.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ mlir::Value genMin(fir::FirOpBuilder &builder, mlir::Location loc,
434434
mlir::Value a, mlir::Value b) {
435435
auto cmp =
436436
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, b);
437-
return builder.create<mlir::SelectOp>(loc, cmp, a, b);
437+
return builder.create<mlir::arith::SelectOp>(loc, cmp, a, b);
438438
}
439439

440440
void fir::factory::CharacterExprHelper::createAssign(
@@ -532,7 +532,8 @@ fir::CharBoxValue fir::factory::CharacterExprHelper::createSubstring(
532532
auto zero = builder.createIntegerConstant(loc, substringLen.getType(), 0);
533533
auto cdt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
534534
substringLen, zero);
535-
substringLen = builder.create<mlir::SelectOp>(loc, cdt, zero, substringLen);
535+
substringLen =
536+
builder.create<mlir::arith::SelectOp>(loc, cdt, zero, substringLen);
536537

537538
return {substringRef, substringLen};
538539
}
@@ -570,8 +571,8 @@ fir::factory::CharacterExprHelper::createLenTrim(const fir::CharBoxValue &str) {
570571
// Compute length after iteration (zero if all blanks)
571572
mlir::Value newLen =
572573
builder.create<arith::AddIOp>(loc, iterWhile.getResult(1), one);
573-
auto result =
574-
builder.create<mlir::SelectOp>(loc, iterWhile.getResult(0), zero, newLen);
574+
auto result = builder.create<mlir::arith::SelectOp>(
575+
loc, iterWhile.getResult(0), zero, newLen);
575576
return builder.createConvert(loc, builder.getCharacterLengthType(), result);
576577
}
577578

flang/lib/Optimizer/Builder/MutableBox.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,8 @@ void fir::factory::genReallocIfNeeded(fir::FirOpBuilder &builder,
675675
// reallocate = reallocate || previous != required
676676
auto cmp = builder.create<arith::CmpIOp>(
677677
loc, arith::CmpIPredicate::ne, castPrevious, required);
678-
mustReallocate =
679-
builder.create<mlir::SelectOp>(loc, cmp, cmp, mustReallocate);
678+
mustReallocate = builder.create<mlir::arith::SelectOp>(
679+
loc, cmp, cmp, mustReallocate);
680680
};
681681
llvm::SmallVector<mlir::Value> previousLbounds;
682682
llvm::SmallVector<mlir::Value> previousExtents =

flang/lib/Optimizer/Builder/Runtime/Numeric.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ mlir::Value fir::runtime::genNearest(fir::FirOpBuilder &builder,
295295
mlir::Value False = builder.createIntegerConstant(loc, boolTy, 0);
296296
mlir::Value True = builder.createIntegerConstant(loc, boolTy, 1);
297297

298-
mlir::Value positive = builder.create<mlir::SelectOp>(loc, cmp, True, False);
298+
mlir::Value positive =
299+
builder.create<mlir::arith::SelectOp>(loc, cmp, True, False);
299300
auto args = fir::runtime::createArguments(builder, loc, funcTy, x, positive);
300301

301302
return builder.create<fir::CallOp>(loc, func, args).getResult(0);

flang/lib/Optimizer/Transforms/RewriteLoop.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
7575
auto cond = rewriter.create<mlir::arith::CmpIOp>(
7676
loc, arith::CmpIPredicate::sle, iters, zero);
7777
auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
78-
iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters);
78+
iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters);
7979
}
8080

8181
llvm::SmallVector<mlir::Value> loopOperands;

flang/test/Fir/loop02.fir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func private @y(%addr : !fir.ref<index>)
2323
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
2424
// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_6]] : index
2525
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
26-
// CHECK: %[[VAL_9:.*]] = select %[[VAL_7]], %[[VAL_8]], %[[VAL_5]] : index
26+
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_7]], %[[VAL_8]], %[[VAL_5]] : index
2727
// CHECK: br ^bb1(%[[VAL_1]], %[[VAL_9]] : index, index)
2828
// CHECK: ^bb1(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
2929
// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index

flang/unittests/Optimizer/Builder/Runtime/NumericTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ void testGenNearest(fir::FirOpBuilder &builder, mlir::Type xType,
5656
checkCallOp(nearest.getDefiningOp(), fctName, 2, /*addLocArg=*/false);
5757
auto callOp = mlir::dyn_cast<fir::CallOp>(nearest.getDefiningOp());
5858
mlir::Value select = callOp.getOperands()[1];
59-
EXPECT_TRUE(mlir::isa<mlir::SelectOp>(select.getDefiningOp()));
60-
auto selectOp = mlir::dyn_cast<mlir::SelectOp>(select.getDefiningOp());
59+
EXPECT_TRUE(mlir::isa<mlir::arith::SelectOp>(select.getDefiningOp()));
60+
auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(select.getDefiningOp());
6161
mlir::Value cmp = selectOp.getCondition();
6262
EXPECT_TRUE(mlir::isa<mlir::arith::CmpFOp>(cmp.getDefiningOp()));
6363
auto cmpOp = mlir::dyn_cast<mlir::arith::CmpFOp>(cmp.getDefiningOp());

mlir/benchmark/python/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setup_passes(mlir_module):
2929
f"convert-scf-to-std,"
3030
f"func-bufferize,"
3131
f"arith-bufferize,"
32-
f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
32+
f"builtin.func(tensor-bufferize,finalizing-bufferize),"
3333
f"convert-vector-to-llvm"
3434
f"{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
3535
f"lower-affine,"

mlir/docs/Bufferization.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ The code, slightly simplified and annotated, is reproduced here:
8787
pm.addNestedPass<FuncOp>(createTCPBufferizePass()); // Bufferizes the downstream `tcp` dialect.
8888
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
8989
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
90-
pm.addNestedPass<FuncOp>(createStdBufferizePass());
9190
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
9291
pm.addPass(createFuncBufferizePass());
9392

mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/OpDefinition.h"
1313
#include "mlir/IR/OpImplementation.h"
1414
#include "mlir/Interfaces/CastInterfaces.h"
15+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"
1617
#include "mlir/Interfaces/VectorInterfaces.h"
1718

mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
10881088
%x = arith.cmpi "eq", %lhs, %rhs : vector<4xi64>
10891089

10901090
// Generic form of the same operation.
1091-
%x = "std.arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64}
1091+
%x = "arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64}
10921092
: (vector<4xi64>, vector<4xi64>) -> vector<4xi1>
10931093
```
10941094
}];
@@ -1161,4 +1161,55 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
11611161
let hasFolder = 1;
11621162
}
11631163

1164+
//===----------------------------------------------------------------------===//
1165+
// SelectOp
1166+
//===----------------------------------------------------------------------===//
1167+
1168+
def SelectOp : Arith_Op<"select", [
1169+
AllTypesMatch<["true_value", "false_value", "result"]>
1170+
] # ElementwiseMappable.traits> {
1171+
let summary = "select operation";
1172+
let description = [{
1173+
The `arith.select` operation chooses one value based on a binary condition
1174+
supplied as its first operand. If the value of the first operand is `1`,
1175+
the second operand is chosen, otherwise the third operand is chosen.
1176+
The second and the third operand must have the same type.
1177+
1178+
The operation applies to vectors and tensors elementwise given the _shape_
1179+
of all operands is identical. The choice is made for each element
1180+
individually based on the value at the same position as the element in the
1181+
condition operand. If an i1 is provided as the condition, the entire vector
1182+
or tensor is chosen.
1183+
1184+
Example:
1185+
1186+
```mlir
1187+
// Custom form of scalar selection.
1188+
%x = arith.select %cond, %true, %false : i32
1189+
1190+
// Generic form of the same operation.
1191+
%x = "arith.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
1192+
1193+
// Element-wise vector selection.
1194+
%vx = arith.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32>
1195+
1196+
// Full vector selection.
1197+
%vx = arith.select %cond, %vtrue, %vfalse : vector<42xf32>
1198+
```
1199+
}];
1200+
1201+
let arguments = (ins BoolLike:$condition,
1202+
AnyType:$true_value,
1203+
AnyType:$false_value);
1204+
let results = (outs AnyType:$result);
1205+
1206+
let hasCanonicalizer = 1;
1207+
let hasFolder = 1;
1208+
let hasVerifier = 1;
1209+
1210+
// FIXME: Switch this to use the declarative assembly format.
1211+
let printer = [{ return ::print(p, *this); }];
1212+
let parser = [{ return ::parse$cppClass(parser, result); }];
1213+
}
1214+
11641215
#endif // ARITHMETIC_OPS

0 commit comments

Comments
 (0)