Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][NVGPU] Use gpu.dynamic_shared_memory in tests #133122

Merged
merged 2 commits into from
Mar 26, 2025
Merged

Conversation

grypp
Copy link
Member

@grypp grypp commented Mar 26, 2025

Reland #133051

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

Reland #133051


Patch is 21.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133122.diff

3 Files Affected:

  • (modified) mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir (+21-18)
  • (modified) mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir (+18-17)
  • (modified) mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir (+18-15)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
index 1c5cf73db6eba..a462122c80980 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
@@ -141,14 +141,19 @@ func.func @main() {
     %c16 = arith.constant 16 : index
     %c4096 = arith.constant 4096 : index
     %c8 = arith.constant 8 : index
-    %txcount = arith.constant 32768 : index     
+    %txcount = arith.constant 32768 : index    
+    %c24576 = arith.constant 24576 : index
+    %c16384 = arith.constant 16384 : index
+    %c49152 = arith.constant 49152 : index
+    %c57344 = arith.constant 57344 : index
+    %c40960 = arith.constant 40960 : index
 
     %tidx = gpu.thread_id  x
     %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
     %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
     %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
     %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
-    
+    %dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
     // Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
     %barrier = nvgpu.mbarrier.create -> !barrierType
     %cnd = arith.cmpi eq, %tidx, %c0 : index
@@ -161,31 +166,29 @@ func.func @main() {
     nvgpu.tma.prefetch.descriptor %descA : !lhsTensorMap
     nvgpu.tma.prefetch.descriptor %descB : !rhsTensorMap
     
-    // Step 4.1 [GPU] TMA Load Pipeline 1   
+    // Step 4.1 [GPU] TMA Load Pipeline 1       
     scf.if %cnd {
       %pipe = arith.constant 0 : index
-      %lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
-      %rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
-      %halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
-      %halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
+      %lhsSlice = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+      %halfFirst = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+      %halfSecond = memref.view %dynsmem[%c40960][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
       nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType        
       %dim = arith.muli %pipe, %c64 : index
-      nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
-      nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
-      nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
+      nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType ->  memref<128x64xf16, #gpu.address_space<workgroup>>
+      nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
+      nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
     }
     // Step 4.2 [GPU] TMA Load Pipeline 2
     scf.if %cnd {
       %pipe = arith.constant 1 : index
-      %lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
-      %rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
-      %halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
-      %halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
+      %lhsSlice = memref.view %dynsmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+      %halfFirst = memref.view %dynsmem[%c49152][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+      %halfSecond = memref.view %dynsmem[%c57344][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
       nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
       %dim = arith.muli %pipe, %c64 : index  
-      nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
-      nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
-      nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
+      nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType ->  memref<128x64xf16, #gpu.address_space<workgroup>>
+      nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
+      nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
     }
     
     // Step 5. [GPU] Initiliaze accumulator matrix
@@ -271,4 +274,4 @@ func.func @main() {
   vector.print %errorCount : i32
 
   return
-}
+}
\ No newline at end of file
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
index 6e8ef2b75eae6..b257d2b0f1e34 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
@@ -142,13 +142,18 @@ func.func @main() {
     %c4096 = arith.constant 4096 : index
     %c8 = arith.constant 8 : index
     %txcount = arith.constant 32768 : index     
+    %c24576 = arith.constant 24576 : index
+    %c16384 = arith.constant 16384 : index
+    %c49152 = arith.constant 49152 : index
+    %c57344 = arith.constant 57344 : index
+    %c40960 = arith.constant 40960 : index
 
     %tidx = gpu.thread_id  x
     %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
     %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
     %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
     %rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
-    
+    %dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
     // Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
     %barrier = nvgpu.mbarrier.create -> !barrierType
     
@@ -175,28 +180,25 @@ func.func @main() {
 
     // Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
     %pipe1 = arith.constant 0 : index
-    %p1lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
-    %p1rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
-    %p1halfFirst = memref.subview %p1rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
-    %p1halfSecond = memref.subview %p1rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
+    %lhsSlice1 = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+    %halfFirst1 = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+    %halfSecond1 = memref.view %dynsmem[%c40960][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
     nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe1], %txcount, predicate = %cnd : !barrierType        
     %dim1 = arith.muli %pipe1, %c64 : index
-    nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %p1lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
-    nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %p1halfFirst, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
-    nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %p1halfSecond, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
+    nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %lhsSlice1, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, #gpu.address_space<workgroup>>
+    nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %halfFirst1, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
+    nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %halfSecond1, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
 
     // Step 5. [GPU] TMA Load Pipeline 2 (predicated)
     %pipe2 = arith.constant 1 : index
-    %p2lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
-    %p2rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
-    %p2halfFirst = memref.subview %p2rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
-    %p2halfSecond = memref.subview %p2rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
+    %lhsSlice2 = memref.view %dynsmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+    %halfFirst2 = memref.view %dynsmem[%c49152][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+    %halfSecond2 = memref.view %dynsmem[%c57344][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
     nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe2], %txcount, predicate = %cnd : !barrierType
     %dim2 = arith.muli %pipe2, %c64 : index  
-    nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %p2lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType ->  memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
-    nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %p2halfFirst, predicate = %cnd : !rhsTensorMap, !barrierType ->  memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
-    nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %p2halfSecond, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
-    
+    nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %lhsSlice2, predicate = %cnd : !lhsTensorMap, !barrierType ->  memref<128x64xf16, #gpu.address_space<workgroup>>
+    nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %halfFirst2, predicate = %cnd : !rhsTensorMap, !barrierType ->  memref<64x64xf16, #gpu.address_space<workgroup>>
+    nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %halfSecond2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
     // Step 6. [GPU] Initiliaze accumulator matrix
     %14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
 
@@ -282,4 +284,3 @@ func.func @main() {
   return
 }
 
-
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
index 462040cd04a3d..b209114e957fa 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
@@ -39,8 +39,6 @@
 
 module @mymod {
   func.func private @printMemrefF32(memref<*xf32>)
-  memref.global "private" @bufferLhsGlobal : !shmemlhs
-  memref.global "private" @bufferRhsGlobal : !shmemrhs
   llvm.func @printf(!llvm.ptr, ...) -> i32
   func.func @main() {
     %c32768 = arith.constant 32768 : index
@@ -49,7 +47,7 @@ module @mymod {
     %c64 = arith.constant 64 : index
     %c1 = arith.constant 1 : index
     %c32 = arith.constant 32 : index
-    %c0 = arith.constant 0 : index
+    %c01 = arith.constant 0 : index
     %c128 = arith.constant 128 : index
     %c8 = arith.constant 8 : index
     
@@ -58,8 +56,8 @@ module @mymod {
     %rhs = memref.alloc() : !rhs
     %lhs32 = memref.alloc() : memref<128x64xf32>
     %rhs32 = memref.alloc() : memref<64x128xf32>
-    scf.for %i = %c0 to %c64 step %c1 {
-      scf.for %j = %c0 to %c128 step %c1 {
+    scf.for %i = %c01 to %c64 step %c1 {
+      scf.for %j = %c01 to %c128 step %c1 {
         %v0 = arith.muli %i, %c128 : index
         %v00 = arith.addi %v0, %j : index     
         %v01 = arith.divui %v00, %c8 : index 
@@ -92,15 +90,21 @@ module @mymod {
 
     %d_lhsTensorMap = nvgpu.tma.create.descriptor %d_lhs_unranked box[%c128, %c64] : memref<*xf16> -> !lhsTensorMap
     %d_rhsTensorMap = nvgpu.tma.create.descriptor %d_rhs_unranked box[%c64, %c64] : memref<*xf16> -> !rhsTensorMap
+    %c32768_i32 = arith.constant 32768 : i32
 
     // Step 4. Launch a GPU kernel
-    gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1) threads(%arg3, %arg4, %arg5) in (%arg9 = %c128, %arg10 = %c1, %arg11 = %c1) {
+    gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1) threads(%arg3, %arg4, %arg5) in (%arg9 = %c128, %arg10 = %c1, %arg11 = %c1) dynamic_shared_memory_size %c32768_i32 {
       %5 = gpu.block_dim  x
       %6 = gpu.thread_id  x
-      %lhsShmem = memref.get_global @bufferLhsGlobal : !shmemlhs
-      %rhsShmem = memref.get_global @bufferRhsGlobal : !shmemrhs
-      %rhsShmem1 = memref.subview %rhsShmem[0, 0][64, 64][1, 1] : !shmemrhs to memref<64x64xf16, strided<[128, 1]>, 3>
-      %rhsShmem2 = memref.subview %rhsShmem[32, 0][64, 64][1, 1] : !shmemrhs to memref<64x64xf16, strided<[128, 1], offset: 4096>, 3>
+      %c0 = arith.constant 0 : index
+      %txcount = arith.constant 32768 : index     
+      %c24576 = arith.constant 24576 : index
+      %c16384 = arith.constant 16384 : index
+      %dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+      %lhsSlice = memref.view %dynsmem[%c01][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+      %rhsSlice = memref.view %dynsmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+      %halfFirst = memref.view %dynsmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+      %halfSecond = memref.view %dynsmem[%c24576][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
     
       // Step 5. Initialize the mbarrier
       %9 = nvgpu.mbarrier.create -> !barrierType
@@ -110,9 +114,9 @@ module @mymod {
       // Step 6. First thread does TMA load
       scf.if %10 {
         gpu.printf "[GPU] TMA SIZE %d\0A", %c32768 : index
-        nvgpu.tma.async.load %d_lhsTensorMap[%c0, %c0], %9[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
-        nvgpu.tma.async.load %d_rhsTensorMap[%c0, %c0], %9[%c0] to %rhsShmem1 : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1]>, 3>
-        nvgpu.tma.async.load %d_rhsTensorMap[%c64, %c0], %9[%c0] to %rhsShmem2 : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 4096>, 3>
+        nvgpu.tma.async.load %d_lhsTensorMap[%c0, %c0], %9[%c0] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, #gpu.address_space<workgroup>>
+        nvgpu.tma.async.load %d_rhsTensorMap[%c0, %c0], %9[%c0] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
+        nvgpu.tma.async.load %d_rhsTensorMap[%c64, %c0], %9[%c0] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
         nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c32768 : !barrierType
       } else {
         nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c0 : !barrierType
@@ -127,7 +131,7 @@ module @mymod {
         gpu.printf "===--- Matrix B ---=== %d \n", %c-1_i32 : i32
         scf.for %ii = %c0 to %c64 step %c1 {
           scf.for %j = %c0 to %c128 step %c1 {
-            %lhs0 = memref.load %rhsShmem[%ii, %j] : !shmemrhs
+            %lhs0 = memref.load %rhsSlice[%ii, %j] : memref<64x128xf16, #gpu.address_space<workgroup>>
             %lhs032 = arith.extf %lhs0: f16 to f32
             gpu.printf "%.0f,   ", %lhs032 : f32
           }
@@ -209,4 +213,3 @@ module @mymod {
 // CHECK: 938,   938,   938,   938,   938,   938,   938,   938,   939,   939,   939,   939,   939,   939,   939,   939,   936,   936,   936,   936,   936,   936,   936,   936,   937,   937,   937,   937,   937,   937,   937,   937,   942,   942,   942,   942,   942,   942,   942,   942,   943,   943,   943,   943,   943,   943,   943,   943,   940,   940,   940,   940,   940,   940,   940,   940,   941,   941,   941,   941,   941,   941,   941,   941,   955,   955,   955,   955,   955,   955,   955,   955,   954,   954,   954,   954,   954,   954,   954,   954,   953,   953,   953,   953,   953,   953,   953,   953,   952,   952,   952,   952,   952,   952,   952,   952,   959,   959,   959,   959,   959,   959,   959,   959,   958,   958,   958,   958,   958,   958,   958,   958,   957,   957,   957,   957,   957,   957,   957,   957,   956,   956,   956,   956,   956,   956,   956,   956
 // CHECK: 972,   972,   972,   972,   972,   972,   972,   972,   973,   973,   973,   973,   973,   973,   973,   973,   974,   974,   974,   974,   974,   974,   974,   974,   975,   975,   975,   975,   975,   975,   975,   975,   968,   968,   968,   968,   968,   968,   968,   968,   969,   969,   969,   969,   969,   969,   969,   969,   970,   970,   970,   970,   970,   970,   970,   970,   971,   971,   971,   971,   971,   971,   971,   971,   989,   989,   989,   989,   989,   989,   989,   989,   988,   988,   988,   988,   988,   988,   988,   988,   991,   991,   991,   991,   991,   991,   991,   991,   990,   990,   990,   990,   990,   990,   990,   990,   985,   985,   985,   985,   9...
[truncated]

@grypp grypp merged commit e8dfd70 into llvm:main Mar 26, 2025
8 of 10 checks passed
@basioli-k
Copy link
Contributor

Hi!

I've noticed that the integration tests I mentioned here are still failing for me.

It seems that all the outputs don't match the reference. Are you seeing something different?

@grypp
Copy link
Member Author

grypp commented Mar 26, 2025

Thanks for running test for that. I've checked the generated PTX code, it looks the same to me before and after this PR. (It wasn't correct for the previous PR). Are you still seeing failures, can you please post it here?

@basioli-k
Copy link
Contributor

It's still the same as my comment on the original PR.
This seems like the most relevant bit:

Input was:

<<<<<<

          1: Correct Results :0 

          2: Incorrect Results :16384 

@matthias-springer
Copy link
Member

@basioli-k Is this still broken or has it been resolved? Can you also post all the tests that are failing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants