Skip to content

Commit

Permalink
Support simultaneous reduction of multiple tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Mar 13, 2023
1 parent 710a4bd commit eaea0a3
Show file tree
Hide file tree
Showing 14 changed files with 459 additions and 227 deletions.
49 changes: 30 additions & 19 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,39 @@
namespace mlir {

class ReduceOpHelper {
ReduceOpHelper(Operation *op, int axis, bool withIndex)
: op(op), axis(axis), withIndex(withIndex) {
srcTy = op->getOperands().front().getType().cast<RankedTensorType>();
}

public:
explicit ReduceOpHelper(triton::ReduceOp op):
ReduceOpHelper(
op.getOperation(),
op.getAxis(),
triton::ReduceOp::withIndex(op.getRedOp())) {
explicit ReduceOpHelper(triton::ReduceOp rop):
op(rop.getOperation()), axis(rop.getAxis()) {
auto srcTy = rop.getOperand().getType().cast<RankedTensorType>();
srcShape = srcTy.getShape();
srcEncoding = srcTy.getEncoding();
srcElementTypes.push_back(srcTy.getElementType());

if (triton::ReduceOp::withIndex(rop.getRedOp())) {
srcElementTypes.push_back(Builder(op).getI32Type());
}
}

explicit ReduceOpHelper(triton::GenericReduceOp op):
ReduceOpHelper(
op.getOperation(),
op.getAxis(),
/*withIndex*/false) {
explicit ReduceOpHelper(triton::GenericReduceOp rop):
op(rop.getOperation()), axis(rop.getAxis()) {
auto firstTy = rop.getOperands()[0].getType().cast<RankedTensorType>();
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
srcElementTypes = rop.getElementTypes();

for (const auto &t : rop.getInputTypes()) {
if (t.getShape() != srcShape) {
rop.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
rop.emitError() << "encoding mismatch";
}
}
}

ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
ArrayRef<int64_t> getSrcShape() { return srcShape; }

Attribute getSrcLayout() { return srcTy.getEncoding(); }
Attribute getSrcLayout() { return srcEncoding; }

bool isFastReduction();

Expand All @@ -51,9 +61,10 @@ class ReduceOpHelper {

private:
Operation *op;
RankedTensorType srcTy{};
ArrayRef<int64_t> srcShape;
Attribute srcEncoding;
SmallVector<Type> srcElementTypes;
int axis;
bool withIndex;
};

bool isSharedEncoding(Value value);
Expand Down
17 changes: 13 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,25 @@ def TT_ReduceOp : TT_Op<"reduce", [Pure,
def TT_GenericReduceOp: TT_Op<"generic_reduce",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>, SingleBlock]> {
let summary = "Reduction using generic combination algorithm";
let arguments = (ins TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Type:$result);
let regions = (region SizedRegion<1>:$region);
let arguments = (ins Variadic<TT_Tensor>:$operands, I32Attr:$axis);
let results = (outs Variadic<TT_Type>:$result);
let regions = (region SizedRegion<1>:$combineOp);
let builders = [
OpBuilder<(ins "ValueRange":$operands, "int":$axis)>,
];
let hasVerifier = 1;
let hasRegionVerifier = 1;
let extraClassDeclaration = [{
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();
}];
}

def TT_GenericReduceReturnOp: TT_Op<"generic_reduce.return",
[HasParent<"GenericReduceOp">, Pure, Terminator, ReturnLike]> {
let summary = "terminator for reduce operator";
let arguments = (ins AnyType:$result);
let arguments = (ins Variadic<AnyType>:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}

Expand Down
10 changes: 5 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";

Expand All @@ -78,7 +78,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise,
}

def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";

Expand All @@ -100,10 +100,10 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
let description = [{}];

let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
TT_Type:$true_value,
TT_Type:$false_value);

let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
}


Expand Down
27 changes: 10 additions & 17 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,24 @@
namespace mlir {

bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
return axis == triton::gpu::getOrder(srcLayout)[0];
return axis == triton::gpu::getOrder(getSrcLayout())[0];
}

unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}
Expand All @@ -46,7 +41,7 @@ SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
SmallVector<SmallVector<unsigned>> smemShapes(3);

auto argLayout = srcTy.getEncoding();
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
Expand Down Expand Up @@ -76,13 +71,11 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
elems = product<unsigned>(smemShape);
}

auto tensorType = op->getOperand(0).getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;

if (withIndex)
bytes += elems * sizeof(int32_t);

return bytes;
unsigned bytes_per_elem = 0;
for (const auto &ty: srcElementTypes) {
bytes_per_elem += ty.getIntOrFloatBitWidth() / 8;
}
return bytes_per_elem * elems;
}

bool isSharedEncoding(Value value) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
#undef POPULATE_BINARY_OP

#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
Expand Down
Loading

0 comments on commit eaea0a3

Please sign in to comment.