Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite ReduceOp to support arbitrary reduce operations #1305

Merged
merged 23 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5f6e3de
Add GenericReduceOp to ttir and add `tl.prod` using it
peterbell10 Mar 7, 2023
786433e
Lower tt.generic_reduce to LLVM IR
peterbell10 Mar 8, 2023
ab461e8
Support simultaneous reduction of multiple tensors
peterbell10 Mar 13, 2023
0aba718
Automatically build reduction combine op region from JITFunction
peterbell10 Mar 14, 2023
08c2e35
Replace old ReduceOp entirely
peterbell10 Mar 14, 2023
4b74ce3
Misc cleanup
peterbell10 Mar 14, 2023
4b2b16a
Add SameOperandsEncoding
peterbell10 Mar 16, 2023
b3957a7
Run clang-format
peterbell10 Mar 16, 2023
bbec2fe
Fix lit tests
peterbell10 Mar 20, 2023
7e195a6
Update to newer LLVM
peterbell10 Mar 20, 2023
0f7a528
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Mar 23, 2023
19d490b
Lint
peterbell10 Mar 23, 2023
c5d928b
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Mar 30, 2023
c6f777b
Fix merge conflicts
peterbell10 Mar 30, 2023
440a39d
Respond to some review comments
peterbell10 Apr 4, 2023
3db5241
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 4, 2023
c404afe
Merge remote-tracking branch 'upstream/main' into HEAD
peterbell10 Apr 7, 2023
19d31c6
Don't rematerialize ReduceOp
peterbell10 Apr 10, 2023
a6ae9e7
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 10, 2023
0cd8f0f
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 11, 2023
c7c8ac1
Revert "Don't rematerialize ReduceOp"
peterbell10 Apr 12, 2023
768241d
Merge remote-tracking branch 'upstream/main' into generic-reduction
peterbell10 Apr 12, 2023
04c2164
Merge branch 'main' into generic-reduction
ptillet Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ namespace mlir {

class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.getOperand().getType().cast<RankedTensorType>();
explicit ReduceOpHelper(triton::ReduceOp rop)
: op(rop.getOperation()), axis(rop.getAxis()) {
auto firstTy = rop.getOperands()[0].getType().cast<RankedTensorType>();
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
srcElementTypes = rop.getElementTypes();

for (const auto &t : rop.getInputTypes()) {
if (t.getShape() != srcShape) {
rop.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
rop.emitError() << "encoding mismatch";
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
ArrayRef<int64_t> getSrcShape() { return srcShape; }

Attribute getSrcLayout() { return srcTy.getEncoding(); }
Attribute getSrcLayout() { return srcEncoding; }

bool isFastReduction();

Expand All @@ -37,8 +50,11 @@ class ReduceOpHelper {
bool isSupportedLayout();

private:
triton::ReduceOp op;
RankedTensorType srcTy{};
Operation *op;
ArrayRef<int64_t> srcShape;
Attribute srcEncoding;
SmallVector<Type> srcElementTypes;
int axis;
};

bool isSharedEncoding(Value value);
Expand Down
24 changes: 0 additions & 24 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,6 @@ def TT_PaddingOptionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> {
let cppNamespace = "::mlir::triton";
}

// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
Expand Down
38 changes: 23 additions & 15 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -388,27 +388,35 @@ def TT_DotOp : TT_Op<"dot", [Pure,
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";

let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);

let results = (outs TT_Type:$result);

def TT_ReduceOp: TT_Op<"reduce",
[Pure,
SameOperandsEncoding,
SingleBlock,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Reduction using generic combination algorithm";
let arguments = (ins Variadic<TT_Tensor>:$operands, I32Attr:$axis);
let results = (outs Variadic<TT_Type>:$result);
let regions = (region SizedRegion<1>:$combineOp);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
OpBuilder<(ins "ValueRange":$operands, "int":$axis)>,
];

let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";

let hasVerifier = 1;
let hasRegionVerifier = 1;
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
unsigned getNumOperands();
}];
}

def TT_ReduceReturnOp: TT_Op<"reduce.return",
Jokeren marked this conversation as resolved.
Show resolved Hide resolved
[HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> {
let summary = "terminator for reduce operator";
let arguments = (ins Variadic<AnyType>:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}


//
// External Elementwise op
//
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
TT_Tensor:$true_value,
TT_Tensor:$false_value);

let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
}


Expand Down
5 changes: 4 additions & 1 deletion lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ void MembarAnalysis::visitTerminator(Operation *op,
return;
}
// Otherwise, it could be a return op
assert(isa<triton::ReturnOp>(op) && "Unknown terminator");
if (isa<triton::ReduceReturnOp>(op) || isa<triton::ReturnOp>(op)) {
return;
}
llvm_unreachable("Unknown terminator encountered in membar analysis");
}

void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
Expand Down
37 changes: 12 additions & 25 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,38 @@
namespace mlir {

bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
return axis == triton::gpu::getOrder(srcLayout)[0];
return axis == triton::gpu::getOrder(getSrcLayout())[0];
}

unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(srcLayout)[axis]);
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}

SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.getAxis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}

SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.getAxis();
SmallVector<SmallVector<unsigned>> smemShapes(3);

auto argLayout = srcTy.getEncoding();
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
Expand All @@ -64,7 +53,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {

/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);

Expand All @@ -82,17 +71,15 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
elems = product<unsigned>(smemShape);
}

auto tensorType = op.getOperand().getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;

if (triton::ReduceOp::withIndex(op.getRedOp()))
bytes += elems * sizeof(int32_t);

return bytes;
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
}
return bytesPerElem * elems;
}

bool ReduceOpHelper::isSupportedLayout() {
auto srcLayout = srcTy.getEncoding();
auto srcLayout = getSrcLayout();
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
return true;
}
Expand Down
14 changes: 8 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,12 +1073,14 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
#undef POPULATE_BINARY_OP

#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
Expand Down
Loading