Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELEASE] [AMD] Additional AMD cherry-picks #4175

Merged
merged 7 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class TargetInfoBase {
StringRef message, StringRef file, StringRef func,
int line) const = 0;

// Whether to enable linear layout. This is a per-backend temporary escape
// hatch to disable linear layout while figuring out issues. Eventually we
// want to enable linear layout everywhere and delete this control.
virtual bool enableLinearLayout() const { return true; }

virtual ~TargetInfoBase() {}
};
} // namespace mlir::triton
Expand Down
87 changes: 62 additions & 25 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,14 +908,21 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout,

inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) {
unsigned ctaBatchOffset, unsigned ctaOffsetX,
unsigned ctaOffsetY) {
const unsigned elemsPerThreadPerGroup = 8;
auto warpSize = getWarpSize(wmmaLayout);
assert(warpSize == 32);
auto shapePerCta = getShapePerCTATile(wmmaLayout);
auto rank = shapePerCta.size();
assert(rank == 2 || rank == 3);
SmallVector<unsigned> elemOffset(rank, 0);
if (rank == 3)
elemOffset[0] = ctaBatchOffset;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]});
elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem;
elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1];
offsets.push_back(elemOffset);
}
}

Expand All @@ -925,9 +932,11 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
RankedTensorType type) {
auto shape = type.getShape();
auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
auto rank = _warpsPerCTA.size();
assert(rank == 2 || rank == 3);
SmallVector<Value> warpsPerCTA;
for (unsigned i = 0; i < rank; ++i)
warpsPerCTA.push_back(i32_val(_warpsPerCTA[i]));
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();

Value threadId = getThreadId(rewriter, loc);
Expand All @@ -940,20 +949,34 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, _warpsPerCTA,
triton::gpu::getWarpOrder(wmmaLayout));
if (shape[0] >= mnkDim[0]) {
assert(shape[0] % mnkDim[0] == 0);
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(ceil<unsigned>(shape[0], mnkDim[0])));
if (shape[rank - 2] >= mnkDim[0]) {
assert(shape[rank - 2] % mnkDim[0] == 0);
multiDimWarpId[rank - 2] =
urem(multiDimWarpId[rank - 2],
i32_val(ceil<unsigned>(shape[rank - 2], mnkDim[0])));
}
if (shape[1] >= mnkDim[1]) {
assert(shape[1] % mnkDim[1] == 0);
multiDimWarpId[1] =
urem(multiDimWarpId[1], i32_val(ceil<unsigned>(shape[1], mnkDim[1])));
if (shape[rank - 1] >= mnkDim[1]) {
assert(shape[rank - 1] % mnkDim[1] == 0);
multiDimWarpId[rank - 1] =
urem(multiDimWarpId[rank - 1],
i32_val(ceil<unsigned>(shape[rank - 1], mnkDim[1])));
}
Value offWarp0 = mul(multiDimWarpId[0], i32_val(mnkDim[0]));
Value offWarp1 = mul(multiDimWarpId[1], i32_val(mnkDim[1]));
return {add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0),
add(laneId, offWarp1)};
Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0]));
Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1]));

SmallVector<Value> multiDimBase(rank);

multiDimBase[rank - 2] =
add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0);
multiDimBase[rank - 1] = add(laneId, offWarp1);

// TODO: It is assumed when rank = 3, warpsPerCTA is set to
// {numWarps, 1, 1}. We need to generalize the offset computation.
if (rank == 3) {
assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1);
multiDimBase[0] = urem(warpId, i32_val(shape[0]));
}
return multiDimBase;
}

inline SmallVector<SmallVector<unsigned>>
Expand All @@ -964,17 +987,31 @@ emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout,
auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape);
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();

SmallVector<unsigned> numWarpsPerDim(2);
auto rank = tensorShape.size();
assert(rank == 2 || rank == 3);

SmallVector<unsigned> numWarpsPerDim(rank, 1);
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();
for (unsigned d = 0; d < 2; ++d) {
SmallVector<unsigned> shapePerWarp(rank, 1);
shapePerWarp[rank - 2] = mnkDim[0];
shapePerWarp[rank - 1] = mnkDim[1];
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mnkDim[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, shapePerWarp[d]);
}

for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, i, j);
unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1;
unsigned repM = numWarpsPerDim[rank - 2];
unsigned repN = numWarpsPerDim[rank - 1];
auto warpsPerBatch =
rank == 3 ? std::min<unsigned>(tensorShape[0], warpsPerCTA[0]) : 1;

for (unsigned b = 0; b < repBatch; ++b) {
for (unsigned i = 0; i < repM; ++i) {
for (unsigned j = 0; j < repN; ++j) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j);
}
}
}
return offsets;
Expand Down Expand Up @@ -1170,7 +1207,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
bool allowLL = true) {
// Eventually the LinearLayout path will be the only one. For now we allow
// both paths so we can test that they produce the same results.
if (allowLL) {
if (allowLL && target.enableLinearLayout()) {
std::optional<SmallVector<SmallVector<Value>>> llOffsets =
emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type,
withCTAOffset);
Expand Down
12 changes: 6 additions & 6 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -798,16 +798,18 @@ def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encod
let mnemonic = "amd_mfma";

let description = [{
An encoding for tensors that have been produced by tensor cores of AMD MI GPUs.
An encoding for tensors that have been produced by MFMA matrix core instructions,
available on AMD Instinct GPUs of CDNA architectures.

It is characterized by the following parameters:
- `versionMajor` and `versionMinor` indicates the GPU arch
- `versionMajor` and `versionMinor` indicates the GPU architecture:
- 1.0: gfx908, i.e. MI100
- 2.0: gfx90a: i.e. MI200, MI210, MI250
- 3.0: gfx940, gfx941, gfx942: MI300
- `warpsPerCTA` indicates the wave layout in the workgroup.
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
without going to LDS. This is used in the case of chained dot (E.g. Flash-Attention kernel).
without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel).

Example 1:
Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32.
Expand Down Expand Up @@ -924,16 +926,14 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,

}];

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> {
let mnemonic = "amd_wmma";

let description = [{
An encoding for tensors that have been produced by WMMA instructions,
available on RDNA 3.
A `warpsPerCTA` parameter characterizes data distribution between waves.
An important limitation of WMMA for layout is a shape for tiles proccessed
by a single wave. It is [16, 16].
This encoding assumes specific access to matrix elements by threads.
Expand Down
8 changes: 6 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ bool supportMFMA(triton::DotOp op) {
auto bShape = bTy.getShape();

auto rank = aShape.size();
assert(bShape.size() == rank);
auto M = aShape[rank - 2];
auto N = bShape[rank - 1];
auto K = aShape[rank - 1];
Expand Down Expand Up @@ -521,8 +522,11 @@ bool supportWMMA(triton::DotOp op) {
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportWMMAGranularity(aShape[0], bShape[1], aShape[1]))
auto rank = aShape.size();
assert(bShape.size() == rank);
assert(aShape[rank - 1] == bShape[rank - 2]);
if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1],
aShape[rank - 1]))
return false;

return true;
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
} else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout)) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0],
emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
}
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
Expand Down
Loading