diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 78d7b723f6a4..f390d5a1e928 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2523,12 +2523,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if is_hip(): if (M, N, K) in [(64, 128, 128)]: pytest.skip(f"test_dot{(M, N, K)} not supported on HIP: memory out of resource.") - if (M, N, K, num_warps) in [(128, 256, 32, 8), (128, 128, 64, 4)]: - pytest.skip(f"test_dot{(M, N, K)} not supported on HIP. Reduce Warp to work") - if M == 16 or N == 16 or K == 16: - pytest.skip(f"test_dot{(M, N, K)} segfaults on HIP") - if epilogue == "softmax": - pytest.skip(f"test_dot{epilogue} segfaults on HIP") torch.backends.cuda.matmul.allow_tf32 = allow_tf32 diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 39fc317638d6..89a4bc6860c5 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -190,48 +190,75 @@ class BlockedToMFMA : public mlir::RewritePattern { return {nonKDim, kDim}; } - Value convertAndPromoteDotOperand(mlir::PatternRewriter &rewriter, - Value oldDotOperand, - ::mlir::Attribute encoding, - Type promotedElemType) const { - assert(promotedElemType.isIntOrFloat()); - - auto loc = oldDotOperand.getLoc(); - auto oldType = oldDotOperand.getType().cast(); + /** + * @brief Convert layout and cast element type of a given tensor + * + * If old element type is different from new element type, this function + * creates two new operations: + * 1. %converted_value = layout_convert %value, newEncoding + * 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType + * + * If old element type is same as new element type, this function creates only + * one operation: %converted_value = layout_convert %value, newEncoding + * + * @param rewriter + * @param value original tensor value, which we need to convert and cast + * @param newEncoding new encoding for the tenosr + * @param newElemType new element type for the tensor + * @return converted and optionaly casted tensor value + */ + Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value, + ::mlir::Attribute newEncoding, + Type newElemType) const { + assert(newElemType.isIntOrFloat()); + + auto loc = value.getLoc(); + auto oldType = value.getType().cast(); auto oldElemType = oldType.getElementType(); assert(oldElemType.isIntOrFloat()); - assert(oldElemType.isIntOrIndex() == promotedElemType.isIntOrIndex()); + assert(oldElemType.isIntOrIndex() == newElemType.isIntOrIndex()); auto convertedType = - RankedTensorType::get(oldType.getShape(), oldElemType, encoding); - - Value convertedDotOperand = rewriter.create( - loc, convertedType, oldDotOperand); - - if (promotedElemType == oldElemType) - return convertedDotOperand; - - Type promotedType = convertedType.cloneWith(std::nullopt, promotedElemType); - - Value promotedDotOperand; - - if (promotedElemType.isIntOrIndex()) { - // TODO Implement integer casting - assert(!promotedElemType.isIntOrIndex() && - "need to implement integer promotion"); + RankedTensorType::get(oldType.getShape(), oldElemType, newEncoding); + + Value convertedTensor = + rewriter.create(loc, convertedType, value); + + if (newElemType == oldElemType) + return convertedTensor; + + Type castedType = convertedType.cloneWith(std::nullopt, newElemType); + + Value castedTensor; + + if (newElemType.isIntOrIndex()) { + unsigned oldWidth = oldElemType.getIntOrFloatBitWidth(); + unsigned newWidth = newElemType.getIntOrFloatBitWidth(); + if (oldWidth == newWidth) + castedTensor = rewriter.create( + loc, convertedType, convertedTensor); + else if (oldWidth > newWidth) + castedTensor = rewriter.create(loc, castedType, + convertedTensor); + else if (oldElemType.isSignedInteger()) + castedTensor = rewriter.create(loc, castedType, + convertedTensor); + else + castedTensor = rewriter.create(loc, castedType, + convertedTensor); } else { - if (oldElemType.isF16() && promotedElemType.isF32()) - promotedDotOperand = rewriter.create( - loc, promotedType, convertedDotOperand); - else if (oldElemType.isF32() && promotedElemType.isF16()) - promotedDotOperand = rewriter.create( - loc, promotedType, convertedDotOperand); + if (oldElemType.isF16() && newElemType.isF32()) + castedTensor = rewriter.create(loc, castedType, + convertedTensor); + else if (oldElemType.isF32() && newElemType.isF16()) + castedTensor = rewriter.create(loc, castedType, + convertedTensor); else - promotedDotOperand = rewriter.create(loc, promotedType, - convertedDotOperand); + castedTensor = + rewriter.create(loc, castedType, convertedTensor); } - return promotedDotOperand; + return castedTensor; } mlir::LogicalResult @@ -279,8 +306,7 @@ class BlockedToMFMA : public mlir::RewritePattern { // convert accumulator auto oldAcc = dotOp.getOperand(2); - auto newAcc = - convertAndPromoteDotOperand(rewriter, oldAcc, mfmaEnc, mfmaAccType); + auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); // kWidth is a number of consecutive elements per one instruction per one // thread @@ -314,8 +340,8 @@ class BlockedToMFMA : public mlir::RewritePattern { dotOp.getMaxNumImpreciseAcc()); Value dotOutput = - convertAndPromoteDotOperand(rewriter, newDot, oldRetType.getEncoding(), - oldRetType.getElementType()); + convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), + oldRetType.getElementType()); rewriter.replaceOp(op, dotOutput);