diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 6a681fd933971..da7693e189ca3 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) { return cast(Scope)->getSubprogram(); } -/// Erase \p V from \p BB and move \II forward to avoid invalidating -/// iterators. -static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, - BasicBlock &BB) { - auto *Inst = cast(V); - // Still used, don't erase. - if (!Inst->use_empty()) - return; - if (II != BB.rend() && Inst == &*II) - ++II; - Inst->eraseFromParent(); -} - /// Return true if V is a splat of a value (which is used when multiplying a /// matrix with a scalar). static bool isSplat(Value *V) { @@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) { /// Return the ShapeInfo for the result of \p I, it it can be determined. static std::optional computeShapeInfoForInst(Instruction *I, - const ValueMap &ShapeMap) { + const DenseMap &ShapeMap) { Value *M; Value *N; Value *K; @@ -492,10 +479,16 @@ class LowerMatrixIntrinsics { /// the result value of the instruction, with the only exceptions being store /// instructions and the matrix_column_major_store intrinsics. For those, the /// shape information indicates that those instructions should be lowered - /// using shape information as well. A ValueMap is used so that when - /// sub-passes like optimizeTransposes performs RAUW the map stays - /// up-to-date. - ValueMap ShapeMap; + /// using shape information as well. Note that extra care is needed when + /// erasing or RAUW'ing a value that is present in ShapeMap. If the + /// replacement is also a matrix operation, use + /// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to + /// ShapeMap. We don't use ValueMap, as there are also cases where we do not + /// want to add shape information for a replacement instruction. When directly + /// erasing a value with an entry in ShapeMap, use + /// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated + /// accordingly. + DenseMap ShapeMap; /// List of instructions to remove. While lowering, we are not replacing all /// users of a lowered instruction, if shape information is available and @@ -759,6 +752,30 @@ class LowerMatrixIntrinsics { return Operation(T0, Shape0.t(), T1, Shape1.t()); } + /// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst + /// itself. + void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) { + auto Iter = ShapeMap.find(Inst); + if (Iter != ShapeMap.end()) + ShapeMap.erase(Iter); + Inst->eraseFromParent(); + } + + /// Erase \p V from \p BB and move \II forward to avoid invalidating + /// iterators. + void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, + BasicBlock &BB) { + auto *Inst = cast(V); + // Still used, don't erase. + if (!Inst->use_empty()) + return; + if (II != BB.rend() && Inst == &*II) + ++II; + eraseFromParentAndRemoveFromShapeMap(Inst); + } + + /// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the + /// entry for \p Old and replace all uses of \p Old with \p New. void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { // We need to remove Old from the ShapeMap otherwise RAUW will replace it // with New. We should only add New it it supportsShapeInfo so we insert @@ -872,13 +889,13 @@ class LowerMatrixIntrinsics { void liftTranspose(Instruction &I) { // Erase dead Instructions after lifting transposes from binops. - auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { + auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) { if (T.use_empty()) - T.eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(&T); if (A->use_empty()) - cast(A)->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(cast(A)); if (A != B && B->use_empty()) - cast(B)->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(cast(B)); }; Value *A, *B, *AT, *BT; @@ -908,8 +925,7 @@ class LowerMatrixIntrinsics { match(B, m_Intrinsic( m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { IRBuilder<> Builder(&I); - auto *Add = cast(Builder.CreateFAdd(AT, BT, "mfadd")); - setShapeInfo(Add, {R, C}); + auto *Add = Builder.CreateFAdd(AT, BT, "mfadd"); MatrixBuilder MBuilder(Builder); Instruction *NewInst = MBuilder.CreateMatrixTranspose( Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); @@ -918,9 +934,13 @@ class LowerMatrixIntrinsics { computeShapeInfoForInst(&I, ShapeMap) && "Shape of new instruction doesn't match original shape."); CleanupBinOp(I, A, B); - assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) == - ShapeMap[Add] && - "Shape of updated addition doesn't match cached shape."); + if (auto *AddI = dyn_cast(Add)) { + setShapeInfo(AddI, {R, C}); + assert( + computeShapeInfoForInst(AddI, ShapeMap).value_or(ShapeMap[AddI]) == + ShapeMap[AddI] && + "Shape of updated addition doesn't match cached shape."); + } } } @@ -1014,7 +1034,8 @@ class LowerMatrixIntrinsics { // Third, try to fuse candidates. for (CallInst *CI : MaybeFusableInsts) - LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); + if (!FusedInsts.contains(CI)) + LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); Changed = !FusedInsts.empty(); @@ -1475,7 +1496,7 @@ class LowerMatrixIntrinsics { m_Value(Arg)))) { auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); Op->replaceAllUsesWith(NewLoad); - cast(Op)->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(cast(Op)); return; } else if (match(Op, m_Intrinsic( m_Value(Arg)))) { @@ -1844,15 +1865,15 @@ class LowerMatrixIntrinsics { // Mark eliminated instructions as fused and remove them. FusedInsts.insert(Store); FusedInsts.insert(MatMul); - Store->eraseFromParent(); - MatMul->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(Store); + eraseFromParentAndRemoveFromShapeMap(MatMul); if (LoadOp0->hasNUses(0)) { FusedInsts.insert(LoadOp0); - LoadOp0->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(LoadOp0); } if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) { FusedInsts.insert(LoadOp1); - LoadOp1->eraseFromParent(); + eraseFromParentAndRemoveFromShapeMap(LoadOp1); } } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll new file mode 100644 index 0000000000000..b78d56646d9e4 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s + +define void @test(ptr %p, <8 x i32> %x) { +; CHECK-LABEL: define void @test( +; CHECK-SAME: ptr [[P:%.*]], <8 x i32> [[X:%.*]]) { +; CHECK-NEXT: [[L:%.*]] = load <8 x i32>, ptr [[P]], align 4 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <1 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x i32> poison, i32 [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <1 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <8 x i32> [[TMP2]], i32 [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <1 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <8 x i32> [[TMP4]], i32 [[TMP5]], i64 2 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <1 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[TMP7]], i64 3 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <1 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <8 x i32> [[TMP8]], i32 [[TMP9]], i64 4 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <1 x i32> [[SPLIT5]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <8 x i32> [[TMP10]], i32 [[TMP11]], i64 5 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <1 x i32> [[SPLIT6]], i64 0 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <8 x i32> [[TMP12]], i32 [[TMP13]], i64 6 +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <1 x i32> [[SPLIT7]], i64 0 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x i32> [[TMP14]], i32 [[TMP15]], i64 7 +; CHECK-NEXT: [[TMP17:%.*]] = mul <8 x i32> [[L]], [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP17]]) +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i64 0 +; CHECK-NEXT: [[E:%.*]] = extractelement <1 x i32> [[TMP19]], i64 0 +; CHECK-NEXT: store i32 [[E]], ptr [[P]], align 4 +; CHECK-NEXT: ret void +; + %l = load <8 x i32>, ptr %p, align 4 + %t = tail call <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32> %x, i32 1, i32 8) + %m = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %l, <8 x i32> %t, i32 1, i32 8, i32 1) + %e = extractelement <1 x i32> %m, i64 0 + store i32 %e, ptr %p, align 4 + ret void +} + +declare <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32>, i32 immarg, i32 immarg) + +declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32 immarg, i32 immarg, i32 immarg) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll index 2fd77e245a34e..aadaf1ffffb23 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll @@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>, declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1 declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0 + +define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]] +; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]]) +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0 +; CHECK-NEXT: ret <1 x i32> [[TMP11]] +; +entry: + %t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2) + %shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> + %t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1) + %m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1) + ret <1 x i32> %m +} + +declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll new file mode 100644 index 0000000000000..5ac92da75409e --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting-constant-folds.ll @@ -0,0 +1,39 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" + +define <8 x float> @transpose_constant_fold_fadd_AT_BT() { +; CHECK-LABEL: define <8 x float> @transpose_constant_fold_fadd_AT_BT() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: ret <8 x float> +; +entry: + %t = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> splat (float 1.0), i32 8, i32 1) + %f = fadd <8 x float> %t, %t + ret <8 x float> %f +} + +define <8 x float> @transpose_constant_fold_fmul_A_k() { +; CHECK-LABEL: define <8 x float> @transpose_constant_fold_fmul_A_k() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <8 x float> , <8 x float> poison, <8 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x float> [[SPLAT]], <8 x float> poison, <4 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x float> [[SPLAT]], <8 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = fmul <4 x float> , [[SPLIT]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul <4 x float> , [[SPLIT1]] +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x float> [[TMP0]], <4 x float> [[TMP1]], <8 x i32> +; CHECK-NEXT: ret <8 x float> [[TMP2]] +; +entry: + %t.1 = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> splat (float 1.0), i32 4, i32 2) + %splat = shufflevector <8 x float> splat (float 3.0), <8 x float> poison, <8 x i32> zeroinitializer + %m = fmul <8 x float> %t.1, %splat + %t.2 = tail call <8 x float> @llvm.matrix.transpose.v8f32(<8 x float> %m, i32 2, i32 4) + ret <8 x float> %t.2 +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare <8 x float> @llvm.matrix.transpose.v8f32(<8 x float>, i32 immarg, i32 immarg) #0 + +attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll index fcf83b03bc3d2..1b3b41d8cfe1f 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll @@ -144,8 +144,28 @@ entry: ret <6 x double> %mul } +define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) { +; CHECK-LABEL: define void @test_remove_entries_from_shape_map( +; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2) +; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]] +; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2) +; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4 +; CHECK-NEXT: ret void +; +entry: + %m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2) + %add = fadd <6 x float> %c, %m + %t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2) + store <6 x float> %t, ptr %dst, align 4 + ret void +} + declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32) declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32) declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32) declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32) declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32) +declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg) +declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)