Skip to content

[mlir][spirv] Add support for Aligned memory operand in CoopMatrix memory operations #145480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

IgWod-IMG
Copy link
Contributor

@IgWod-IMG IgWod-IMG commented Jun 24, 2025

In the process of adding support for Aligned, I have noticed that the support for MakePointerAvailable and MakePointerVisible is incomplete as the operation does not accept a scope nor check for NonPrivatePointer. The PR does not address it, but the relevant issues has been created #145485.

@IgWod-IMG IgWod-IMG marked this pull request as ready for review June 24, 2025 09:46
@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

Changes

In the process of adding support for Aligned, I have noticed that the support for MakePointerAvailable and MakePointerVisible is incomplete as the operation does not accept a scope nor check for NonPrivatePointer. The PR does not address it, but the relevant issues has been created #145485.


Full diff: https://github.com/llvm/llvm-project/pull/145480.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+10-6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-10)
  • (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+31-3)
  • (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 46732ba19afed..fd75532ae3d70 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   }];
 
   let assemblyFormat = [{
-    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
       type(operands) `->` type($result)
   }];
 
@@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
     Capability<[SPIRV_C_CooperativeMatrixKHR]>
   ];
 
+  // TODO: Add scope operand for MakePointer*. See #145485.
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
     SPIRV_Integer:$stride,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+    OptionalAttr<I32Attr>:$alignment
   );
 
   let results = (outs
@@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
                    "spirv::ConstantOp":$stride,
                    "spirv::CooperativeMatrixLayoutKHR":$layout), [{
       build($_builder, $_state, result, pointer, layout, stride,
-            spirv::MemoryAccessAttr{});
+            spirv::MemoryAccessAttr{}, IntegerAttr{});
     }]>
   ];
 }
@@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   }];
 
   let assemblyFormat = [{
-    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
       type(operands)
   }];
 
@@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
     Capability<[SPIRV_C_CooperativeMatrixKHR]>
   ];
 
+  // TODO: Add scope operand for MakePointer*. See #145485.
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_AnyCooperativeMatrix:$object,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
     SPIRV_Integer:$stride,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+    OptionalAttr<I32Attr>:$alignment
   );
 
   let results = (outs);
@@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
                    "spirv::ConstantOp":$stride,
                    "spirv::CooperativeMatrixLayoutKHR":$layout), [{
       build($_builder, $_state, pointer, object, layout, stride,
-            spirv::MemoryAccessAttr{});
+            spirv::MemoryAccessAttr{}, IntegerAttr{});
     }]>
   ];
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 2ff3efdc96a7f..fa20cc179f892 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -23,7 +23,8 @@ namespace mlir::spirv {
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
-                       spirv::MemoryAccessAttr memoryOperand) {
+                       spirv::MemoryAccessAttr memoryOperand,
+                       IntegerAttr alignment) {
   auto pointerType = cast<PointerType>(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa<ScalarType, VectorType>(pointeeType)) {
@@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
           "not compatible with memory operand 'MakePointerVisible'");
     }
 
-    // The 'Aligned' memory operand requires an alignment literal to follow,
-    // which needs to be implemented on the level of op parsing and
-    // (de-)serialization.
-    // TODO: Consider adding support for this attribute value.
-    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                  spirv::MemoryAccess::Aligned)) {
-      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
+    // #145485.
+
+    if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+        !alignment) {
+      return op->emitOpError("missing value for the 'Aligned' memory operand");
+    }
+
+    if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+        alignment) {
+      return op->emitOpError(
+          "found alignment attribute for non-'Aligned' memory operand");
     }
   }
 
@@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
-                                getResult().getType(), getMemoryOperandAttr());
+                                getResult().getType(), getMemoryOperandAttr(),
+                                getAlignmentAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
-                                getObject().getType(), getMemoryOperandAttr());
+                                getObject().getType(), getMemoryOperandAttr(),
+                                getAlignmentAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..56d477cca97b7 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_load_aligned
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
 // CHECK-LABEL: @cooperative_matrix_store
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
@@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBu
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_store_aligned
+spirv.func @cooperative_matrix_store_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
 // -----
 
 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuf
 // -----
 
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
@@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer
 // -----
 
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
@@ -179,7 +198,7 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageB
 
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
   spirv.Return
@@ -187,6 +206,15 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %str
 
 // -----
 
+spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <MakePointerVisible>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 153ff47937972..77949908e8883 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires
     spirv.Return
   }
 
+  // CHECK-LABEL: @cooperative_matrix_load_3
+  spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+      !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    spirv.Return
+  }
+
   // CHECK-LABEL: @cooperative_matrix_store_1
   spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                          %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
@@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
 
+    // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+      !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
     // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

Changes

In the process of adding support for Aligned, I have noticed that the support for MakePointerAvailable and MakePointerVisible is incomplete as the operation does not accept a scope nor check for NonPrivatePointer. The PR does not address it, but the relevant issues has been created #145485.


Full diff: https://github.com/llvm/llvm-project/pull/145480.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+10-6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-10)
  • (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+31-3)
  • (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 46732ba19afed..fd75532ae3d70 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   }];
 
   let assemblyFormat = [{
-    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+    $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
       type(operands) `->` type($result)
   }];
 
@@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
     Capability<[SPIRV_C_CooperativeMatrixKHR]>
   ];
 
+  // TODO: Add scope operand for MakePointer*. See #145485.
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
     SPIRV_Integer:$stride,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+    OptionalAttr<I32Attr>:$alignment
   );
 
   let results = (outs
@@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
                    "spirv::ConstantOp":$stride,
                    "spirv::CooperativeMatrixLayoutKHR":$layout), [{
       build($_builder, $_state, result, pointer, layout, stride,
-            spirv::MemoryAccessAttr{});
+            spirv::MemoryAccessAttr{}, IntegerAttr{});
     }]>
   ];
 }
@@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   }];
 
   let assemblyFormat = [{
-    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
+    $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
       type(operands)
   }];
 
@@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
     Capability<[SPIRV_C_CooperativeMatrixKHR]>
   ];
 
+  // TODO: Add scope operand for MakePointer*. See #145485.
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
     SPIRV_AnyCooperativeMatrix:$object,
     SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
     SPIRV_Integer:$stride,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
+    OptionalAttr<I32Attr>:$alignment
   );
 
   let results = (outs);
@@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
                    "spirv::ConstantOp":$stride,
                    "spirv::CooperativeMatrixLayoutKHR":$layout), [{
       build($_builder, $_state, pointer, object, layout, stride,
-            spirv::MemoryAccessAttr{});
+            spirv::MemoryAccessAttr{}, IntegerAttr{});
     }]>
   ];
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 2ff3efdc96a7f..fa20cc179f892 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -23,7 +23,8 @@ namespace mlir::spirv {
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
-                       spirv::MemoryAccessAttr memoryOperand) {
+                       spirv::MemoryAccessAttr memoryOperand,
+                       IntegerAttr alignment) {
   auto pointerType = cast<PointerType>(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa<ScalarType, VectorType>(pointeeType)) {
@@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
           "not compatible with memory operand 'MakePointerVisible'");
     }
 
-    // The 'Aligned' memory operand requires an alignment literal to follow,
-    // which needs to be implemented on the level of op parsing and
-    // (de-)serialization.
-    // TODO: Consider adding support for this attribute value.
-    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                  spirv::MemoryAccess::Aligned)) {
-      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
+    // #145485.
+
+    if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+        !alignment) {
+      return op->emitOpError("missing value for the 'Aligned' memory operand");
+    }
+
+    if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
+        alignment) {
+      return op->emitOpError(
+          "found alignment attribute for non-'Aligned' memory operand");
     }
   }
 
@@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
-                                getResult().getType(), getMemoryOperandAttr());
+                                getResult().getType(), getMemoryOperandAttr(),
+                                getAlignmentAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
-                                getObject().getType(), getMemoryOperandAttr());
+                                getObject().getType(), getMemoryOperandAttr(),
+                                getAlignmentAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..56d477cca97b7 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_load_aligned
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
 // CHECK-LABEL: @cooperative_matrix_store
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
@@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBu
   spirv.Return
 }
 
+// CHECK-LABEL: @cooperative_matrix_store_aligned
+spirv.func @cooperative_matrix_store_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
 // -----
 
 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuf
 // -----
 
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
@@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer
 // -----
 
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
   spirv.Return
@@ -179,7 +198,7 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageB
 
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
-  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  // expected-error @+1 {{missing value for the 'Aligned' memory operand}}
   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
   spirv.Return
@@ -187,6 +206,15 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %str
 
 // -----
 
+spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <MakePointerVisible>, 16 :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 153ff47937972..77949908e8883 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires
     spirv.Return
   }
 
+  // CHECK-LABEL: @cooperative_matrix_load_3
+  spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
+      !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    spirv.Return
+  }
+
   // CHECK-LABEL: @cooperative_matrix_store_1
   spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                          %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
@@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
 
+    // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
+      !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
     // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
       !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants