Skip to content

Commit

Permalink
[FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#…
Browse files Browse the repository at this point in the history
…1305)

Fixes #1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
  • Loading branch information
peterbell10 committed Apr 13, 2023
1 parent 5b91191 commit e152183
Show file tree
Hide file tree
Showing 23 changed files with 826 additions and 610 deletions.
28 changes: 22 additions & 6 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ namespace mlir {

class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.getOperand().getType().cast<RankedTensorType>();
explicit ReduceOpHelper(triton::ReduceOp rop)
: op(rop.getOperation()), axis(rop.getAxis()) {
auto firstTy = rop.getOperands()[0].getType().cast<RankedTensorType>();
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
srcElementTypes = rop.getElementTypes();

for (const auto &t : rop.getInputTypes()) {
if (t.getShape() != srcShape) {
rop.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
rop.emitError() << "encoding mismatch";
}
}
}

ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
ArrayRef<int64_t> getSrcShape() { return srcShape; }

Attribute getSrcLayout() { return srcTy.getEncoding(); }
Attribute getSrcLayout() { return srcEncoding; }

bool isFastReduction();

Expand All @@ -37,8 +50,11 @@ class ReduceOpHelper {
bool isSupportedLayout();

private:
triton::ReduceOp op;
RankedTensorType srcTy{};
Operation *op;
ArrayRef<int64_t> srcShape;
Attribute srcEncoding;
SmallVector<Type> srcElementTypes;
int axis;
};

bool isSharedEncoding(Value value);
Expand Down
24 changes: 0 additions & 24 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,6 @@ def TT_PaddingOptionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> {
let cppNamespace = "::mlir::triton";
}

// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
Expand Down
38 changes: 23 additions & 15 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -388,27 +388,35 @@ def TT_DotOp : TT_Op<"dot", [Pure,
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";

let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);

let results = (outs TT_Type:$result);

def TT_ReduceOp: TT_Op<"reduce",
[Pure,
SameOperandsEncoding,
SingleBlock,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Reduction using generic combination algorithm";
let arguments = (ins Variadic<TT_Tensor>:$operands, I32Attr:$axis);
let results = (outs Variadic<TT_Type>:$result);
let regions = (region SizedRegion<1>:$combineOp);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
OpBuilder<(ins "ValueRange":$operands, "int":$axis)>,
];

let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";

let hasVerifier = 1;
let hasRegionVerifier = 1;
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();
}];
}

def TT_ReduceReturnOp: TT_Op<"reduce.return",
[HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> {
let summary = "terminator for reduce operator";
let arguments = (ins Variadic<AnyType>:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}


//
// External Elementwise op
//
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
TT_Tensor:$true_value,
TT_Tensor:$false_value);

let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
}


Expand Down
5 changes: 4 additions & 1 deletion lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ void MembarAnalysis::visitTerminator(Operation *op,
return;
}
// Otherwise, it could be a return op
assert(isa<triton::ReturnOp>(op) && "Unknown terminator");
if (isa<triton::ReduceReturnOp>(op) || isa<triton::ReturnOp>(op)) {
return;
}
llvm_unreachable("Unknown terminator encountered in membar analysis");
}

void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
Expand Down
37 changes: 12 additions & 25 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,38 @@
namespace mlir {

bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
return axis == triton::gpu::getOrder(srcLayout)[0];
return axis == triton::gpu::getOrder(getSrcLayout())[0];
}

unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}

SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.getAxis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}

SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.getAxis();
SmallVector<SmallVector<unsigned>> smemShapes(3);

auto argLayout = srcTy.getEncoding();
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
Expand All @@ -64,7 +53,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {

/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);

Expand All @@ -82,17 +71,15 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
elems = product<unsigned>(smemShape);
}

auto tensorType = op.getOperand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;

if (triton::ReduceOp::withIndex(op.getRedOp()))
bytes += elems * sizeof(int32_t);

return bytes;
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
}
return bytesPerElem * elems;
}

bool ReduceOpHelper::isSupportedLayout() {
auto srcLayout = srcTy.getEncoding();
auto srcLayout = getSrcLayout();
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
return true;
}
Expand Down
14 changes: 8 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,12 +1073,14 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
#undef POPULATE_BINARY_OP

#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
Expand Down
Loading

0 comments on commit e152183

Please sign in to comment.