Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Cleanup AccelerateAMDMatmulPass and enable more tests #3025

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2523,12 +2523,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if is_hip():
if (M, N, K) in [(64, 128, 128)]:
pytest.skip(f"test_dot{(M, N, K)} not supported on HIP: memory out of resource.")
Comment on lines 2524 to 2525
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not enable this test, because LDS size is not sufficient.
Looking for solution

if (M, N, K, num_warps) in [(128, 256, 32, 8), (128, 128, 64, 4)]:
pytest.skip(f"test_dot{(M, N, K)} not supported on HIP. Reduce Warp to work")
if M == 16 or N == 16 or K == 16:
pytest.skip(f"test_dot{(M, N, K)} segfaults on HIP")
if epilogue == "softmax":
pytest.skip(f"test_dot{epilogue} segfaults on HIP")

torch.backends.cuda.matmul.allow_tf32 = allow_tf32

Expand Down
102 changes: 64 additions & 38 deletions third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,48 +190,75 @@ class BlockedToMFMA : public mlir::RewritePattern {
return {nonKDim, kDim};
}

Value convertAndPromoteDotOperand(mlir::PatternRewriter &rewriter,
Value oldDotOperand,
::mlir::Attribute encoding,
Type promotedElemType) const {
assert(promotedElemType.isIntOrFloat());

auto loc = oldDotOperand.getLoc();
auto oldType = oldDotOperand.getType().cast<RankedTensorType>();
/**
* @brief Convert layout and cast element type of a given tensor
*
* If old element type is different from new element type, this function
* creates two new operations:
* 1. %converted_value = layout_convert %value, newEncoding
* 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType
*
* If old element type is same as new element type, this function creates only
* one operation: %converted_value = layout_convert %value, newEncoding
*
* @param rewriter
* @param value original tensor value, which we need to convert and cast
* @param newEncoding new encoding for the tenosr
* @param newElemType new element type for the tensor
* @return converted and optionaly casted tensor value
*/
Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value,
::mlir::Attribute newEncoding,
Type newElemType) const {
assert(newElemType.isIntOrFloat());

auto loc = value.getLoc();
auto oldType = value.getType().cast<RankedTensorType>();
auto oldElemType = oldType.getElementType();

assert(oldElemType.isIntOrFloat());
assert(oldElemType.isIntOrIndex() == promotedElemType.isIntOrIndex());
assert(oldElemType.isIntOrIndex() == newElemType.isIntOrIndex());

auto convertedType =
RankedTensorType::get(oldType.getShape(), oldElemType, encoding);

Value convertedDotOperand = rewriter.create<ttg::ConvertLayoutOp>(
loc, convertedType, oldDotOperand);

if (promotedElemType == oldElemType)
return convertedDotOperand;

Type promotedType = convertedType.cloneWith(std::nullopt, promotedElemType);

Value promotedDotOperand;

if (promotedElemType.isIntOrIndex()) {
// TODO Implement integer casting
assert(!promotedElemType.isIntOrIndex() &&
"need to implement integer promotion");
RankedTensorType::get(oldType.getShape(), oldElemType, newEncoding);

Value convertedTensor =
rewriter.create<ttg::ConvertLayoutOp>(loc, convertedType, value);

if (newElemType == oldElemType)
return convertedTensor;

Type castedType = convertedType.cloneWith(std::nullopt, newElemType);

Value castedTensor;

if (newElemType.isIntOrIndex()) {
unsigned oldWidth = oldElemType.getIntOrFloatBitWidth();
unsigned newWidth = newElemType.getIntOrFloatBitWidth();
if (oldWidth == newWidth)
castedTensor = rewriter.create<mlir::arith::BitcastOp>(
loc, convertedType, convertedTensor);
else if (oldWidth > newWidth)
castedTensor = rewriter.create<mlir::arith::TruncIOp>(loc, castedType,
convertedTensor);
else if (oldElemType.isSignedInteger())
castedTensor = rewriter.create<mlir::arith::ExtSIOp>(loc, castedType,
convertedTensor);
else
castedTensor = rewriter.create<mlir::arith::ExtUIOp>(loc, castedType,
convertedTensor);
} else {
if (oldElemType.isF16() && promotedElemType.isF32())
promotedDotOperand = rewriter.create<mlir::arith::ExtFOp>(
loc, promotedType, convertedDotOperand);
else if (oldElemType.isF32() && promotedElemType.isF16())
promotedDotOperand = rewriter.create<mlir::arith::TruncFOp>(
loc, promotedType, convertedDotOperand);
if (oldElemType.isF16() && newElemType.isF32())
castedTensor = rewriter.create<mlir::arith::ExtFOp>(loc, castedType,
convertedTensor);
else if (oldElemType.isF32() && newElemType.isF16())
castedTensor = rewriter.create<mlir::arith::TruncFOp>(loc, castedType,
convertedTensor);
else
promotedDotOperand = rewriter.create<tt::FpToFpOp>(loc, promotedType,
convertedDotOperand);
castedTensor =
rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
}
return promotedDotOperand;
return castedTensor;
}

mlir::LogicalResult
Expand Down Expand Up @@ -279,8 +306,7 @@ class BlockedToMFMA : public mlir::RewritePattern {

// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc =
convertAndPromoteDotOperand(rewriter, oldAcc, mfmaEnc, mfmaAccType);
auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType);

// kWidth is a number of consecutive elements per one instruction per one
// thread
Expand Down Expand Up @@ -314,8 +340,8 @@ class BlockedToMFMA : public mlir::RewritePattern {
dotOp.getMaxNumImpreciseAcc());

Value dotOutput =
convertAndPromoteDotOperand(rewriter, newDot, oldRetType.getEncoding(),
oldRetType.getElementType());
convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(),
oldRetType.getElementType());

rewriter.replaceOp(op, dotOutput);

Expand Down
Loading