@@ -30,6 +30,7 @@ using namespace mlir;
3030
3131static constexpr const char *kBindMemRef1DFloat = " bindMemRef1DFloat" ;
3232static constexpr const char *kBindMemRef2DFloat = " bindMemRef2DFloat" ;
33+ static constexpr const char *kBindMemRef3DFloat = " bindMemRef3DFloat" ;
3334static constexpr const char *kCInterfaceVulkanLaunch =
3435 " _mlir_ciface_vulkanLaunch" ;
3536static constexpr const char *kDeinitVulkan = " deinitVulkan" ;
@@ -76,10 +77,12 @@ class VulkanLaunchFuncToVulkanCallsPass
7677 llvmPointerType = LLVM::LLVMType::getInt8PtrTy (llvmDialect);
7778 llvmInt32Type = LLVM::LLVMType::getInt32Ty (llvmDialect);
7879 llvmInt64Type = LLVM::LLVMType::getInt64Ty (llvmDialect);
79- initializeMemRefTypes ();
80+ llvmMemRef1DFloat = getMemRefType (1 );
81+ llvmMemRef2DFloat = getMemRefType (2 );
82+ llvmMemRef3DFloat = getMemRefType (3 );
8083 }
8184
82- void initializeMemRefTypes ( ) {
85+ LLVM::LLVMType getMemRefType ( uint32_t rank ) {
8386 // According to the MLIR doc memref argument is converted into a
8487 // pointer-to-struct argument of type:
8588 // template <typename Elem, size_t Rank>
@@ -91,22 +94,15 @@ class VulkanLaunchFuncToVulkanCallsPass
9194 // int64_t strides[Rank]; // omitted when rank == 0
9295 // };
9396 auto llvmPtrToFloatType = getFloatType ().getPointerTo ();
94- auto llvmArrayOneElementSizeType =
95- LLVM::LLVMType::getArrayTy (getInt64Type (), 1 );
96- auto llvmArrayTwoElementSizeType =
97- LLVM::LLVMType::getArrayTy (getInt64Type (), 2 );
97+ auto llvmArrayRankElementSizeType =
98+ LLVM::LLVMType::getArrayTy (getInt64Type (), rank);
9899
99- // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
100- llvmMemRef1DFloat = LLVM::LLVMType::getStructTy (
100+ // Create a type
101+ // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
102+ return LLVM::LLVMType::getStructTy (
101103 llvmDialect,
102104 {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type (),
103- llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
104-
105- // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`.
106- llvmMemRef2DFloat = LLVM::LLVMType::getStructTy (
107- llvmDialect,
108- {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type (),
109- llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType});
105+ llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
110106 }
111107
112108 LLVM::LLVMType getFloatType () { return llvmFloatType; }
@@ -116,6 +112,7 @@ class VulkanLaunchFuncToVulkanCallsPass
116112 LLVM::LLVMType getInt64Type () { return llvmInt64Type; }
117113 LLVM::LLVMType getMemRef1DFloat () { return llvmMemRef1DFloat; }
118114 LLVM::LLVMType getMemRef2DFloat () { return llvmMemRef2DFloat; }
115+ LLVM::LLVMType getMemRef3DFloat () { return llvmMemRef3DFloat; }
119116
120117 // / Creates a LLVM global for the given `name`.
121118 Value createEntryPointNameConstant (StringRef name, Location loc,
@@ -164,6 +161,7 @@ class VulkanLaunchFuncToVulkanCallsPass
164161 LLVM::LLVMType llvmInt64Type;
165162 LLVM::LLVMType llvmMemRef1DFloat;
166163 LLVM::LLVMType llvmMemRef2DFloat;
164+ LLVM::LLVMType llvmMemRef3DFloat;
167165
168166 // TODO: Use an associative array to support multiple vulkan launch calls.
169167 std::pair<StringAttr, StringAttr> spirvAttributes;
@@ -335,6 +333,16 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
335333 /* isVarArg=*/ false ));
336334 }
337335
336+ if (!module .lookupSymbol (kBindMemRef3DFloat )) {
337+ builder.create <LLVM::LLVMFuncOp>(
338+ loc, kBindMemRef3DFloat ,
339+ LLVM::LLVMType::getFunctionTy (getVoidType (),
340+ {getPointerType (), getInt32Type (),
341+ getInt32Type (),
342+ getMemRef3DFloat ().getPointerTo ()},
343+ /* isVarArg=*/ false ));
344+ }
345+
338346 if (!module .lookupSymbol (kInitVulkan )) {
339347 builder.create <LLVM::LLVMFuncOp>(
340348 loc, kInitVulkan ,
0 commit comments