-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to bfmmla
#145064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145064.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7385bb73b449a..c4007dd02c0d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+
+def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "BFloat16 matrix multiply-accumulate";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
+
+ This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
+ segment of the first source vector by the 4x2 BFloat16 matrix in the
+ corresponding segment of the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the corresponding
+ segment of the accumulator vector, yielding the final 2x2 Float32
+ segment of the result.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -590,6 +619,12 @@ def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def BfmmlaIntrOp :
+ ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
+ Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 35f2a02cc4ec6..73f388b6d81c0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
// Populate conversion patterns
// clang-format off
- patterns.add<ConvertFromSvboolOpLowering,
+ patterns.add<BfmmlaOpLowering,
+ ConvertFromSvboolOpLowering,
ConvertToSvboolOpLowering,
DupQLaneLowering,
PselOpLowering,
@@ -220,7 +222,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
- target.addLegalOp<ConvertFromSvboolIntrOp,
+ target.addLegalOp<BfmmlaIntrOp,
+ ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,
@@ -241,7 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ZipX2IntrOp,
ZipX4IntrOp,
SdotIntrOp>();
- target.addIllegalOp<ConvertFromSvboolOp,
+ target.addIllegalOp<BfmmlaOp,
+ ConvertFromSvboolOp,
ConvertToSvboolOp,
DupQLaneOp,
PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8c658db009adf..8673b994d1e71 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,6 +60,15 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.intr.bfmmla
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 64e0cff39eb06..9a653df767400 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index da71cb5a63bd2..737145c74e331 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
+llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
+ %arg1: vector<[8]xbf16>,
+ %arg2: vector<[4]xf32>)
+ -> vector<[4]xf32> {
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
+ %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
+ -> vector<[4]xf32>
+ llvm.return %0 : vector<[4]xf32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
|
@llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145064.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 7385bb73b449a..c4007dd02c0d3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+
+def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "BFloat16 matrix multiply-accumulate";
+ let description = [{
+ BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
+
+ This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
+ segment of the first source vector by the 4x2 BFloat16 matrix in the
+ corresponding segment of the second source vector, then accumulates
+ this intermediate result with the 2x2 Float32 matrix in the corresponding
+ segment of the accumulator vector, yielding the final 2x2 Float32
+ segment of the result.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
+ ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -590,6 +619,12 @@ def UsmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def BfmmlaIntrOp :
+ ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
+ Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
+ Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 35f2a02cc4ec6..73f388b6d81c0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,6 +25,7 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
@@ -191,7 +192,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
// Populate conversion patterns
// clang-format off
- patterns.add<ConvertFromSvboolOpLowering,
+ patterns.add<BfmmlaOpLowering,
+ ConvertFromSvboolOpLowering,
ConvertToSvboolOpLowering,
DupQLaneLowering,
PselOpLowering,
@@ -220,7 +222,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
- target.addLegalOp<ConvertFromSvboolIntrOp,
+ target.addLegalOp<BfmmlaIntrOp,
+ ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,
@@ -241,7 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ZipX2IntrOp,
ZipX4IntrOp,
SdotIntrOp>();
- target.addIllegalOp<ConvertFromSvboolOp,
+ target.addIllegalOp<BfmmlaOp,
+ ConvertFromSvboolOp,
ConvertToSvboolOp,
DupQLaneOp,
PselOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8c658db009adf..8673b994d1e71 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -60,6 +60,15 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.intr.bfmmla
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 64e0cff39eb06..9a653df767400 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
// -----
+func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
+ %b: vector<[8]xbf16>,
+ %c: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
+ %0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index da71cb5a63bd2..737145c74e331 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}
+// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
+llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
+ %arg1: vector<[8]xbf16>,
+ %arg2: vector<[4]xf32>)
+ -> vector<[4]xf32> {
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
+ %0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
+ (vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
+ -> vector<[4]xf32>
+ llvm.return %0 : vector<[4]xf32>
+}
+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There is one high-level bit that I missed before (apologies for that!) and that I realised when comparing this against #145038.
Basically, since we won't be doing anything "involved" with this new Op (other than lowering directlly to LLVM), using ArmSVE_IntrOp
(instead of ArmSVE_Op
) should be perfectly sufficient.
We should be taking similar approach with e.g. UsmmlaOp
, but that's a separate PR. I can take care of that.
Thanks!
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, | |||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; | |||
} | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] DELETEME
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure, | |||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; | |||
} | |||
|
|||
|
|||
def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure, | |
def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Please add some negative tests, otherwise LGTM, thanks! |
fcae83d
to
9bcd62e
Compare
Done.
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
No description provided.