diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 7f0a2729a487..d03f6b86268a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -55,6 +55,11 @@ class TargetInfoBase { StringRef message, StringRef file, StringRef func, int line) const = 0; + // Whether to enable linear layout. This is a per-backend temporary escape + // hatch to disable linear layout while figuring out issues. Eventually we + // want to enable linear layout everywhere and delete this control. + virtual bool enableLinearLayout() const { return true; } + virtual ~TargetInfoBase() {} }; } // namespace mlir::triton diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 186e6677564c..382f60254c6d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -908,14 +908,21 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout, inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, SmallVector> &offsets, - unsigned ctaOffsetX, unsigned ctaOffsetY) { + unsigned ctaBatchOffset, unsigned ctaOffsetX, + unsigned ctaOffsetY) { const unsigned elemsPerThreadPerGroup = 8; auto warpSize = getWarpSize(wmmaLayout); assert(warpSize == 32); auto shapePerCta = getShapePerCTATile(wmmaLayout); + auto rank = shapePerCta.size(); + assert(rank == 2 || rank == 3); + SmallVector elemOffset(rank, 0); + if (rank == 3) + elemOffset[0] = ctaBatchOffset; for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { - offsets.push_back( - {ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]}); + elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + offsets.push_back(elemOffset); } } @@ -925,9 +932,11 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, RankedTensorType type) { auto shape = type.getShape(); auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA(); - assert(_warpsPerCTA.size() == 2); - SmallVector warpsPerCTA = {i32_val(_warpsPerCTA[0]), - i32_val(_warpsPerCTA[1])}; + auto rank = _warpsPerCTA.size(); + assert(rank == 2 || rank == 3); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); Value threadId = getThreadId(rewriter, loc); @@ -940,20 +949,34 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, triton::gpu::getWarpOrder(wmmaLayout)); - if (shape[0] >= mnkDim[0]) { - assert(shape[0] % mnkDim[0] == 0); - multiDimWarpId[0] = - urem(multiDimWarpId[0], i32_val(ceil(shape[0], mnkDim[0]))); + if (shape[rank - 2] >= mnkDim[0]) { + assert(shape[rank - 2] % mnkDim[0] == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mnkDim[0]))); } - if (shape[1] >= mnkDim[1]) { - assert(shape[1] % mnkDim[1] == 0); - multiDimWarpId[1] = - urem(multiDimWarpId[1], i32_val(ceil(shape[1], mnkDim[1]))); + if (shape[rank - 1] >= mnkDim[1]) { + assert(shape[rank - 1] % mnkDim[1] == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], mnkDim[1]))); } - Value offWarp0 = mul(multiDimWarpId[0], i32_val(mnkDim[0])); - Value offWarp1 = mul(multiDimWarpId[1], i32_val(mnkDim[1])); - return {add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0), - add(laneId, offWarp1)}; + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0])); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1])); + + SmallVector multiDimBase(rank); + + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + multiDimBase[rank - 1] = add(laneId, offWarp1); + + // TODO: It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; } inline SmallVector> @@ -964,17 +987,31 @@ emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); - SmallVector numWarpsPerDim(2); + auto rank = tensorShape.size(); + assert(rank == 2 || rank == 3); + + SmallVector numWarpsPerDim(rank, 1); auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); - for (unsigned d = 0; d < 2; ++d) { + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mnkDim[0]; + shapePerWarp[rank - 1] = mnkDim[1]; + for (unsigned d = 0; d < rank; ++d) { unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); - numWarpsPerDim[d] = ceil(inPerWarp, mnkDim[d]); + numWarpsPerDim[d] = ceil(inPerWarp, shapePerWarp[d]); } - for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) { - for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) { - emitWmmaOffsetForCTA(wmmaLayout, offsets, i, j); + unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1; + unsigned repM = numWarpsPerDim[rank - 2]; + unsigned repN = numWarpsPerDim[rank - 1]; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < repM; ++i) { + for (unsigned j = 0; j < repN; ++j) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j); + } } } return offsets; @@ -1170,7 +1207,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, bool allowLL = true) { // Eventually the LinearLayout path will be the only one. For now we allow // both paths so we can test that they produce the same results. - if (allowLL) { + if (allowLL && target.enableLinearLayout()) { std::optional>> llOffsets = emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type, withCTAOffset); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 05f3378dc2d2..ae23f9d13cea 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -798,16 +798,18 @@ def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encod let mnemonic = "amd_mfma"; let description = [{ -An encoding for tensors that have been produced by tensor cores of AMD MI GPUs. +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + It is characterized by the following parameters: -- `versionMajor` and `versionMinor` indicates the GPU arch +- `versionMajor` and `versionMinor` indicates the GPU architecture: - 1.0: gfx908, i.e. MI100 - 2.0: gfx90a: i.e. MI200, MI210, MI250 - 3.0: gfx940, gfx941, gfx942: MI300 - `warpsPerCTA` indicates the wave layout in the workgroup. - `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. - `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout -without going to LDS. This is used in the case of chained dot (E.g. Flash-Attention kernel). +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). Example 1: Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. @@ -924,6 +926,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, }]; + let genVerifyDecl = 1; let hasCustomAssemblyFormat = 1; } @@ -931,9 +934,6 @@ def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encod let mnemonic = "amd_wmma"; let description = [{ -An encoding for tensors that have been produced by WMMA instructions, -available on RDNA 3. -A `warpsPerCTA` parameter characterizes data distribution between waves. An important limitation of WMMA for layout is a shape for tiles proccessed by a single wave. It is [16, 16]. This encoding assumes specific access to matrix elements by threads. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 689e83b5acd9..32cc43c9d5d2 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -469,6 +469,7 @@ bool supportMFMA(triton::DotOp op) { auto bShape = bTy.getShape(); auto rank = aShape.size(); + assert(bShape.size() == rank); auto M = aShape[rank - 2]; auto N = bShape[rank - 1]; auto K = aShape[rank - 1]; @@ -521,8 +522,11 @@ bool supportWMMA(triton::DotOp op) { auto aShape = aTy.getShape(); auto bShape = bTy.getShape(); - assert(aShape[1] == bShape[0]); - if (!supportWMMAGranularity(aShape[0], bShape[1], aShape[1])) + auto rank = aShape.size(); + assert(bShape.size() == rank); + assert(aShape[rank - 1] == bShape[rank - 2]); + if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1], + aShape[rank - 1])) return false; return true; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index ca7367d15b07..a80158a463e1 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -590,7 +590,7 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } else if (auto wmmaLayout = dyn_cast(layout)) { - emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0], + emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], multiDimCTAInRepId[1]); } multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 74ae61c06d31..c2da06ee17ad 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -7,12 +7,14 @@ #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/StrUtil.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" + using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::gpu; @@ -804,16 +806,22 @@ SmallVector AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of wmma layout"); + assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout"); SmallVector elemsPerThread(rank); auto mnkDim = getMNKDimPerWMMAInstr(); auto elemsPerThreadPerTile = getSizePerThread(); auto warpsPerCTA = getWarpsPerCTA(); - return {ceil(shape[0], mnkDim[0] * warpsPerCTA[0]) * - elemsPerThreadPerTile[0], - ceil(shape[1], mnkDim[1] * warpsPerCTA[1]) * - elemsPerThreadPerTile[1]}; + + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = + ceil(shape[rank - 2], mnkDim[0] * warpsPerCTA[rank - 2]) * + elemsPerThreadPerTile[rank - 2]; + elemsPerThread[rank - 1] = + ceil(shape[rank - 1], mnkDim[1] * warpsPerCTA[rank - 1]) * + elemsPerThreadPerTile[rank - 1]; + return elemsPerThread; } unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, @@ -1321,6 +1329,26 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { printer << "}>"; } +LogicalResult +AMDMfmaEncodingAttr::verify(function_ref emitError, + unsigned versionMajor, unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + unsigned mDim, unsigned nDim, bool isTransposed, + mlir::triton::gpu::CTALayoutAttr) { + if (!(versionMajor >= 0 && versionMajor <= 3)) { + return emitError() << "major version must be in the [0, 3] range"; + } + if (versionMinor != 0) { + return emitError() << "minor version must be 0"; + } + if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) { + return emitError() + << "(M, N) cases other than (32, 32) or (16, 16) unimplemented"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // WMMA encoding //===----------------------------------------------------------------------===// @@ -1605,9 +1633,8 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - constexpr int waveSize = 64; auto rep = getMFMARepForOperands(shape, kWidth, opIdx); - return rep[0] * rep[1] * rep[2] * kWidth; + return product(rep) * kWidth; } SmallVector @@ -1646,8 +1673,14 @@ AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, SmallVector AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + auto mnkDim = getMNKDimPerWMMAInstr(); - return {mnkDim[0] * getWarpsPerCTA()[0], mnkDim[1] * getWarpsPerCTA()[1]}; + shapePerCTATile[rank - 2] *= mnkDim[0]; + shapePerCTATile[rank - 1] *= mnkDim[1]; + return shapePerCTATile; } SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); @@ -1668,28 +1701,43 @@ SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); } SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { - return {getMNKDimPerWMMAInstr()[0] / getSizePerThread()[0], - getMNKDimPerWMMAInstr()[1] / getSizePerThread()[1]}; + auto rank = getWarpsPerCTA().size(); + SmallVector threads(rank, 1); + auto mnkInstr = getMNKDimPerWMMAInstr(); + threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2]; + threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1]; + return threads; } SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { - return {8, 1}; + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + sizePerThread[rank - 2] = 8; + sizePerThread[rank - 1] = 1; + return sizePerThread; } SmallVector AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); if (opIdx == 0) { - return {1, 16}; + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = 16; } else if (opIdx == 1) { - return {16, 1}; + sizePerThread[rank - 2] = 16; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } + return sizePerThread; } SmallVector AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const { auto parentShapePerCTA = getShapePerCTATile(shape); + auto rank = shape.size(); + assert(rank = 2); if (opIdx == 0) { return {parentShapePerCTA[0], static_cast(shape[1])}; } else if (opIdx == 1) { @@ -1702,7 +1750,7 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); - return rep[0] * rep[1] * kWidth; + return product(rep) * kWidth; } SmallVector @@ -1715,16 +1763,25 @@ AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, Type elemType, int kWidth, int opIdx) const { auto operandTileShape = getWMMAElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; if (opIdx == 0) - return {std::max(1, operandShape[0] / - (operandTileShape[0] * warpsPerCTA[0])), - std::max(1, operandShape[1] / operandTileShape[1])}; + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; else { assert(opIdx == 1); - return {std::max(1, operandShape[0] / operandTileShape[0]), - std::max(1, operandShape[1] / - (operandTileShape[1] * warpsPerCTA[1]))}; + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; } } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f6604c3de3fd..5972c93d7f98 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3262,6 +3262,11 @@ def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests input_precision = "ieee" + if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index b42edaea3d86..5a4ada339200 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> #mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}> @@ -27,6 +27,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } + // CHECK-LABEL: wmma_dot_bf16 + tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma>) { + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> + // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> + // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> + // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> + // CHECK: llvm.mlir.undef : vector<16xbf16> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xbf16, #mma> + tt.return + } + // CHECK-LABEL: wmma_dot_int8_32 tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> @@ -57,3 +71,33 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> +#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 1, 4]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_dot_operand3d + tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) { + // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma_dot3d + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma>) { + // CHECK-COUNT-32: llvm.extractvalue %arg0 + // CHECK-COUNT-32: llvm.insertelement + // CHECK-COUNT-32: llvm.extractvalue %arg1 + // CHECK-COUNT-32: llvm.insertelement + // CHECK-COUNT-8: llvm.extractvalue %arg2 + // CHECK-COUNT-8: llvm.insertelement + // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<2x16x16xf16, #mma> + // CHECK-COUNT-8: llvm.extractelement + // CHECK-COUNT-8: llvm.insertvalue + tt.return + } +} diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index 7517e23337bc..abf18381f56f 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -45,3 +45,18 @@ // expected-error@+2 {{triton_gpu.dot_op kWidth parameter supports only 16 for WMMA parent}} #wmma = #triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> + +// ----- + +// expected-error@+1 {{major version must be in the [0, 3] range}} +#mfma = #triton_gpu.amd_mfma<{versionMajor = 10, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> + +// ----- + +// expected-error@+1 {{minor version must be 0}} +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 5, warpsPerCTA = [1, 1, 1], instrShape = [32, 32], isTransposed = false}> + +// ----- + +// expected-error@+1 {{(M, N) cases other than (32, 32) or (16, 16) unimplemented}} +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 1], instrShape = [16, 8], isTransposed = false}> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 5e1067884d4a..72e11d593ac6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -212,6 +212,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto sharedLayout = cast(aTensorTy.getEncoding()); auto order = sharedLayout.getOrder(); + assert((rank == 2 || order[2] == 0) && + "expect batch to be the slowest dimension"); auto elemTy = aTensorTy.getElementType(); auto kWidth = encoding.getKWidth(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 3e2ec71db317..950e2926a13f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -69,7 +69,9 @@ llvm::SmallVector> computeTensorElemMappingInBlock( const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { - auto numK = reps[1]; + assert(reps.size() == 3); + assert(elemsPerInstr.size() == 2); + auto numK = reps[2]; const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); @@ -77,6 +79,8 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value nonKDim = i32_val(iNonKDim); Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); + auto rank = smemOffsets.size(); + for (int tile = 0; tile < numK; ++tile) { Value tileVOffset = _0; Value tileHOffset = i32_val(tile * elemsPerInstr[1]); @@ -92,8 +96,8 @@ llvm::SmallVector> computeTensorElemMappingInBlock( add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); - Value row = add(sliceVOffset, smemOffsets[0]); - Value col = add(sliceHOffset, smemOffsets[1]); + Value row = add(sliceVOffset, smemOffsets[rank - 2]); + Value col = add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; } @@ -107,61 +111,68 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread) { assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx"); - int kDimIdx = opIdx == 0 ? 1 : 0; - int nonKDimIdx = opIdx == 0 ? 0 : 1; + auto rank = smemObj.getStrides().size(); + int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; auto wmmaLayout = cast(encoding.getParent()); - auto nonKDim = wmmaLayout.getMNKDimPerWMMAInstr()[nonKDimIdx]; - assert(nonKDim == 16); + assert(wmmaLayout.getMNKDimPerWMMAInstr()[nonKDimIdx] == 16); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); auto aTensorTy = cast(tensor.getType()); ArrayRef shape = aTensorTy.getShape(); auto sharedLayout = cast(aTensorTy.getEncoding()); auto order = sharedLayout.getOrder(); + assert((rank == 2 || order[2] == 0) && + "expect batch to be the slowest dimension"); auto elemTy = aTensorTy.getElementType(); int kWidth = encoding.getKWidth(); auto elemsPerInstr = wmmaLayout.getWMMAElemsPerInstrForOperands(); - auto wmmaInstrK = elemsPerInstr[kDimIdx]; + auto wmmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0]; + auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; + assert(wmmaInstrNonK == 16); auto numReps = wmmaLayout.getWMMARepForOperands(shape, elemTy, kWidth, opIdx); - auto numRepNonK = numReps[nonKDimIdx]; - auto numRepK = numReps[kDimIdx]; - - unsigned iWarpSize = triton::gpu::getWarpSize(wmmaLayout); - unsigned iNumLanes = iWarpSize / 2; - assert(iWarpSize == 32); - Value warpSize = i32_val(iWarpSize); + auto numRepNonK = numReps[opIdx == 0 ? 1 : 2]; + auto numRepK = numReps[opIdx == 0 ? 2 : 1]; + auto repB = numReps[0]; + + unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout); + unsigned iNumLanes = iWaveSize / 2; + assert(iWaveSize == 32); + Value waveSize = i32_val(iWaveSize); Value numLanes = i32_val(iNumLanes); - Value linearWarpId = udiv(thread, warpSize); + Value linearWaveId = udiv(thread, waveSize); Value lane = urem(thread, numLanes); // share elem between two threads - unsigned numElemsPerThreadPerRep = - wmmaLayout.getMNKDimPerWMMAInstr()[kDimIdx]; + unsigned numElemsPerThreadPerRep = wmmaInstrK; - Value warp = udiv(thread, warpSize); - unsigned int maxNumWarps = shape[nonKDimIdx] / elemsPerInstr[nonKDimIdx]; + Value warp = udiv(thread, waveSize); + unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); + int warpsPerBatch = + rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; + Value waveIdInBatch = urem(linearWaveId, i32_val(warpsPerBatch)); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; SmallVector offsets; Value smemBase; Value spatialWarpId = AMD::getWarpIdInBlock( - rewriter, loc, linearWarpId, warpsPerCTA, elemsPerInstr[0], + rewriter, loc, linearWaveId, warpsPerCTA, elemsPerInstr[0], shape[nonKDimIdx], nonKDimIdx, triton::gpu::getOrder(wmmaLayout)); if (opIdx == 0) { offsets = AMD::computeOffsetsAType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, spatialWarpId, lane, warpsPerBlockNonK, numElemsPerThreadPerRep, - numReps, smemObj, sharedLayout, nonKDim, wmmaInstrK); + numReps, smemObj, sharedLayout, wmmaInstrNonK, wmmaInstrK); } else { assert(opIdx == 1); offsets = AMD::computeOffsetsBType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, spatialWarpId, lane, warpsPerBlockNonK, numElemsPerThreadPerRep, - numReps, smemObj, sharedLayout, nonKDim, wmmaInstrK); + numReps, smemObj, sharedLayout, wmmaInstrNonK, wmmaInstrK); } smemBase = AMD::computeBasePtr(rewriter, loc, smemObj); @@ -171,19 +182,26 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); - for (int nonK = 0; nonK < numRepNonK; ++nonK) { - for (int k = 0; k < numRepK; ++k) { - auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); - Value valVec = undef(vecTy); - for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); - Value loadOffset = offsets[nonK * loadsPerThread * numRepK + - k * loadsPerThread + loadId]; - Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset); - Value loadedValue = load(loadVecTy, loadAddress); - for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { - Value elemVal = extract_element(elemTy, loadedValue, i32_val(elemId)); - loadedValues.push_back(elemVal); + for (int b = 0; b < repB; ++b) { + int operandSize = shape[rank - 1] * shape[rank - 2]; + Value batchOffset = mul(i32_val(operandSize), + add(waveIdInBatch, i32_val(b * warpsPerBatch))); + for (int nonK = 0; nonK < numRepNonK; ++nonK) { + for (int k = 0; k < numRepK; ++k) { + auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); + Value valVec = undef(vecTy); + for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); + Value loadOffset = offsets[nonK * loadsPerThread * numRepK + + k * loadsPerThread + loadId]; + loadOffset = add(loadOffset, batchOffset); + Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset); + Value loadedValue = load(loadVecTy, loadAddress); + for (int elemId = 0; elemId < elemsPerLoad; ++elemId) { + Value elemVal = + extract_element(elemTy, loadedValue, i32_val(elemId)); + loadedValues.push_back(elemVal); + } } } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index d8b537be68e8..3843159fc047 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -45,32 +45,40 @@ enum class WMMAInstrType : uint8_t { NOT_APPLICABLE, }; -using ValueTable = std::map, Value>; +using ValueTable = std::map, Value>; -ValueTable getValuesFromDotOperandLayoutStruct( - ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, - Value value, int n0, int n1, int kWidth, Type type, Location loc) { +ValueTable +getValuesFromDotOperandLayoutStruct(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Value value, int batch, int n0, int n1, + int kWidth, Type type, Location loc) { auto elems = unpackLLElements(loc, value, rewriter); ValueTable vals; - for (int i = 0; i < n0; i++) { - for (int j = 0; j < n1; j++) { - Type elemTy = typeConverter->convertType(type); - Type ty = vec_ty(elemTy, kWidth); - Value rawElems = undef(ty); - for (int k = 0; k < kWidth; ++k) { - rawElems = insert_element(ty, rawElems, - elems[kWidth * (n1 * i + j) + k], i32_val(k)); - } + for (int b = 0; b < batch; b++) { + for (int i = 0; i < n0; i++) { + for (int j = 0; j < n1; j++) { + Type elemTy = typeConverter->convertType(type); + Type ty = vec_ty(elemTy, kWidth); + Value rawElems = undef(ty); + for (int k = 0; k < kWidth; ++k) { + rawElems = insert_element( + ty, rawElems, + elems[n0 * n1 * kWidth * b + kWidth * (n1 * i + j) + k], + i32_val(k)); + } - Value convertedElems; - if (type.isBF16() || type.isF16()) { - convertedElems = rawElems; - } else { - convertedElems = bitcast( - rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() / - i32_ty.getIntOrFloatBitWidth())); + Value convertedElems; + if (type.isF16()) { + convertedElems = rawElems; + } else if (type.isBF16()) { + convertedElems = bitcast(rawElems, vec_ty(i16_ty, kWidth)); + } else { + convertedElems = bitcast( + rawElems, vec_ty(i32_ty, kWidth * type.getIntOrFloatBitWidth() / + i32_ty.getIntOrFloatBitWidth())); + } + vals[{b, i, j}] = convertedElems; } - vals[{i, j}] = convertedElems; } } return vals; @@ -172,52 +180,56 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, auto repB = wmmaLayout.getWMMARepForOperands(bTensorTy.getShape(), elemTy, kWidth, 1); - assert(repA[1] == repB[0]); + assert(repA[2] == repB[1]); Value loadedA = adaptor.getA(); Value loadedB = adaptor.getB(); Value loadedC = adaptor.getC(); - auto numRepM = repA[0]; - auto numRepN = repB[1]; - auto numRepK = repA[1]; + auto numRepM = repA[1]; + auto numRepN = repB[2]; + auto numRepK = repA[2]; + auto numRepB = repA[0]; ValueTable ha = getValuesFromDotOperandLayoutStruct( - rewriter, typeConverter, loadedA, numRepM, numRepK, kWidth, + rewriter, typeConverter, loadedA, numRepB, numRepM, numRepK, kWidth, aTensorTy.getElementType(), loc); ValueTable hb = getValuesFromDotOperandLayoutStruct( - rewriter, typeConverter, loadedB, numRepN, numRepK, kWidth, + rewriter, typeConverter, loadedB, numRepB, numRepN, numRepK, kWidth, aTensorTy.getElementType(), loc); auto dstElemTy = dTensorTy.getElementType(); auto fc = unpackLLElements(loc, loadedC, rewriter); unsigned warpSize = triton::gpu::getWarpSize(wmmaLayout); - // TODO get rid of magic numbers - unsigned vgprElemWidth = 32; + constexpr unsigned vgprElemBitWidth = 32; unsigned paddedOutputElemSize = - vgprElemWidth / dstElemTy.getIntOrFloatBitWidth(); + vgprElemBitWidth / dstElemTy.getIntOrFloatBitWidth(); // compute number of output elements that each thread holds for one WMMA // instruction. auto elemsPerVec = mnkDim[0] * mnkDim[1] * paddedOutputElemSize / warpSize; auto dElemsToStorePerThread = mnkDim[0] * mnkDim[1] / warpSize; auto vecTy = vec_ty(dstElemTy, elemsPerVec); - for (int m = 0; m < numRepM; ++m) { - for (int n = 0; n < numRepN; ++n) { - Value acc = undef(vecTy); - for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { - acc = insert_element(vecTy, acc, - fc[m * numRepN * dElemsToStorePerThread + - n * dElemsToStorePerThread + v], - i32_val(v * paddedOutputElemSize)); - } - for (size_t k = 0; k < numRepK; k++) { - acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{m, k}], - hb[{n, k}], acc, aTensorTy.getElementType(), - bTensorTy.getElementType()); - } - for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { - fc[m * numRepN * dElemsToStorePerThread + n * dElemsToStorePerThread + - v] = - extract_element(dstElemTy, acc, i32_val(v * paddedOutputElemSize)); + for (int b = 0; b < numRepB; ++b) { + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + auto batchOffIdx = b * numRepM * numRepN * dElemsToStorePerThread; + auto mRepOffId = m * numRepN * dElemsToStorePerThread; + auto nRepOffId = n * dElemsToStorePerThread; + auto fcThreadOffIdx = batchOffIdx + mRepOffId + nRepOffId; + + Value acc = undef(vecTy); + for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { + acc = insert_element(vecTy, acc, fc[fcThreadOffIdx + v], + i32_val(v * paddedOutputElemSize)); + } + for (size_t k = 0; k < numRepK; k++) { + acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{b, m, k}], + hb[{b, n, k}], acc, aTensorTy.getElementType(), + bTensorTy.getElementType()); + } + for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { + fc[fcThreadOffIdx + v] = extract_element( + dstElemTy, acc, i32_val(v * paddedOutputElemSize)); + } } } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 0cf45350637e..2312c9ed6279 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -58,6 +58,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef message, StringRef file, StringRef func, int line) const override; + bool enableLinearLayout() const override { return false; } + private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, ConversionPatternRewriter &rewriter, bool useStdErr) const; diff --git a/third_party/proton/csrc/include/Driver/Device.h b/third_party/proton/csrc/include/Driver/Device.h index 79a9bf11ece1..3e414c824bf6 100644 --- a/third_party/proton/csrc/include/Driver/Device.h +++ b/third_party/proton/csrc/include/Driver/Device.h @@ -27,13 +27,13 @@ struct Device { uint64_t memoryClockRate; // khz uint64_t busWidth; uint64_t numSms; - uint64_t arch; + std::string arch; Device() = default; Device(DeviceType type, uint64_t id, uint64_t clockRate, uint64_t memoryClockRate, uint64_t busWidth, uint64_t numSms, - uint64_t arch) + std::string arch) : type(type), id(id), clockRate(clockRate), memoryClockRate(memoryClockRate), busWidth(busWidth), numSms(numSms), arch(arch) {} diff --git a/third_party/proton/csrc/include/Driver/GPU/HipApi.h b/third_party/proton/csrc/include/Driver/GPU/HipApi.h index 6b8ad082f722..fadb9c425c14 100644 --- a/third_party/proton/csrc/include/Driver/GPU/HipApi.h +++ b/third_party/proton/csrc/include/Driver/GPU/HipApi.h @@ -14,8 +14,15 @@ template hipError_t deviceGetAttribute(int *value, hipDeviceAttribute_t attribute, int deviceId); +template hipError_t getDeviceCount(int *count); + +template +hipError_t getDeviceProperties(hipDeviceProp_t *prop, int deviceId); + Device getDevice(uint64_t index); +const std::string getHipArchName(uint64_t index); + const char *getKernelNameRef(const hipFunction_t f); const char *getKernelNameRefByPtr(const void *hostFunction, hipStream_t stream); diff --git a/third_party/proton/csrc/include/Driver/GPU/HsaApi.h b/third_party/proton/csrc/include/Driver/GPU/HsaApi.h new file mode 100644 index 000000000000..c694a11af4b5 --- /dev/null +++ b/third_party/proton/csrc/include/Driver/GPU/HsaApi.h @@ -0,0 +1,23 @@ +#ifndef PROTON_DRIVER_GPU_HSA_H_ +#define PROTON_DRIVER_GPU_HSA_H_ + +#include "Driver/Device.h" +#include "hsa/hsa_ext_amd.h" + +namespace proton { + +namespace hsa { + +template +hsa_status_t agentGetInfo(hsa_agent_t agent, hsa_agent_info_t attribute, + void *value); + +hsa_status_t iterateAgents(hsa_status_t (*callback)(hsa_agent_t agent, + void *data), + void *data); + +} // namespace hsa + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_HSA_H_ diff --git a/third_party/proton/csrc/include/Profiler/CuptiProfiler.h b/third_party/proton/csrc/include/Profiler/CuptiProfiler.h index d3412ef06ceb..344d0fd4b9df 100644 --- a/third_party/proton/csrc/include/Profiler/CuptiProfiler.h +++ b/third_party/proton/csrc/include/Profiler/CuptiProfiler.h @@ -1,40 +1,17 @@ #ifndef PROTON_PROFILER_CUPTI_PROFILER_H_ #define PROTON_PROFILER_CUPTI_PROFILER_H_ -#include "Context/Context.h" -#include "Profiler.h" - -#include -#include +#include "GPUProfiler.h" namespace proton { -class CuptiProfiler : public Profiler, - public OpInterface, - public Singleton { +class CuptiProfiler : public GPUProfiler { public: CuptiProfiler(); virtual ~CuptiProfiler(); -protected: - // OpInterface - void startOp(const Scope &scope) override final; - void stopOp(const Scope &scope) override final; - void setOpInProgress(bool value) override final; - bool isOpInProgress() override final; - - // Profiler - void doStart() override; - void doFlush() override; - void doStop() override; - private: - // Use the pimpl idiom to hide the implementation details. This lets us avoid - // including the cupti header from this header. The cupti header and the - // equivalent header from AMD define conflicting macros, so we want to use - // those headers only within cc files. struct CuptiProfilerPimpl; - std::unique_ptr pImpl; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h new file mode 100644 index 000000000000..c3c148658349 --- /dev/null +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -0,0 +1,123 @@ +#ifndef PROTON_PROFILER_GPU_PROFILER_H_ +#define PROTON_PROFILER_GPU_PROFILER_H_ + +#include "Context/Context.h" +#include "Profiler.h" +#include "Utility/Atomic.h" +#include + +#include +#include + +namespace proton { + +// Singleton: Each concrete GPU profiler, e.g., +// CuptiProfiler, should be a singleton. +template +class GPUProfiler : public Profiler, + public OpInterface, + public Singleton { +public: + GPUProfiler() = default; + virtual ~GPUProfiler() = default; + +protected: + // OpInterface + void startOp(const Scope &scope) override { pImpl->startOp(scope); } + void stopOp(const Scope &scope) override { pImpl->stopOp(scope); } + + void setOpInProgress(bool value) override { + profilerState.isRecording = value; + } + + bool isOpInProgress() override { return profilerState.isRecording; } + + // Profiler + virtual void doStart() override { pImpl->doStart(); } + virtual void doFlush() override { pImpl->doFlush(); } + virtual void doStop() override { pImpl->doStop(); } + + struct ProfilerState { + ConcreteProfilerT &profiler; + std::set dataSet; + bool isRecording{false}; + Scope scope{}; + + ProfilerState(ConcreteProfilerT &profiler) : profiler(profiler) {} + + void record(const Scope &scope) { + this->scope = scope; + // Take a snapshot of the current dataset + this->dataSet = profiler.getDataSet(); + } + + void enterOp() { + profiler.enterOp(scope); + for (auto data : dataSet) + data->enterOp(scope); + } + + void exitOp() { + profiler.exitOp(scope); + for (auto data : dataSet) + data->exitOp(this->scope); + } + }; + + struct Correlation { + std::atomic maxSubmittedCorrelationId{0}; + std::atomic maxCompletedCorrelationId{0}; + + Correlation() = default; + + void submit(const uint64_t correlationId) { + atomicMax(maxSubmittedCorrelationId, correlationId); + } + + void complete(const uint64_t correlationId) { + atomicMax(maxCompletedCorrelationId, correlationId); + } + + template + void flush(uint64_t maxRetries, uint64_t sleepMs, FlushFnT &&flushFn) { + flushFn(); + auto submittedId = maxSubmittedCorrelationId.load(); + auto completedId = maxCompletedCorrelationId.load(); + auto retries = maxRetries; + while ((completedId < submittedId) && retries > 0) { + std::this_thread::sleep_for(std::chrono::microseconds(sleepMs)); + flushFn(); + completedId = maxCompletedCorrelationId.load(); + --retries; + } + } + }; + + static thread_local ProfilerState profilerState; + Correlation correlation; + + // Use the pimpl idiom to hide the implementation details. This lets us avoid + // including the cupti header from this header. The cupti header and the + // equivalent header from AMD define conflicting macros, so we want to use + // those headers only within cpp files. + class GPUProfilerPimplInterface { + public: + GPUProfilerPimplInterface(ConcreteProfilerT &profiler) + : profiler(profiler) {} + virtual ~GPUProfilerPimplInterface() = default; + + virtual void startOp(const Scope &scope) = 0; + virtual void stopOp(const Scope &scope) = 0; + virtual void doStart() = 0; + virtual void doFlush() = 0; + virtual void doStop() = 0; + + protected: + ConcreteProfilerT &profiler; + }; + std::unique_ptr pImpl; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_GPU_PROFILER_H_ diff --git a/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h b/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h index c6c614b5ec9e..2f1791dcb506 100644 --- a/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h +++ b/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h @@ -1,98 +1,17 @@ #ifndef PROTON_PROFILER_ROCTRACER_PROFILER_H_ #define PROTON_PROFILER_ROCTRACER_PROFILER_H_ -#include "Context/Context.h" -#include "Profiler.h" -#include "roctracer/roctracer.h" - -#include +#include "GPUProfiler.h" namespace proton { -class RoctracerProfiler : public Profiler, - public OpInterface, - public Singleton { +class RoctracerProfiler : public GPUProfiler { public: - RoctracerProfiler() = default; - virtual ~RoctracerProfiler() = default; - - // External Correlation - enum CorrelationDomain { - begin, - Default = begin, - Domain0 = begin, - Domain1, - end, - size = end - }; - static void pushCorrelationID(uint64_t id, CorrelationDomain type); - static void popCorrelationID(CorrelationDomain type); - -protected: - // OpInterface - void startOp(const Scope &scope) override final; - void stopOp(const Scope &scope) override final; - void setOpInProgress(bool value) override final; - bool isOpInProgress() override final; - - // Profiler - void doStart() override; - void doFlush() override; - void doStop() override; + RoctracerProfiler(); + virtual ~RoctracerProfiler(); private: - static void apiCallback(uint32_t domain, uint32_t cid, - const void *callback_data, void *arg); - static void activityCallback(const char *begin, const char *end, void *arg); - static void processActivity(std::map &correlation, - std::set &dataSet, - const roctracer_record_t *activity); - - const inline static size_t AlignSize = 8; - const inline static size_t BufferSize = 64 * 1024 * 1024; - - std::map correlation; - std::mutex correlationLock; - - bool externalCorrelationEnabled{true}; - - struct RoctracerState { - RoctracerProfiler &profiler; - std::set dataSet; - size_t level{0}; - bool isRecording{false}; - Scope scope{}; - - RoctracerState(RoctracerProfiler &profiler) : profiler(profiler) {} - - void record(const Scope &scope, const std::set &dataSet) { - this->scope = scope; - this->dataSet.insert(dataSet.begin(), dataSet.end()); - } - - void reset() { - dataSet.clear(); - level = 0; - scope = Scope(); - } - - void enterOp() { - profiler.enterOp(scope); - for (auto data : dataSet) { - data->enterOp(scope); - } - } - - void exitOp() { - profiler.exitOp(scope); - for (auto data : dataSet) { - data->exitOp(this->scope); - } - } - }; - - static inline thread_local RoctracerState roctracerState{ - RoctracerProfiler::instance()}; + struct RoctracerProfilerPimpl; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Utility/Atomic.h b/third_party/proton/csrc/include/Utility/Atomic.h new file mode 100644 index 000000000000..d7e40e73cd24 --- /dev/null +++ b/third_party/proton/csrc/include/Utility/Atomic.h @@ -0,0 +1,19 @@ +#include + +namespace proton { + +template T atomicMax(std::atomic &target, T value) { + T current = target.load(); + while (current < value && !target.compare_exchange_weak(current, value)) + ; + return current; +} + +template T atomicMin(std::atomic &target, T value) { + T current = target.load(); + while (current > value && !target.compare_exchange_weak(current, value)) + ; + return current; +} + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp index d58127dcb5c8..aae8b4ceb3ab 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp @@ -49,7 +49,8 @@ Device getDevice(uint64_t index) { int minor; cuda::deviceGetAttribute( &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - auto arch = major * 10 + minor; + std::string arch = std::to_string(major * 10 + minor); + return Device(DeviceType::CUDA, index, clockRate, memoryClockRate, busWidth, numSms, arch); } diff --git a/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp index 3e994ee82db7..18de4a4f62a3 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp @@ -21,6 +21,11 @@ DEFINE_DISPATCH(ExternLibHip, deviceSynchronize, hipDeviceSynchronize) DEFINE_DISPATCH(ExternLibHip, deviceGetAttribute, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, int); +DEFINE_DISPATCH(ExternLibHip, getDeviceCount, hipGetDeviceCount, int *); + +DEFINE_DISPATCH(ExternLibHip, getDeviceProperties, hipGetDeviceProperties, + hipDeviceProp_t *, int); + Device getDevice(uint64_t index) { int clockRate; (void)hip::deviceGetAttribute(&clockRate, hipDeviceAttributeClockRate, @@ -35,13 +40,30 @@ Device getDevice(uint64_t index) { (void)hip::deviceGetAttribute( &smCount, hipDeviceAttributeMultiprocessorCount, index); - // TODO: Compute capability is a NVIDIA concept. It doesn't map naturally to - // AMD GPUs. Figure out a better way to support this. - uint64_t arch = 0; + std::string arch = getHipArchName(index); + return Device(DeviceType::HIP, index, clockRate, memoryClockRate, busWidth, smCount, arch); } +// TODO: hipDeviceProp_t was updated to point from hipDeviceProp_tR0000 -> +// hipDeviceProp_tR0600 as part of a breaking API change in Rocm 6.0 +// https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/driver.c +// uses hipDeviceProp_tR0000 and imports the hip_deprecated.h header file to be +// be back compatible with ROCm 5.x. PyTorch stills needs to support 5.x and the +// hipDeviceProp_tR0600 symbol does not exist pre-Rocm 6.0. Calling +// hipDeviceProp_tR0000 here with Rocm 6.1 causes a stack corruption. Therefore +// were will use hipDeviceProp_t and investigate if we can unify the definitions +// in the two files. + +const std::string getHipArchName(uint64_t index) { + hipDeviceProp_t devProp; + (void)hip::getDeviceProperties(&devProp, index); + std::string gcnArchName(devProp.gcnArchName); + std::string hipArch = gcnArchName.substr(0, 6); + return hipArch; +} + const char *getKernelNameRef(const hipFunction_t f) { typedef const char *(*hipKernelNameRef_t)(const hipFunction_t); static hipKernelNameRef_t func = nullptr; diff --git a/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp new file mode 100644 index 000000000000..e07f5eb1b619 --- /dev/null +++ b/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp @@ -0,0 +1,35 @@ +#include "Driver/GPU/HsaApi.h" +#include "Driver/Dispatch.h" + +namespace proton { + +namespace hsa { + +struct ExternLibHsa : public ExternLibBase { + using RetType = hsa_status_t; + static constexpr const char *name = "libhsa-runtime64.so"; + static constexpr RetType success = HSA_STATUS_SUCCESS; + static void *lib; +}; + +void *ExternLibHsa::lib = nullptr; + +DEFINE_DISPATCH(ExternLibHsa, agentGetInfo, hsa_agent_get_info, hsa_agent_t, + hsa_agent_info_t, void *); + +hsa_status_t iterateAgents(hsa_status_t (*callback)(hsa_agent_t agent, + void *data), + void *data) { + typedef hsa_status_t (*hsa_iterate_agents_t)( + hsa_status_t (*)(hsa_agent_t, void *), void *data); + static hsa_iterate_agents_t func = nullptr; + Dispatch::init(ExternLibHsa::name, &ExternLibHsa::lib); + if (func == nullptr) + func = reinterpret_cast( + dlsym(ExternLibHsa::lib, "hsa_iterate_agents")); + return (func ? func(callback, data) : HSA_STATUS_ERROR_FATAL); +} + +} // namespace hsa + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp index 01649d4492d8..81cef5fa08cb 100644 --- a/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp @@ -11,42 +11,11 @@ namespace proton { -namespace { -struct CuptiState { - CuptiProfiler *profiler; - std::set dataSet; - size_t level{0}; - bool isRecording{false}; - Scope scope{}; - - void record(const Scope &scope, CuptiProfiler *profiler) { - this->scope = scope; - this->profiler = profiler; - this->dataSet = profiler->getDataSet(); - } - - void reset() { - dataSet.clear(); - level = 0; - scope = Scope(); - } - - void enterOp() { - profiler->enterOp(scope); - for (auto data : dataSet) { - data->enterOp(scope); - } - } +template <> +thread_local GPUProfiler::ProfilerState + GPUProfiler::profilerState(CuptiProfiler::instance()); - void exitOp() { - profiler->exitOp(scope); - for (auto data : dataSet) { - data->exitOp(this->scope); - } - } -}; - -static thread_local CuptiState cuptiState; +namespace { std::shared_ptr convertActivityToMetric(CUpti_Activity *activity) { std::shared_ptr metric; @@ -74,45 +43,49 @@ void addMetric(size_t scopeId, std::set &dataSet, } } -void processActivityExternalCorrelation(std::map &correlation, - CUpti_Activity *activity) { - auto *externalCorrelation = +uint32_t +processActivityExternalCorrelation(std::map &corrIdToExternId, + CUpti_Activity *activity) { + auto *externalActivity = reinterpret_cast(activity); - correlation[externalCorrelation->correlationId] = - externalCorrelation->externalId; + corrIdToExternId[externalActivity->correlationId] = + externalActivity->externalId; + return externalActivity->correlationId; } -void processActivityKernel(std::map &correlation, - std::set &dataSet, - CUpti_Activity *activity) { +uint32_t processActivityKernel(std::map &corrIdToExternId, + std::set &dataSet, + CUpti_Activity *activity) { // Support CUDA >= 11.0 auto *kernel = reinterpret_cast(activity); auto correlationId = kernel->correlationId; - // TODO: non-triton kernels - if (correlation.find(correlationId) == correlation.end()) { - return; - } - auto externalId = correlation[correlationId]; + if (corrIdToExternId.find(correlationId) == corrIdToExternId.end()) + return correlationId; + auto externalId = corrIdToExternId[correlationId]; addMetric(externalId, dataSet, activity); // Track correlation ids from the same stream and erase those < correlationId - correlation.erase(correlationId); + corrIdToExternId.erase(correlationId); + return correlationId; } -void processActivity(std::map &correlation, - std::set &dataSet, CUpti_Activity *activity) { +uint32_t processActivity(std::map &corrIdToExternId, + std::set &dataSet, CUpti_Activity *activity) { + auto correlationId = 0; switch (activity->kind) { case CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION: { - processActivityExternalCorrelation(correlation, activity); + correlationId = + processActivityExternalCorrelation(corrIdToExternId, activity); break; } case CUPTI_ACTIVITY_KIND_KERNEL: case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { - processActivityKernel(correlation, dataSet, activity); + correlationId = processActivityKernel(corrIdToExternId, dataSet, activity); break; } default: break; } + return correlationId; } std::pair matchKernelCbId(CUpti_CallbackId cbId) { @@ -153,14 +126,14 @@ std::pair matchKernelCbId(CUpti_CallbackId cbId) { } // namespace -struct CuptiProfiler::CuptiProfilerPimpl { - CuptiProfilerPimpl() = default; +struct CuptiProfiler::CuptiProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + CuptiProfilerPimpl(CuptiProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) {} virtual ~CuptiProfilerPimpl() = default; void startOp(const Scope &scope); void stopOp(const Scope &scope); - void setOpInProgress(bool value); - bool isOpInProgress(); void doStart(); void doFlush(); @@ -173,10 +146,10 @@ struct CuptiProfiler::CuptiProfilerPimpl { static void callbackFn(void *userData, CUpti_CallbackDomain domain, CUpti_CallbackId cbId, const void *cbData); - const inline static size_t AlignSize = 8; - const inline static size_t BufferSize = 64 * 1024 * 1024; + static constexpr size_t AlignSize = 8; + static constexpr size_t BufferSize = 64 * 1024 * 1024; - std::map correlation; + std::map corrIdToExternId; CUpti_SubscriberHandle subscriber{}; }; @@ -198,15 +171,17 @@ void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, size_t validSize) { CuptiProfiler &profiler = dynamic_cast(CuptiProfiler::instance()); - auto &correlation = profiler.pImpl->correlation; + auto &pImpl = dynamic_cast(*profiler.pImpl.get()); auto &dataSet = profiler.dataSet; - + uint32_t maxCorrelationId = 0; CUptiResult status; CUpti_Activity *activity = nullptr; do { status = cupti::activityGetNextRecord(buffer, validSize, &activity); if (status == CUPTI_SUCCESS) { - processActivity(correlation, dataSet, activity); + auto correlationId = + processActivity(pImpl.corrIdToExternId, dataSet, activity); + maxCorrelationId = std::max(maxCorrelationId, correlationId); } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { break; } else { @@ -214,7 +189,9 @@ void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, } } while (true); - free(buffer); + std::free(buffer); + + profiler.correlation.complete(maxCorrelationId); } void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, @@ -230,22 +207,16 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, const CUpti_CallbackData *callbackData = reinterpret_cast(cbData); if (callbackData->callbackSite == CUPTI_API_ENTER) { - if (callbackData->context && cuptiState.level == 0) { + if (callbackData->context) { // Valid context and outermost level of the kernel launch auto scopeId = Scope::getNewScopeId(); auto scope = Scope(scopeId, callbackData->symbolName); - cuptiState.record(scope, &profiler); - cuptiState.enterOp(); + profilerState.record(scope); } - cuptiState.level++; + profilerState.enterOp(); } else if (callbackData->callbackSite == CUPTI_API_EXIT) { - cuptiState.level--; - if (cuptiState.level == 0) { - if (cuptiState.isRecording) { - cuptiState.exitOp(); - } - cuptiState.reset(); - } + profilerState.exitOp(); + profiler.correlation.submit(callbackData->correlationId); } } @@ -260,14 +231,6 @@ void CuptiProfiler::CuptiProfilerPimpl::stopOp(const Scope &scope) { CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, &correlationId); } -void CuptiProfiler::CuptiProfilerPimpl::setOpInProgress(bool value) { - cuptiState.isRecording = value; -} - -bool CuptiProfiler::CuptiProfilerPimpl::isOpInProgress() { - return cuptiState.isRecording; -} - void CuptiProfiler::CuptiProfilerPimpl::doStart() { cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); cupti::activityEnable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION); @@ -284,12 +247,28 @@ void CuptiProfiler::CuptiProfilerPimpl::doStart() { } void CuptiProfiler::CuptiProfilerPimpl::doFlush() { - CUcontext cu_context = nullptr; - cuda::ctxGetCurrent(&cu_context); - if (cu_context) { + // cuptiActivityFlushAll returns the activity records associated with all + // contexts/streams. + // This is a blocking call but it doesn’t issue any CUDA synchronization calls + // implicitly thus it’s not guaranteed that all activities are completed on + // the underlying devices. + // We do an "oppurtunistic" synchronization here to try to ensure that all + // activities are completed on the current context. + // If the current context is not set, we don't do any synchronization. + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) cuda::ctxSynchronize(); - } - cupti::activityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED); + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepMs=*/10, + /*flush=*/[]() { + cupti::activityFlushAll( + /*flag=*/0); + }); + // CUPTI_ACTIVITY_FLAG_FLUSH_FORCED is used to ensure that even incomplete + // activities are flushed so that the next profiling session can start with + // new activities. + cupti::activityFlushAll(/*flag=*/CUPTI_ACTIVITY_FLAG_FLUSH_FORCED); } void CuptiProfiler::CuptiProfilerPimpl::doStop() { @@ -304,25 +283,10 @@ void CuptiProfiler::CuptiProfilerPimpl::doStop() { cupti::finalize(); } -CuptiProfiler::CuptiProfiler() - : pImpl(std::make_unique()) {} - -CuptiProfiler::~CuptiProfiler() = default; - -void CuptiProfiler::startOp(const Scope &scope) { pImpl->startOp(scope); } - -void CuptiProfiler::stopOp(const Scope &scope) { pImpl->stopOp(scope); } - -void CuptiProfiler::setOpInProgress(bool value) { - pImpl->setOpInProgress(value); +CuptiProfiler::CuptiProfiler() { + pImpl = std::make_unique(*this); } -bool CuptiProfiler::isOpInProgress() { return pImpl->isOpInProgress(); } - -void CuptiProfiler::doStart() { pImpl->doStart(); } - -void CuptiProfiler::doFlush() { pImpl->doFlush(); } - -void CuptiProfiler::doStop() { pImpl->doStop(); } +CuptiProfiler::~CuptiProfiler() = default; } // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp index 412598f15192..a56d23e3e1b7 100644 --- a/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp @@ -2,6 +2,7 @@ #include "Context/Context.h" #include "Data/Metric.h" #include "Driver/GPU/HipApi.h" +#include "Driver/GPU/HsaApi.h" #include "Driver/GPU/RoctracerApi.h" #include "hip/amd_detail/hip_runtime_prof.h" @@ -12,27 +13,47 @@ #include #include #include +#include #include #include namespace proton { +template <> +thread_local GPUProfiler::ProfilerState + GPUProfiler::profilerState( + RoctracerProfiler::instance()); + namespace { -// Track dispatched ops to ensure a complete flush -class Flush { -public: - std::mutex mutex_; - std::atomic maxCorrelationId_; - uint64_t maxCompletedCorrelationId_{0}; - void reportCorrelation(const uint64_t cid) { - uint64_t prev = maxCorrelationId_; - while (prev < cid && !maxCorrelationId_.compare_exchange_weak(prev, cid)) { - } - } +// Node to device id mapping +int deviceOffset = 0x7fffffff; + +void createDeviceMap() { + int dc = 0; + auto ret = hip::getDeviceCount(&dc); + hsa::iterateAgents( + [](hsa_agent_t agent, void *data) { + auto &deviceOffset = *static_cast(data); + int nodeId; + hsa::agentGetInfo( + agent, + static_cast(HSA_AMD_AGENT_INFO_DRIVER_NODE_ID), + &nodeId); + int deviceType; + hsa::agentGetInfo( + agent, static_cast(HSA_AGENT_INFO_DEVICE), + &deviceType); + if ((nodeId < deviceOffset) && (deviceType == HSA_DEVICE_TYPE_GPU)) + deviceOffset = nodeId; + + return HSA_STATUS_SUCCESS; + }, + &deviceOffset); }; -Flush flushState; + +int mapDeviceId(int id) { return id - deviceOffset; }; std::shared_ptr convertActivityToMetric(const roctracer_record_t *activity) { @@ -42,7 +63,7 @@ convertActivityToMetric(const roctracer_record_t *activity) { metric = std::make_shared( static_cast(activity->begin_ns), static_cast(activity->end_ns), 1, - static_cast(activity->device_id), + static_cast(mapDeviceId(activity->device_id)), static_cast(DeviceType::HIP)); break; } @@ -59,136 +80,38 @@ void addMetric(size_t scopeId, std::set &dataSet, } } -void processActivityKernel(std::map &correlation, +void processActivityKernel(std::mutex &corrIdToExternIdMutex, + std::map &corrIdToExternId, std::set &dataSet, const roctracer_record_t *activity) { auto correlationId = activity->correlation_id; - // TODO: non-triton kernels - if (correlation.find(correlationId) == correlation.end()) { + std::unique_lock lock(corrIdToExternIdMutex); + if (corrIdToExternId.find(correlationId) == corrIdToExternId.end()) return; - } - auto externalId = correlation[correlationId]; + auto externalId = corrIdToExternId[correlationId]; addMetric(externalId, dataSet, activity); // Track correlation ids from the same stream and erase those < correlationId - correlation.erase(correlationId); -} - -std::mutex externalIdLock; -thread_local std::deque - externalIdMap[RoctracerProfiler::CorrelationDomain::size]; - -} // namespace - -// External correlation -void RoctracerProfiler::pushCorrelationID(uint64_t id, CorrelationDomain type) { - if (!instance().externalCorrelationEnabled) { - return; - } - std::scoped_lock lock(externalIdLock); - externalIdMap[type].push_back(id); -} - -void RoctracerProfiler::popCorrelationID(CorrelationDomain type) { - if (!instance().externalCorrelationEnabled) { - return; - } - std::scoped_lock lock(externalIdLock); - externalIdMap[type].pop_back(); -} - -void RoctracerProfiler::startOp(const Scope &scope) { - pushCorrelationID(scope.scopeId, Default); -} - -void RoctracerProfiler::stopOp(const Scope &scope) { - popCorrelationID(Default); -} - -void RoctracerProfiler::setOpInProgress(bool value) { - roctracerState.isRecording = value; -} - -bool RoctracerProfiler::isOpInProgress() { return roctracerState.isRecording; } - -void RoctracerProfiler::doStart() { - // Inline Callbacks - // roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HSA_API, - // api_callback, nullptr); - roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HIP_API, apiCallback, - nullptr); - - // Activity Records - roctracer_properties_t properties; - memset(&properties, 0, sizeof(roctracer_properties_t)); - properties.buffer_size = 0x1000; - properties.buffer_callback_fun = activityCallback; - roctracer::openPool(&properties); - roctracer::enableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); - roctracer::start(); -} - -void RoctracerProfiler::doFlush() { - // Implement reliable flushing. Wait for all dispatched ops to be reported - auto ret = hip::deviceSynchronize(); - roctracer::flushActivity(); - std::unique_lock lock(flushState.mutex_); - // Load ending id from the running max - auto correlationId = flushState.maxCorrelationId_.load(); - - // Poll on the worker finding the final correlation id - int timeout = 500; - while ((flushState.maxCompletedCorrelationId_ < correlationId) && --timeout) { - lock.unlock(); - roctracer::flushActivity(); - usleep(1000); - lock.lock(); - } -} - -void RoctracerProfiler::doStop() { - roctracer::stop(); - // roctracer::disable_domain_callback(ACTIVITY_DOMAIN_HSA_API); - roctracer::disableDomainCallback(ACTIVITY_DOMAIN_HIP_API); - roctracer::disableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); - roctracer::closePool(); -} - -void RoctracerProfiler::activityCallback(const char *begin, const char *end, - void *arg) { - RoctracerProfiler &profiler = - dynamic_cast(RoctracerProfiler::instance()); - auto &correlation = profiler.correlation; - auto &dataSet = profiler.dataSet; - - std::unique_lock lock(flushState.mutex_); - const roctracer_record_t *record = (const roctracer_record_t *)(begin); - const roctracer_record_t *endRecord = (const roctracer_record_t *)(end); - - while (record < endRecord) { - // Log latest completed correlation id. Used to ensure we have flushed all - // data on stop - if (record->correlation_id > flushState.maxCompletedCorrelationId_) { - flushState.maxCompletedCorrelationId_ = record->correlation_id; - } - std::scoped_lock lock(profiler.correlationLock); - processActivity(correlation, dataSet, record); - roctracer::getNextRecord(record, &record); - } + corrIdToExternId.erase(correlationId); } -void RoctracerProfiler::processActivity(std::map &correlation, - std::set &dataSet, - const roctracer_record_t *record) { +void processActivity(std::mutex &corrIdToExternIdMutex, + std::map &corrIdToExternId, + std::set &dataSet, + const roctracer_record_t *record) { switch (record->kind) { case 0x11F1: // Task - kernel enqueued by graph launch case kHipVdiCommandKernel: { - processActivityKernel(correlation, dataSet, record); + processActivityKernel(corrIdToExternIdMutex, corrIdToExternId, dataSet, + record); break; } - default:; + default: + break; } } +} // namespace + namespace { std::pair matchKernelCbId(uint32_t cbId) { @@ -215,21 +138,22 @@ std::pair matchKernelCbId(uint32_t cbId) { } return std::make_pair(isRuntimeApi, isDriverApi); } - // C++ symbol demangle -static inline const char *cxxDemangle(const char *symbol) { - size_t funcnamesize; +static inline const std::string cxxDemangle(const char *symbol) { + size_t funcNameSize; int status; - const char *ret = - (symbol != NULL) - ? abi::__cxa_demangle(symbol, NULL, &funcnamesize, &status) - : symbol; - return (ret != NULL) ? ret : symbol; + if (const char *name = + abi::__cxa_demangle(symbol, NULL, &funcNameSize, &status)) { + std::string ret(name); + std::free(reinterpret_cast(const_cast(name))); + return ret; + } + return std::string(symbol); } -const char *kernelName(uint32_t domain, uint32_t cid, - const void *callback_data) { - const char *name = ""; +const std::string kernelName(uint32_t domain, uint32_t cid, + const void *callback_data) { + std::string name; if (domain == ACTIVITY_DOMAIN_HIP_API) { const hip_api_data_t *data = (const hip_api_data_t *)(callback_data); switch (cid) { @@ -274,7 +198,8 @@ const char *kernelName(uint32_t domain, uint32_t cid, case HIP_API_ID_hipGraphLaunch: { name = "graphLaunch"; } break; - default:; + default: + break; } } return name; @@ -282,51 +207,137 @@ const char *kernelName(uint32_t domain, uint32_t cid, } // namespace -void RoctracerProfiler::apiCallback(uint32_t domain, uint32_t cid, - const void *callback_data, void *arg) { +enum CorrelationDomain { Default, Domain0, Domain1, Count }; + +struct RoctracerProfiler::RoctracerProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + RoctracerProfilerPimpl(RoctracerProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) {} + virtual ~RoctracerProfilerPimpl() = default; + + void startOp(const Scope &scope); + void stopOp(const Scope &scope); + + void doStart(); + void doFlush(); + void doStop(); + + static void apiCallback(uint32_t domain, uint32_t cid, + const void *callbackData, void *arg); + static void activityCallback(const char *begin, const char *end, void *arg); + + static constexpr size_t BufferSize = 64 * 1024 * 1024; + + std::mutex corrIdToExternIdMutex; + std::map corrIdToExternId; + inline static thread_local std::deque + externIdQueue[CorrelationDomain::Count]; +}; + +void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( + uint32_t domain, uint32_t cid, const void *callback_data, void *arg) { auto [isRuntimeAPI, isDriverAPI] = matchKernelCbId(cid); if (!(isRuntimeAPI || isDriverAPI)) { return; } - RoctracerProfiler &profiler = + auto &profiler = dynamic_cast(RoctracerProfiler::instance()); + auto &pImpl = dynamic_cast( + *profiler.pImpl); if (domain == ACTIVITY_DOMAIN_HIP_API) { const hip_api_data_t *data = (const hip_api_data_t *)(callback_data); if (data->phase == ACTIVITY_API_PHASE_ENTER) { - // if (callbackData->context && roctracerState.level == 0) { - { - // Valid context and outermost level of the kernel launch - const char *name = kernelName(domain, cid, callback_data); - // roctracer::getOpString(ACTIVITY_DOMAIN_HIP_API, cid, 0); // - // proper api name - auto scopeId = Scope::getNewScopeId(); - auto scope = Scope(scopeId, name); - roctracerState.record(scope, profiler.getDataSet()); - roctracerState.enterOp(); - - // Generate and Report external correlation - for (int it = CorrelationDomain::begin; it < CorrelationDomain::end; - ++it) { - std::scoped_lock lock(profiler.correlationLock, externalIdLock); - if (externalIdMap[it].size() > 0) { - profiler.correlation[data->correlation_id] = - externalIdMap[it].back(); - } - } - } - roctracerState.level++; + // Valid context and outermost level of the kernel launch + const std::string name = kernelName(domain, cid, callback_data); + auto scopeId = Scope::getNewScopeId(); + auto scope = Scope(scopeId, name); + profilerState.record(scope); + profilerState.enterOp(); + if (externIdQueue[CorrelationDomain::Domain0].empty()) + return; + std::unique_lock lock(pImpl.corrIdToExternIdMutex); + pImpl.corrIdToExternId[data->correlation_id] = + externIdQueue[CorrelationDomain::Domain0].back(); } else if (data->phase == ACTIVITY_API_PHASE_EXIT) { - roctracerState.level--; - if (roctracerState.level == 0) { - if (roctracerState.isRecording) { - roctracerState.exitOp(); - } - roctracerState.reset(); - } - - // track outstanding op for flush - flushState.reportCorrelation(data->correlation_id); + profilerState.exitOp(); + // Track outstanding op for flush + profiler.correlation.submit(data->correlation_id); } } } + +void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( + const char *begin, const char *end, void *arg) { + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto &pImpl = dynamic_cast( + *profiler.pImpl); + auto &dataSet = profiler.dataSet; + auto &correlation = profiler.correlation; + + const roctracer_record_t *record = + reinterpret_cast(begin); + const roctracer_record_t *endRecord = + reinterpret_cast(end); + uint64_t maxCorrelationId = 0; + + while (record != endRecord) { + // Log latest completed correlation id. Used to ensure we have flushed all + // data on stop + maxCorrelationId = + std::max(maxCorrelationId, record->correlation_id); + processActivity(pImpl.corrIdToExternIdMutex, pImpl.corrIdToExternId, + dataSet, record); + roctracer::getNextRecord(record, &record); + } + correlation.complete(maxCorrelationId); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::startOp(const Scope &scope) { + // Track correlation id for the scope + externIdQueue[CorrelationDomain::Domain0].push_back(scope.scopeId); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::stopOp(const Scope &scope) { + externIdQueue[CorrelationDomain::Domain0].pop_back(); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStart() { + roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HIP_API, apiCallback, + nullptr); + // Activity Records + roctracer_properties_t properties{0}; + properties.buffer_size = BufferSize; + properties.buffer_callback_fun = activityCallback; + roctracer::openPool(&properties); + roctracer::enableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::start(); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doFlush() { + // Implement reliable flushing. + // Wait for all dispatched ops to be reported. + std::ignore = hip::deviceSynchronize(); + // If flushing encounters an activity record still being written, flushing + // stops. Use a subsequent flush when the record has completed being written + // to resume the flush. + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepMs=*/10, /*flush=*/ + []() { roctracer::flushActivity(); }); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStop() { + roctracer::stop(); + roctracer::disableDomainCallback(ACTIVITY_DOMAIN_HIP_API); + roctracer::disableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::closePool(); +} + +RoctracerProfiler::RoctracerProfiler() { + pImpl = std::make_unique(*this); + createDeviceMap(); +} + +RoctracerProfiler::~RoctracerProfiler() = default; + } // namespace proton diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index fee1b1fa6198..3ef3a4c93ab4 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -43,14 +43,19 @@ def get_min_time_flops(df, device_info): continue max_flops = 0 if device_type == "CUDA": - if arch == 80: + if arch == "80": max_flops = 624e12 / (width / 8) - elif arch == 89: + elif arch == "89": # TODO(Keren): Implement fp16 acc-> 660.6 fp8 max_flops = (330.3 * 1e12) / (width / 8) - elif arch == 90: + elif arch == "90": # 114 sms and 1755mhz is the base number of sms and clock rate of H100 pcie max_flops = ((num_sms / 114 * clock_rate / (1755 * 1e3) * 1513) * 1e12) / (width / 8) + elif device_type == "HIP": + if arch == "gfx90a": + max_flops = 383e12 / (width / 8) + elif arch == "gfx941" or arch == "gfx942": + max_flops = 2614.9e12 / (width / 8) else: raise ValueError(f"Unsupported device type: {device_type}") min_time_flops.loc[idx, "min_time"] += device_frames[f"flops{width}"].fillna(0) / max_flops diff --git a/third_party/proton/test/example.json b/third_party/proton/test/example_cuda.json similarity index 96% rename from third_party/proton/test/example.json rename to third_party/proton/test/example_cuda.json index ea65853cd156..9e148ff79184 100644 --- a/third_party/proton/test/example.json +++ b/third_party/proton/test/example_cuda.json @@ -46,14 +46,14 @@ { "CUDA": { "0": { - "arch": 89, + "arch": "89", "bus_width": 384, "clock_rate": 2625000, "memory_clock_rate": 10501000, "num_sms": 128 }, "1": { - "arch": 90, + "arch": "90", "bus_width": 6144, "clock_rate": 1980000, "memory_clock_rate": 2619000, diff --git a/third_party/proton/test/example_hip.json b/third_party/proton/test/example_hip.json new file mode 100644 index 000000000000..2fcfad3c5d05 --- /dev/null +++ b/third_party/proton/test/example_hip.json @@ -0,0 +1,64 @@ + [ + { + "children": [ + { + "children": [], + "frame": { + "name": "foo0", + "type": "function" + }, + "metrics": { + "Count": 1, + "DeviceId": "1", + "DeviceType": "HIP", + "Time (ns)": 204800, + "flops8": 1e11, + "bytes": 1e8 + } + }, + { + "children": [], + "frame": { + "name": "foo1", + "type": "function" + }, + "metrics": { + "Count": 1, + "DeviceId": "0", + "DeviceType": "HIP", + "Time (ns)": 204800, + "flops8": 1e10, + "bytes": 1e7 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "Count": 0, + "Time (ns)": 0, + "flops8": 0, + "bytes": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + }, + "1": { + "arch": "gfx941", + "bus_width": 8192, + "clock_rate": 5200000, + "memory_clock_rate": 2525000, + "num_sms": 304 + } + } + } +] diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index 57295c79193f..63a74b06ce04 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -3,7 +3,8 @@ import numpy as np file_path = __file__ -example_file = file_path.replace("test_viewer.py", "example.json") +cuda_example_file = file_path.replace("test_viewer.py", "example_cuda.json") +hip_example_file = file_path.replace("test_viewer.py", "example_hip.json") def test_help(): @@ -13,7 +14,7 @@ def test_help(): def test_min_time_flops(): - with open(example_file, "r") as f: + with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_flops(gf.dataframe, device_info) device0_idx = gf.dataframe["DeviceId"] == "0" @@ -22,10 +23,19 @@ def test_min_time_flops(): np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000025]], atol=1e-5) # sm90 np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.00005]], atol=1e-5) + with open(hip_example_file, "r") as f: + gf, _, device_info = get_raw_metrics(f) + ret = get_min_time_flops(gf.dataframe, device_info) + device0_idx = gf.dataframe["DeviceId"] == "0" + device1_idx = gf.dataframe["DeviceId"] == "1" + # MI200 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000026]], atol=1e-5) + # MI300 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.000038]], atol=1e-5) def test_min_time_bytes(): - with open(example_file, "r") as f: + with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_bytes(gf.dataframe, device_info) device0_idx = gf.dataframe["DeviceId"] == "0" @@ -34,3 +44,12 @@ def test_min_time_bytes(): np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[9.91969e-06]], atol=1e-6) # sm90 np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[2.48584e-05]], atol=1e-6) + with open(hip_example_file, "r") as f: + gf, _, device_info = get_raw_metrics(f) + ret = get_min_time_bytes(gf.dataframe, device_info) + device0_idx = gf.dataframe["DeviceId"] == "0" + device1_idx = gf.dataframe["DeviceId"] == "1" + # MI200 + np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[6.10351e-06]], atol=1e-6) + # MI300 + np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[1.93378e-05]], atol=1e-6)