From e152183570b26cbdbfdd280cf3c6187d27539dc5 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 13 Apr 2023 01:37:39 +0000 Subject: [PATCH] [FRONTEND][BACKEND] ReduceOp to support arbitrary reduce operations (#1305) Fixes #1285 This changes `tt.reduce` to replace `redOp` by a region containing arbitrary code. For example, `tl.sum` is now lowered as: ```mlir %res = "tt.reduce"(%arg0) ({ ^bb0(%arg1: f32, %arg2: f32): %add = arith.addf %arg1, %arg2 : f32 tt.reduce.return %add : f32 }) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32> ``` Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example `argmin` gets lowered as: ```mlir %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32> %9:2 = "tt.reduce"(%6, %8) ({ ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): %14 = arith.cmpf olt, %arg4, %arg6 : f32 %15 = arith.cmpf ogt, %arg4, %arg6 : f32 %16 = arith.cmpi slt, %arg5, %arg7 : i32 %17 = arith.select %16, %arg5, %arg7 : i32 %18 = arith.select %15, %arg7, %17 : i32 %19 = arith.select %14, %arg5, %18 : i32 %20 = arith.cmpf olt, %arg4, %arg6 : f32 %21 = arith.select %20, %arg4, %arg6 : f32 tt.reduce.return %21, %19 : f32, i32 }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>) ``` --- include/triton/Analysis/Utility.h | 28 +- .../Dialect/Triton/IR/TritonAttrDefs.td | 24 - include/triton/Dialect/Triton/IR/TritonOps.td | 38 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 +- lib/Analysis/Membar.cpp | 5 +- lib/Analysis/Utility.cpp | 37 +- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 14 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 482 ++++++++---------- .../TritonGPUToLLVM/TypeConverter.cpp | 23 +- .../TritonToTritonGPUPass.cpp | 51 +- lib/Dialect/Triton/IR/Ops.cpp | 130 ++++- .../Transforms/RemoveLayoutConversions.cpp | 64 ++- .../Transforms/TritonGPUConversion.cpp | 2 + python/src/triton.cc | 44 +- python/test/unit/language/test_core.py | 13 +- python/triton/compiler/code_generator.py | 81 +-- python/triton/language/core.py | 163 +++++- python/triton/language/semantic.py | 109 +--- test/Analysis/test-allocation.mlir | 6 +- test/Analysis/test-membar.mlir | 6 +- test/Conversion/triton_ops.mlir | 48 +- test/Conversion/triton_to_tritongpu.mlir | 32 +- test/TritonGPU/combine.mlir | 34 +- 23 files changed, 826 insertions(+), 610 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 9c8a6fb8a059..e7873c646842 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -12,13 +12,26 @@ namespace mlir { class ReduceOpHelper { public: - explicit ReduceOpHelper(triton::ReduceOp op) : op(op) { - srcTy = op.getOperand().getType().cast(); + explicit ReduceOpHelper(triton::ReduceOp rop) + : op(rop.getOperation()), axis(rop.getAxis()) { + auto firstTy = rop.getOperands()[0].getType().cast(); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = rop.getElementTypes(); + + for (const auto &t : rop.getInputTypes()) { + if (t.getShape() != srcShape) { + rop.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + rop.emitError() << "encoding mismatch"; + } + } } - ArrayRef getSrcShape() { return srcTy.getShape(); } + ArrayRef getSrcShape() { return srcShape; } - Attribute getSrcLayout() { return srcTy.getEncoding(); } + Attribute getSrcLayout() { return srcEncoding; } bool isFastReduction(); @@ -37,8 +50,11 @@ class ReduceOpHelper { bool isSupportedLayout(); private: - triton::ReduceOp op; - RankedTensorType srcTy{}; + Operation *op; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; + int axis; }; bool isSharedEncoding(Value value); diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index c6b5cb9a6043..bb0e6f30b5cc 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -34,30 +34,6 @@ def TT_PaddingOptionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -// reduction -def TT_RedOpAttr : I32EnumAttr< - /*name*/"RedOp", /*summary*/"", - /*case*/ - [ - I32EnumAttrCase, - I32EnumAttrCase<"FADD", 2, "fadd">, - I32EnumAttrCase<"MIN", 3, "min">, - I32EnumAttrCase<"MAX", 4, "max">, - I32EnumAttrCase<"UMIN", 5, "umin">, - I32EnumAttrCase<"UMAX", 6, "umax">, - I32EnumAttrCase<"ARGMIN", 7, "argmin">, - I32EnumAttrCase<"ARGMAX", 8, "argmax">, - I32EnumAttrCase<"ARGUMIN", 9, "argumin">, - I32EnumAttrCase<"ARGUMAX", 10, "argumax">, - I32EnumAttrCase<"FMIN", 11, "fmin">, - I32EnumAttrCase<"FMAX", 12, "fmax">, - I32EnumAttrCase<"ARGFMIN", 13, "argfmin">, - I32EnumAttrCase<"ARGFMAX", 14, "argfmax">, - I32EnumAttrCase<"XOR", 15, "xor"> - ]> { - let cppNamespace = "::mlir::triton"; -} - // atomic def TT_AtomicRMWAttr : I32EnumAttr< "RMWOp", "", diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index f3263162616b..f29628d9450f 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -388,27 +388,35 @@ def TT_DotOp : TT_Op<"dot", [Pure, // // Reduce Op // -def TT_ReduceOp : TT_Op<"reduce", [Pure, - DeclareOpInterfaceMethods]> { - let summary = "reduce"; - - let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); - - let results = (outs TT_Type:$result); - +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$operands, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); let builders = [ - OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>, + OpBuilder<(ins "ValueRange":$operands, "int":$axis)>, ]; - - let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; - + let hasVerifier = 1; + let hasRegionVerifier = 1; let extraClassDeclaration = [{ - // This member function is marked static because we need to call it before the ReduceOp - // is constructed, see the implementation of create_reduce in triton.cc. - static bool withIndex(mlir::triton::RedOp redOp); + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); }]; } +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + // // External Elementwise op // diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 422f1bdaeaf4..b774fd54684d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -103,7 +103,7 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, TT_Tensor:$true_value, TT_Tensor:$false_value); - let results = (outs TT_Tensor:$result); + let results = (outs TT_Type:$result); } diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 5b2b15abcb03..3d045d2de0ec 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -74,7 +74,10 @@ void MembarAnalysis::visitTerminator(Operation *op, return; } // Otherwise, it could be a return op - assert(isa(op) && "Unknown terminator"); + if (isa(op) || isa(op)) { + return; + } + llvm_unreachable("Unknown terminator encountered in membar analysis"); } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 9d925a11e86d..dfd6302e74fc 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,49 +10,38 @@ namespace mlir { bool ReduceOpHelper::isFastReduction() { - auto srcLayout = srcTy.getEncoding(); - auto axis = op.getAxis(); - return axis == triton::gpu::getOrder(srcLayout)[0]; + return axis == triton::gpu::getOrder(getSrcLayout())[0]; } unsigned ReduceOpHelper::getInterWarpSize() { - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); - auto axis = op.getAxis(); auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, - triton::gpu::getWarpsPerCTA(srcLayout)[axis]); + triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); - auto axis = op.getAxis(); auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, - triton::gpu::getThreadsPerWarp(srcLayout)[axis]); + triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { - auto srcLayout = srcTy.getEncoding(); - auto axis = op.getAxis(); + auto srcLayout = getSrcLayout(); return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * triton::gpu::getWarpsPerCTA(srcLayout)[axis]; } SmallVector ReduceOpHelper::getScratchConfigBasic() { - auto axis = op.getAxis(); auto smemShape = convertType(getSrcShape()); smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis()); return smemShape; } SmallVector> ReduceOpHelper::getScratchConfigsFast() { - auto axis = op.getAxis(); SmallVector> smemShapes(3); - auto argLayout = srcTy.getEncoding(); + auto argLayout = getSrcLayout(); auto argLayoutMma = argLayout.dyn_cast(); if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1) @@ -64,7 +53,7 @@ SmallVector> ReduceOpHelper::getScratchConfigsFast() { /// FIXME(Qingyi): This size is actually larger than required. /// shared memory block1: - auto mod = op.getOperation()->getParentOfType(); + auto mod = op->getParentOfType(); unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); smemShapes[1].push_back(numWarps * 32); @@ -82,17 +71,15 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { elems = product(smemShape); } - auto tensorType = op.getOperand().getType().cast(); - unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8; - - if (triton::ReduceOp::withIndex(op.getRedOp())) - bytes += elems * sizeof(int32_t); - - return bytes; + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += ty.getIntOrFloatBitWidth() / 8; + } + return bytesPerElem * elems; } bool ReduceOpHelper::isSupportedLayout() { - auto srcLayout = srcTy.getEncoding(); + auto srcLayout = getSrcLayout(); if (srcLayout.isa()) { return true; } diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index c4924151bc00..ee2c3187ff35 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1073,12 +1073,14 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) - POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & - POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | - POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ - POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << - POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> - POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 1eb225dd1857..9f8181b0378b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -23,112 +23,59 @@ struct ReduceOpConversion } private: - void accumulate(ConversionPatternRewriter &rewriter, Location loc, - RedOp redOp, Value &acc, Value cur, bool isFirst) const { + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + llvm::SmallVectorImpl &acc, ValueRange cur, + bool isFirst) const { if (isFirst) { - acc = cur; + acc.resize(cur.size()); + for (unsigned i = 0; i < cur.size(); ++i) { + acc[i] = cur[i]; + } return; } - switch (redOp) { - case RedOp::ADD: - acc = add(acc, cur); - break; - case RedOp::FADD: - acc = fadd(acc.getType(), acc, cur); - break; - case RedOp::MIN: - acc = smin(acc, cur); - break; - case RedOp::MAX: - acc = smax(acc, cur); - break; - case RedOp::UMIN: - acc = umin(acc, cur); - break; - case RedOp::UMAX: - acc = umax(acc, cur); - break; - case RedOp::FMIN: - acc = fmin(acc, cur); - break; - case RedOp::FMAX: - acc = fmax(acc, cur); - break; - case RedOp::XOR: - acc = xor_(acc, cur); - break; - case RedOp::ARGMIN: - case RedOp::ARGMAX: - case RedOp::ARGUMIN: - case RedOp::ARGUMAX: - case RedOp::ARGFMIN: - case RedOp::ARGFMAX: - llvm::report_fatal_error( - "This accumulate implementation is not for argmin / argmax"); - default: - llvm::report_fatal_error("Unsupported reduce op"); + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = dyn_cast(newReduce.getTerminator()); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; } - } - void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc, - RedOp redOp, Value &acc, Value &accIndex, Value cur, - Value curIndex, bool isFirst) const { - if (isFirst) { - acc = cur; - accIndex = curIndex; - return; + rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; } - switch (redOp) { - case RedOp::ARGMIN: - accIndex = select( - icmp_slt(acc, cur), accIndex, - select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = smin(acc, cur); - break; - case RedOp::ARGMAX: - accIndex = select( - icmp_sgt(acc, cur), accIndex, - select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = smax(acc, cur); - break; - case RedOp::ARGUMIN: - accIndex = select( - icmp_ult(acc, cur), accIndex, - select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = umin(acc, cur); - break; - case RedOp::ARGUMAX: - accIndex = select( - icmp_ugt(acc, cur), accIndex, - select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = umax(acc, cur); - break; - case RedOp::ARGFMIN: - accIndex = select( - fcmp_olt(acc, cur), accIndex, - select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = fmin(acc, cur); - break; - case RedOp::ARGFMAX: - accIndex = select( - fcmp_ogt(acc, cur), accIndex, - select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = fmax(acc, cur); - break; - case RedOp::ADD: - case RedOp::FADD: - case RedOp::MIN: - case RedOp::MAX: - case RedOp::UMIN: - case RedOp::UMAX: - case RedOp::FMIN: - case RedOp::FMAX: - case RedOp::XOR: - llvm::report_fatal_error( - "This accumulate implementation is only for argmin / argmax"); - default: - llvm::report_fatal_error("Unsupported reduce op"); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = getTypeConverter()->unpackLLElements(loc, operands[i], + rewriter, types[i]); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } } + return srcValues; } // Calculates the write index in the shared memory where we would be writing @@ -177,63 +124,64 @@ struct ReduceOpConversion matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { ReduceOpHelper helper(op); - Location loc = op->getLoc(); + Location loc = op.getLoc(); unsigned axis = op.getAxis(); - // Specifies whether the reduce operation returns an index - // rather than a value, e.g. argmax, argmin, .. etc - bool withIndex = triton::ReduceOp::withIndex(op.getRedOp()); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding(); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout(); if (!helper.isSupportedLayout()) { assert(false && "Unexpected srcLayout in ReduceOpConversion"); } // The order of the axes for the the threads within the warp auto srcOrd = triton::gpu::getOrder(srcLayout); auto sizePerThread = triton::gpu::getSizePerThread(srcLayout); - auto srcShape = srcTy.getShape(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - auto smemShape = helper.getScratchConfigBasic(); unsigned elems = product(smemShape); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); - unsigned srcElems = getElemsPerThread(srcTy); + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = + bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), + elemPtrTys[i]); + } + + unsigned srcElems = getElemsPerThread(srcTys[0]); // Emits indices of the original tensor that each thread // would own - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + // Emits offsets (the offset from the base index) // of the original tensor that each thread would own + // NOTE: Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); + emitOffsetForLayout(srcLayout, srcTys[0]); + // Keep track of accumulations and their indices - std::map, Value> accs; - std::map, Value> accIndices; + std::map, SmallVector> accs; std::map, SmallVector> indices; + Region *combineOp = &op.getCombineOp(); + // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i], - isFirst); - } else { - Value curIndex = srcIndices[i][axis]; - accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key], - accIndices[key], srcValues[i], curIndex, isFirst); - } + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -250,24 +198,20 @@ struct ReduceOpConversion // reduce across threads for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; - Value accIndex; - if (withIndex) - accIndex = accIndices[key]; + auto &acc = it.second; // get the writeIdx at which to write in smem SmallVector writeIdx; getWriteIndexBasic(rewriter, loc, srcLayout, indices[key], writeIdx, ints, axis); + // calculate the offset in smem for that writeIdx Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - // Get element pointers for the value and index - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); - // Store the within-thread accumulated value at writePtr - store(acc, writePtr); - // Store the index of within-thread accumulation at indexWritePtr - if (withIndex) - store(accIndex, indexWritePtr); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + // Store the within-thread accumulated value into shared memory + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + store(acc[i], writePtrs[i]); + } SmallVector readIdx(writeIdx.size(), ints[0]); // Perform parallel reduction with sequential addressing @@ -286,27 +230,24 @@ struct ReduceOpConversion Value readOffset = select( readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), ints[0]); - // The readPtr is readOffset away from writePtr - Value readPtr = gep(elemPtrTy, writePtr, readOffset); + SmallVector readPtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + // The readPtr is readOffset away from writePtr + readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); + } + + barrier(); + // Combine accumulator value from another thread + SmallVector cur(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + cur[i] = load(readPtrs[i]); + } + accumulate(rewriter, *combineOp, acc, cur, false); + barrier(); - // If we do not care about the index, i.e. this is not an argmax, - // argmin, .. etc - if (!withIndex) { - // The value at the readPtr, whereas acc is the value at writePtr - Value cur = load(readPtr); - accumulate(rewriter, loc, op.getRedOp(), acc, cur, false); - barrier(); - // Update writePtr value - store(acc, writePtr); - } else { - Value cur = load(readPtr); - Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset); - Value curIndex = load(indexReadPtr); - accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, cur, - curIndex, false); - barrier(); - store(acc, writePtr); - store(accIndex, indexWritePtr); + // Publish our new accumulator value to shared memory + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + store(acc[i], writePtrs[i]); } } } @@ -314,33 +255,37 @@ struct ReduceOpConversion barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding(); - auto resultShape = resultTy.getShape(); - - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + + auto resultLayout = resultTy.getEncoding(); + + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (unsigned j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, ints[0]); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, srcOrd); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + results[i] = getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); - rewriter.replaceOp(op, resultVal); } + auto parentBlock = op.getOperation()->getBlock(); + rewriter.replaceOp(op, results); return success(); } @@ -351,60 +296,59 @@ struct ReduceOpConversion ReduceOpHelper helper(op); Location loc = op->getLoc(); unsigned axis = adaptor.getAxis(); - bool withIndex = triton::ReduceOp::withIndex(op.getRedOp()); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding(); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout(); if (!helper.isSupportedLayout()) { assert(false && "Unexpected srcLayout in ReduceOpConversion"); } - auto srcShape = srcTy.getShape(); - auto order = getOrder(srcLayout); - - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + auto srcOrd = triton::gpu::getOrder(srcLayout); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); auto smemShapes = helper.getScratchConfigsFast(); unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = + bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), + elemPtrTys[i]); + } unsigned sizeIntraWarps = helper.getIntraWarpSize(); unsigned sizeInterWarps = helper.getInterWarpSize(); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + + // Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); + emitOffsetForLayout(srcLayout, srcTys[0]); - std::map, Value> accs; - std::map, Value> accIndices; - std::map, SmallVector> indices; + auto *combineOp = &op.getCombineOp(); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i], - isFirst); - } else { - Value curIndex = srcIndices[i][axis]; - accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key], - accIndices[key], srcValues[i], curIndex, isFirst); - } + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -414,6 +358,9 @@ struct ReduceOpConversion Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimWarpId = @@ -427,32 +374,24 @@ struct ReduceOpConversion for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; - Value accIndex; - if (withIndex) - accIndex = accIndices[key]; + SmallVector acc = it.second; // Reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false); - } else { - Value shflIndex = shflSync(loc, rewriter, accIndex, N); - accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl, - shflIndex, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); } + accumulate(rewriter, *combineOp, acc, shfl, false); } SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, laneZero); - if (withIndex) { - Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); - storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset); + storeShared(rewriter, loc, writePtr, acc[i], laneZero); } } @@ -469,39 +408,36 @@ struct ReduceOpConversion unsigned elemsPerThread = std::max(elems / numThreads, 1); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { - Value readPtr = gep(elemPtrTy, smemBase, readOffset); // FIXME(Qingyi): need predicate icmp_slt(threadId, // i32_val(sizeInerWarps)) - Value acc = load(readPtr); - Value accIndex; - if (withIndex) { - Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset); - accIndex = load(readIndexPtr); + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + acc[i] = load(readPtr); } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false); - } else { - Value shflIndex = shflSync(loc, rewriter, accIndex, N); - accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl, - shflIndex, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); } + accumulate(rewriter, *combineOp, acc, shfl, false); } // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + } Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); - storeShared(rewriter, loc, writePtr, acc, pred); - if (withIndex) { - Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset); - storeShared(rewriter, loc, writeIndexPtr, accIndex, pred); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + storeShared(rewriter, loc, writePtrs[i], acc[i], pred); } if (round != elemsPerThread - 1) { @@ -515,32 +451,34 @@ struct ReduceOpConversion barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding().cast(); - auto resultShape = resultTy.getShape(); - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (size_t i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShapes[0], order); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + + results[i] = getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); - rewriter.replaceOp(op, resultVal); } + rewriter.replaceOp(op, results); return success(); } diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index cada8fab16c2..06025f7bb976 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -46,15 +46,30 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( Value TritonGPUToLLVMTypeConverter::packLLElements( Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type type) { - auto structType = this->convertType(type); - if (!structType.isa()) { + auto structType = this->convertType(type).dyn_cast(); + if (!structType) { + assert(resultVals.size() == 1); return *resultVals.begin(); } + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } Value llvmStruct = rewriter.create(loc, structType); - // llvm::outs() << structType << "\n"; for (const auto &v : llvm::enumerate(resultVals)) { - assert(v.value() && "can not insert null values"); + if (!v.value()) { + emitError(loc) + << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + emitError(loc) << "invalid element type in packLLEElements. Expected " + << elementTypes[v.index()] << " but got " + << v.value().getType(); + } llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); } return llvmStruct; diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 921cb6b8d29d..8b393cd6eccf 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -68,13 +68,15 @@ class ArithConstantPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); - assert(value); - if (value.getElementType().isInteger(1) && value.isSplat()) - // Workaround until https://reviews.llvm.org/D133743 is included. - value = DenseElementsAttr::get(retType, value.getSplatValue()); - else - // This is a hack. We just want to add encoding - value = value.reshape(retType); + if (dyn_cast(retType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = DenseElementsAttr::get(retType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retType); + } addNamedAttrs( rewriter.replaceOpWithNewOp(op, retType, value), adaptor.getAttributes()); @@ -469,10 +471,28 @@ struct TritonReducePattern : public OpConversionPattern { LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - addNamedAttrs( - rewriter.replaceOpWithNewOp( - op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis()), - adaptor.getAttributes()); + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonReduceReturnPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getResult()), + adaptor.getAttributes()); return success(); } }; @@ -517,10 +537,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonCatPattern, - TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, - TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, - TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern, - TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context); + TritonReducePattern, TritonReduceReturnPattern, TritonTransPattern, + TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, + TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, + TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>( + typeConverter, context); } // diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index c52f321455cc..3d97b4e985bf 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -310,21 +310,10 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( } //-- ReduceOp -- -mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // infer shape - Value arg = operands[0]; - auto argTy = arg.getType().cast(); - auto argEltTy = argTy.getElementType(); - auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); - auto redOp = - attributes.get("redOp").cast().getValue(); - bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); - auto retEltTy = withIndex ? i32Ty : argEltTy; +static mlir::LogicalResult +inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, + int axis, SmallVectorImpl &inferredReturnTypes) { auto retShape = argTy.getShape().vec(); - int axis = attributes.get("axis").cast().getInt(); retShape.erase(retShape.begin() + axis); if (retShape.empty()) { // 0d-tensor -> scalar @@ -352,15 +341,114 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( return mlir::success(); } -bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { - return redOp == mlir::triton::RedOp::ARGMIN || - redOp == mlir::triton::RedOp::ARGMAX || - redOp == mlir::triton::RedOp::ARGUMIN || - redOp == mlir::triton::RedOp::ARGUMAX || - redOp == mlir::triton::RedOp::ARGFMIN || - redOp == mlir::triton::RedOp::ARGFMAX; +void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = operands[i].getType().cast(); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); } +mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) { + auto argTy = arg.getType().cast(); + auto retEltTy = argTy.getElementType(); + int axis = attributes.get("axis").cast().getInt(); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +mlir::LogicalResult mlir::triton::ReduceOp::verify() { + if (this->getOperands().size() < 1) { + return this->emitOpError() << "must have at least 1 operand"; + } + for (const auto &operand : this->getOperands()) { + if (!dyn_cast(operand.getType())) { + return this->emitOpError() << "operands must be RankedTensorType"; + } + } + return success(); +} + +mlir::LogicalResult mlir::triton::ReduceOp::verifyRegions() { + auto argElementTypes = this->getElementTypes(); + const auto &operands = this->getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *this->getBody(); + if (block.getNumArguments() != numArgs) { + return this->emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return this->emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = + dyn_cast(block.getTerminator()); + if (!terminator) { + return this->emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return this->emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return this->emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return mlir::success(); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + llvm::SmallVector srcTys; + srcTys.reserve(this->getNumOperands()); + for (const auto &ty : this->getOperands().getTypes()) { + srcTys.push_back(ty.cast()); + } + return srcTys; +} + +llvm::SmallVector ReduceOp::getElementTypes() { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(this->getNumOperands()); + for (const auto &op : this->getOperands()) { + srcElemTys.push_back( + op.getType().cast().getElementType()); + } + return srcElemTys; +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + //-- SplatOp -- OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto value = adaptor.getSrc(); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 579852c10a0b..fc34862ec7dd 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -101,29 +101,59 @@ class SimplifyReduceCvt : public mlir::RewritePattern { auto convert = llvm::cast(op); triton::ReduceOp reduce; for (auto &use : convert.getResult().getUses()) { - auto owner = use.getOwner(); - if (llvm::isa_and_nonnull(owner)) { - reduce = llvm::cast(owner); - break; + auto owner = llvm::dyn_cast(use.getOwner()); + if (!owner) { + continue; + } + + // TODO: This only moves conversions from the first argument which is + // fine for argmin/argmax but may not be optimal generally + if (convert.getResult() != owner.getOperands()[0]) { + continue; } + reduce = owner; + break; } if (!reduce) return mlir::failure(); + + SmallVector newOperands = reduce.getOperands(); + + newOperands[0] = convert.getOperand(); + auto newEncoding = + newOperands[0].getType().cast().getEncoding(); + // this may generate unsupported conversions in the LLVM codegen - if (convert.getOperand() - .getType() - .cast() - .getEncoding() - .isa()) - return mlir::failure(); + if (newEncoding.isa()) { + return failure(); + } + + for (unsigned i = 1; i < newOperands.size(); ++i) { + auto oldTy = newOperands[i].getType().cast(); + RankedTensorType newTy = + RankedTensorType::Builder(oldTy).setEncoding(newEncoding); + + newOperands[i] = rewriter.create( + op->getLoc(), newTy, newOperands[i]); + } + auto newReduce = rewriter.create( - op->getLoc(), reduce.getRedOp(), convert.getOperand(), - reduce.getAxis()); - Value newRet = newReduce.getResult(); - if (newRet.getType() != reduce.getResult().getType()) - newRet = rewriter.create( - op->getLoc(), reduce.getResult().getType(), newRet); - rewriter.replaceAllUsesWith(reduce, newRet); + op->getLoc(), newOperands, reduce.getAxis()); + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.inlineRegionBefore(reduce.getCombineOp(), newCombineOp, + newCombineOp.end()); + + SmallVector newRet = newReduce.getResult(); + auto oldTypes = reduce.getResult().getType(); + for (unsigned i = 0; i < reduce.getNumOperands(); ++i) { + // it's still beneficial to move the conversion + // to after the reduce if necessary since it will be + // done on a rank-reduced tensor hence cheaper + if (newRet[i].getType() != oldTypes[i]) + newRet[i] = rewriter.create( + op->getLoc(), oldTypes[i], newRet[i]); + } + rewriter.replaceAllUsesWith(reduce.getResult(), newRet); return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 1a791d07a97e..d20b287015e8 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -79,6 +79,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( // Some ops from SCF are illegal addIllegalOp(); + // We have custom versions of some arith operators + addIllegalOp(); addDynamicallyLegalDialect(m, "REDUCE_OP") - .value("ADD", mlir::triton::RedOp::ADD) - .value("FADD", mlir::triton::RedOp::FADD) - .value("MIN", mlir::triton::RedOp::MIN) - .value("MAX", mlir::triton::RedOp::MAX) - .value("UMIN", mlir::triton::RedOp::UMIN) - .value("UMAX", mlir::triton::RedOp::UMAX) - .value("ARGMIN", mlir::triton::RedOp::ARGMIN) - .value("ARGMAX", mlir::triton::RedOp::ARGMAX) - .value("ARGUMIN", mlir::triton::RedOp::ARGUMIN) - .value("ARGUMAX", mlir::triton::RedOp::ARGUMAX) - .value("FMIN", mlir::triton::RedOp::FMIN) - .value("FMAX", mlir::triton::RedOp::FMAX) - .value("ARGFMIN", mlir::triton::RedOp::ARGFMIN) - .value("ARGFMAX", mlir::triton::RedOp::ARGFMAX) - .value("XOR", mlir::triton::RedOp::XOR); - py::enum_(m, "ATOMIC_OP") .value("ADD", mlir::triton::RMWOp::ADD) .value("FADD", mlir::triton::RMWOp::FADD) @@ -1349,21 +1332,20 @@ void init_triton_ir(py::module &&m) { return self.create(loc, val); }) .def("create_reduce", - [](mlir::OpBuilder &self, mlir::Value &operand, - mlir::triton::RedOp redOp, int axis) -> mlir::Value { - auto loc = self.getUnknownLoc(); - auto inputTensorType = - operand.getType().dyn_cast(); - std::vector shape = inputTensorType.getShape(); - shape.erase(shape.begin() + axis); - bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); - mlir::Type resType = withIndex ? self.getI32Type() - : inputTensorType.getElementType(); - if (!shape.empty()) { - resType = mlir::RankedTensorType::get(shape, resType); + [](mlir::OpBuilder &self, std::vector operands, + int axis) -> mlir::OpState { + auto loc = self.getUnknownLoc(); + return self.create(loc, operands, axis); + }) + .def("create_reduce_ret", + [](mlir::OpBuilder &self, py::args args) -> mlir::OpState { + auto loc = self.getUnknownLoc(); + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); } - return self.create(loc, resType, redOp, - operand, axis); + return self.create(loc, + return_values); }) .def("create_ptr_to_int", [](mlir::OpBuilder &self, mlir::Value &val, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e3dbd789cac0..ca18212f9be6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1295,10 +1295,15 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): %12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr, #blocked>, tensor<{rdims_2d}xi32, #blocked> %13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf32, #blocked> %14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xf32, #blocked>) -> tensor<{M}x{N}xf32, #src> - %15 = tt.reduce %14 {{axis = {axis} : i32, redOp = 12 : i32}} : tensor<{M}x{N}xf32, #src> -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>> - %16 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>> - %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked> - tt.store %12, %17 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: f32, %arg4: f32): + %16 = "triton_gpu.cmpf"(%arg3, %arg4) {{predicate = 2 : i64}} : (f32, f32) -> i1 + %17 = arith.select %16, %arg3, %arg4 : f32 + tt.reduce.return %17 : f32 + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xf32, #src>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>> + %18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>> + %19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked> + tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked> tt.return }} }} diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 3d4a5dd45add..a0b42c512fee 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1,4 +1,5 @@ import ast +import inspect import re import sys import warnings @@ -755,6 +756,43 @@ def visit_Assert(self, node) -> Any: # Convert assert to triton's device_assert which happens on the device return language.core.device_assert(test, msg, _builder=self.builder) + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) + else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug) + generator.visit(fn.parse()) + callee_ret_type = generator.last_ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func)) @@ -768,44 +806,13 @@ def visit_Call(self, node): if not self.debug: return if isinstance(fn, JITFunction): - from inspect import getcallargs - args = getcallargs(fn.fn, *args, **kws) - args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_tensor(arg) - else constexpr(arg) for arg in args] - # generate function def - attributes = dict() - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) - # generate function def if necessary - if not self.module.has_function(fn_name): - prototype = language.function_type([], arg_types) - gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug) - generator.visit(fn.parse()) - callee_ret_type = generator.last_ret_type - self.function_ret_types[fn_name] = callee_ret_type - else: - callee_ret_type = self.function_ret_types[fn_name] - symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: - return None - elif call_op.get_num_results() == 1: - return tensor(call_op.get_result(0), callee_ret_type) - else: - # should return a tuple of tl.tensor - results = [] - for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) + return self.call_JitFunction(fn, args, kws) if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): - return fn(*args, _builder=self.builder, **kws) + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + return fn(*args, **extra_kwargs, **kws) if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) return fn(*args, **kws) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 9a06d951a335..1431da509f42 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import contextmanager from enum import Enum from functools import wraps from typing import Callable, List, TypeVar @@ -1190,46 +1191,166 @@ def _decorator(func: T) -> T: return _decorator +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + @builtin -@_add_reduction_docstr("maximum") -def max(input, axis, _builder=None): +def reduction(input, axis, combine_fn, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + + """ + if isinstance(input, tensor): + return reduction((input,), axis, combine_fn, + _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) + for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + axis = _constexpr_to_value(axis) - return semantic.max(input, axis, _builder) + return semantic.reduction(input, axis, make_combine_region, _builder) @builtin -@_add_reduction_docstr("maximum index") -def argmax(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.argmax(input, axis, _builder) +def _promote_reduction_input(t, _builder=None): + scalar_ty = t.type.scalar + # input is extended to 32-bits if necessary + # this increases numerical accuracy and can be done pretty much for free + # on GPUs + if scalar_ty.is_int() and scalar_ty.int_bitwidth < 32: + return t.to(int32, _builder=_builder) + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + + return t @builtin -@_add_reduction_docstr("minimum") -def min(input, axis, _builder=None): +def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): axis = _constexpr_to_value(axis) - return semantic.min(input, axis, _builder) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + if len(input.shape) > 1: + new_shape = [constexpr(1)] * len(input.shape) + new_shape[axis] = constexpr(n) + index = view(index, new_shape, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) -@builtin + rvalue, rindices = reduction((input, index), axis, combine_fn, + _builder=_builder, _generator=_generator) + return rindices + + +@triton.jit +def _max_combine(a, b): + return maximum(a, b) + + +@triton.jit +@_add_reduction_docstr("maximum") +def max(input, axis): + input = _promote_reduction_input(input) + return reduction(input, axis, _max_combine) + + +@triton.jit +def _argmax_combine(value1, index1, value2, index2): + gt = value1 > value2 + lt = value1 < value2 + index_min = minimum(index1, index2) + index_ret = where(gt, index1, where(lt, index2, index_min)) + value_ret = maximum(value1, value2) + return value_ret, index_ret + + +@triton.jit +@_add_reduction_docstr("maximum index") +def argmax(input, axis): + input = _promote_reduction_input(input) + return _argreduce(input, axis, _argmax_combine) + + +@triton.jit +def _min_combine(a, b): + # TODO: minimum/maximum doesn't get lowered to fmin/fmax... + return minimum(a, b) + + +@triton.jit +@_add_reduction_docstr("minimum") +def min(input, axis): + input = _promote_reduction_input(input) + return reduction(input, axis, _min_combine) + + +@triton.jit +def _argmin_combine(value1, index1, value2, index2): + lt = value1 < value2 + gt = value1 > value2 + index_min = minimum(index1, index2) + index_ret = where(lt, index1, where(gt, index2, index_min)) + value_ret = minimum(value1, value2) + return value_ret, index_ret + + +@triton.jit @_add_reduction_docstr("minimum index") -def argmin(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.argmin(input, axis, _builder) +def argmin(input, axis): + input = _promote_reduction_input(input) + return _argreduce(input, axis, _argmin_combine) -@builtin +@triton.jit +def _sum_combine(a, b): + return a + b + + +@triton.jit @_add_reduction_docstr("sum") -def sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.sum(input, axis, _builder) +def sum(input, axis): + input = _promote_reduction_input(input) + return reduction(input, axis, _sum_combine) + + +@triton.jit +def _xor_combine(a, b): + return a ^ b @builtin @_add_reduction_docstr("xor sum") -def xor_sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.xor_sum(input, axis, _builder) +def xor_sum(input, axis, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = _promote_reduction_input(input, _builder=_builder) + return reduction(input, axis, _xor_combine, + _builder=_builder, _generator=_generator) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 77c5354c181a..985880a2f0de 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1,7 +1,7 @@ from __future__ import annotations # remove after python 3.11 from functools import wraps -from typing import List, Optional, Tuple, TypeVar +from typing import List, Optional, Sequence, Tuple, TypeVar from . import core as tl from triton._C.libtriton.triton import ir @@ -1228,91 +1228,36 @@ def where(condition: tl.tensor, return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) # ===----------------------------------------------------------------------===// -# Reductions +# Reduction # ===----------------------------------------------------------------------=== -def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, - FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: - scalar_ty = input.type.scalar - out_scalar_ty = scalar_ty - # input is extended to 32-bits if necessary - # this increases numerical accuracy and can be done pretty much for free - # on GPUs - if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: - input = cast(input, tl.int32, builder) - out_scalar_ty = tl.int32 - - # hardware doesn't support FMAX, FMIN, CMP for bfloat16 - if scalar_ty is tl.bfloat16: - input = cast(input, tl.float32, builder) - out_scalar_ty = tl.float32 - - # choose the right unsigned operation - if scalar_ty.is_int_unsigned(): - int_op_to_unit = { - ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN, - ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX, - ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN, - ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX, - } - if INT_OP in int_op_to_unit: - INT_OP = int_op_to_unit[INT_OP] - - # If we are doing an argmin or argmax we want to use an int32 output type - if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX: - out_scalar_ty = tl.int32 - elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN: - out_scalar_ty = tl.int32 - - # get result type - shape = input.type.shape - - rank = len(shape) - assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})" - - ret_shape = [] - for i, s in enumerate(shape): - if i != axis: - ret_shape.append(s) - if ret_shape: - res_ty = tl.block_type(out_scalar_ty, ret_shape) - else: - # 0d-tensor -> scalar - res_ty = out_scalar_ty - - if scalar_ty.is_floating(): - return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) - elif scalar_ty.is_int(): - return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) - assert False - - -def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) - +def reduction( + inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder +) -> Tuple[tl.tensor, ...]: + # get result shape + shape = inputs[0].type.shape + print(shape, axis) + ret_shape = [s for i, s in enumerate(shape) if i != axis] + for t in inputs: + assert t.type.shape == shape -def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN) - - -def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) - - -def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX) - - -def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) - - -def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - scalar_ty = input.type.scalar - if not scalar_ty.is_int(): - raise ValueError("xor_sum only supported for integers") - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) + def wrap_tensor(x, scalar_ty): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple( + wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) + for i in range(len(inputs)) + ) # ===----------------------------------------------------------------------=== diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index b118f4b36db1..174be59b6e1a 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -217,7 +217,11 @@ tt.func @alloc(%A : !tt.ptr) { tt.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: scratch offset = 0, size = 512 - %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0> + %b = "tt.reduce" (%cst0) ({ + ^bb0(%arg0: f16, %arg1: f16): + %add = arith.addf %arg0, %arg1 : f16 + tt.reduce.return %add : f16 + }) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> tt.return // CHECK-NEXT: size = 512 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 89cc0bbac129..4946eeef5395 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -79,7 +79,11 @@ tt.func @scratch() { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.convert_layout %1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> - %2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg1: f16, %arg2: f16): + %add = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %add : f16 + }) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> tt.return } diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index e4aa9b2c777b..e48cbd3dc1cb 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -79,18 +79,42 @@ tt.func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, tt.func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { // Test if reduce ops infer types correctly - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> - %a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32> - %b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32> - %c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32> - %e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32> - %f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32 - %g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32 + // CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> + %a = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> + // CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32> + %b = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32> + // CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32> + %c = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32> + // CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32> + %e = "tt.reduce" (%b) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32> + // CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32> + %f = "tt.reduce" (%a) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32> + // CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32 + %g = "tt.reduce" (%f) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4xf32>) -> f32 // Avoid optimizations for c, e, and g %ptr1x2 = tt.splat %ptr : (!tt.ptr) -> tensor<1x2x!tt.ptr> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index eddcae1ce7eb..a04cd2d8be4e 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -40,14 +40,30 @@ tt.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> - // CHECK: tensor<4x4xf32, #[[blocked0]]> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>> - %c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32> - // CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}> - %c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32> - // CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>> - %c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32> - // CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>> - %c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32> + // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>> + %c0_ = "tt.reduce" (%c0) ({ + ^bb0(%arg1: f32, %arg2: f32): + %add = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}> + %c1_ = "tt.reduce" (%c1) ({ + ^bb0(%arg3: f32, %arg4: f32): + %add = arith.addf %arg3, %arg4 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>> + %c2_ = "tt.reduce" (%c1) ({ + ^bb0(%arg5: f32, %arg6: f32): + %add = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32> + // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>> + %c3_ = "tt.reduce" (%c2) ({ + ^bb0(%arg7: f32, %arg8: f32): + %add = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<16x16xf32>) -> tensor<16xf32> tt.return } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index a313103d16d8..a53eda6f1ecb 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -787,7 +787,11 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2> %28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2> %29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> - %30 = tt.reduce %29 {axis = 1 : i32, redOp = 12 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %30 = "tt.reduce" (%29) ({ + ^bb0(%arg4: f32, %arg5: f32): + %max = arith.maxf %arg4, %arg5 : f32 + tt.reduce.return %max : f32 + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0> %32 = triton_gpu.convert_layout %31 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1> @@ -803,7 +807,11 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %43 = math.exp %42 : tensor<16x16xf32, #blocked2> %44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2> %45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> - %46 = tt.reduce %45 {axis = 1 : i32, redOp = 2 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %46 = "tt.reduce" (%45) ({ + ^bb0(%arg4: f32, %arg5: f32): + %add = arith.addf %arg4, %arg5 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %47 = triton_gpu.convert_layout %46 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0> %48 = triton_gpu.convert_layout %47 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %49 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1> @@ -907,7 +915,11 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %74 = "triton_gpu.select"(%54, %73, %arg7) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> scf.yield %74 : tensor<64x64xf32, #blocked2> } - %26 = tt.reduce %25 {axis = 1 : i32, redOp = 2 : i32} : tensor<64x64xf32, #blocked2> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = "tt.reduce" (%25) ({ + ^bb0(%arg8: f32, %arg9: f32): + %add = arith.addf %arg8, %arg9 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %27 = triton_gpu.convert_layout %26 : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xf32, #blocked0> %28 = triton_gpu.convert_layout %27 : (tensor<64xf32, #blocked0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %29 = tt.expand_dims %28 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xf32, #blocked1> @@ -1016,7 +1028,11 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked> %3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked> - %4 = tt.reduce %cst {axis = 1 : i32, redOp = 1 : i32} : tensor<1x2xi32, #blocked> -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = "tt.reduce" (%cst) ({ + ^bb0(%arg3: i32, %arg4: i32): + %add = arith.addi %arg3, %arg4 : i32 + tt.reduce.return %add : i32 + }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1> %6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2> @@ -1037,7 +1053,8 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { // Check if the SimplifyReduceCvt handles convert_layout lifted from the for loop. // CHECK-LABEL: reduce_cvt2 -// CHECK: tt.reduce +// Match the reduction +// CHECK: }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> // CHECK-NEXT: triton_gpu.convert_layout // CHECK-NOT: triton_gpu.convert_layout #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> @@ -1092,7 +1109,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %59 = "triton_gpu.select"(%52, %58, %arg6) : (tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked>, tensor<1x256xf32, #blocked>) -> tensor<1x256xf32, #blocked> scf.yield %59 : tensor<1x256xf32, #blocked> } - %16 = tt.reduce %15 {axis = 1 : i32, redOp = 2 : i32} : tensor<1x256xf32, #blocked> -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = "tt.reduce" (%15) ({ + ^bb0(%arg7: f32, %arg8: f32): + %add = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %add : f32 + + }) {axis = 1 : i32} : (tensor<1x256xf32, #blocked>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %17 = triton_gpu.convert_layout %16 : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xf32, #blocked1> %18 = triton_gpu.convert_layout %17 : (tensor<1xf32, #blocked1>) -> tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %19 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<1xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xf32, #blocked2>