Skip to content

Commit

Permalink
[xla][gpu] Move sparse tests from Triton repo to XLA repo.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626019600
  • Loading branch information
chsigg authored and tensorflower-gardener committed Apr 18, 2024
1 parent dce1770 commit ce6c42d
Show file tree
Hide file tree
Showing 16 changed files with 349 additions and 768 deletions.
198 changes: 0 additions & 198 deletions third_party/triton/xla_extensions/sparse_dot_base.patch
Expand Up @@ -168,204 +168,6 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia
} // namespace gpu
} // namespace triton
} // namespace mlir
diff --git a/test/SparseDot/convert_to_llvm_ampere.mlir b/test/SparseDot/convert_to_llvm_ampere.mlir
new file mode 100644
--- /dev/null
+++ b/test/SparseDot/convert_to_llvm_ampere.mlir
@@ -0,0 +1,26 @@
+// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=80 | FileCheck %s
+
+#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
+#shared0 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
+#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
+#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
+#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
+#dot_meta_enc = #triton_gpu.sparse_dot_meta<{parent=#mma0}>
+
+module attributes {"triton_gpu.num-warps" = 4 : i32} {
+ tt.func @sparse_dot(%A: tensor<32x32xf16, #blocked0>, %B: tensor<64x32xf16, #blocked0>, %meta: tensor<32x4xi16, #blocked0>) {
+ // CHECK-COUNT-2: ldmatrix.sync.aligned.m8n8.x4.shared.b16
+ %A_alloc = triton_gpu.local_alloc %A {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked0>) -> !tt.memdesc<32x32xf16, #shared0>
+ %A_dot = triton_gpu.local_load %A_alloc : !tt.memdesc<32x32xf16, #shared0> -> tensor<32x32xf16, #dot_operand_a>
+ // CHECK-COUNT-4: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+ %B_alloc = triton_gpu.local_alloc %B {allocation.offset = 2048 : i32} : (tensor<64x32xf16, #blocked0>) -> !tt.memdesc<64x32xf16, #shared0>
+ %B_dot = triton_gpu.local_load %B_alloc : !tt.memdesc<64x32xf16, #shared0> -> tensor<64x32xf16, #dot_operand_b>
+ // CHECK-COUNT-4: llvm.load %[[_:.*]] : !llvm.ptr<3> -> i16
+ %meta_alloc = triton_gpu.local_alloc %meta {allocation.offset = 6144 : i32} : (tensor<32x4xi16, #blocked0>) -> !tt.memdesc<32x4xi16, #shared0>
+ %meta_reg = triton_gpu.local_load %meta_alloc : !tt.memdesc<32x4xi16, #shared0> -> tensor<32x4xi16, #dot_meta_enc>
+ // CHECK-COUNT-4: mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32
+ %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma0>
+ %D = triton_gpu.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16, #dot_meta_enc> * tensor<64x32xf16, #dot_operand_b> -> tensor<32x32xf32, #mma0>
+ tt.return
+ }
+}
diff --git a/test/SparseDot/convert_to_llvm_hopper.mlir b/test/SparseDot/convert_to_llvm_hopper.mlir
new file mode 100644
--- /dev/null
+++ b/test/SparseDot/convert_to_llvm_hopper.mlir
@@ -0,0 +1,28 @@
+// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 | FileCheck %s
+
+#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
+#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
+#shared1 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
+#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 64, 16]}>
+#dot_meta_enc = #triton_gpu.sparse_dot_meta<{parent=#mma0}>
+
+module attributes {"triton_gpu.num-warps" = 4 : i32} {
+ tt.func @sparse_dot(%A: tensor<64x32xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>, %meta: tensor<64x4xi16, #blocked0>) {
+ %A_alloc = triton_gpu.local_alloc %A {allocation.offset = 0 : i32} : (tensor<64x32xf16, #blocked0>) -> !tt.memdesc<64x32xf16, #shared0>
+ %B_alloc = triton_gpu.local_alloc %B {allocation.offset = 4096 : i32} : (tensor<64x64xf16, #blocked0>) -> !tt.memdesc<64x64xf16, #shared0>
+ // CHECK-COUNT-2: llvm.load %[[_:.*]] : !llvm.ptr<3> -> i16
+ %meta_alloc = triton_gpu.local_alloc %meta {allocation.offset = 12288 : i32} : (tensor<64x4xi16, #blocked0>) -> !tt.memdesc<64x4xi16, #shared0>
+ %meta_reg = triton_gpu.local_load %meta_alloc : !tt.memdesc<64x4xi16, #shared0> -> tensor<64x4xi16, #dot_meta_enc>
+ // CHECK: nvgpu.wgmma_fence
+ // CHECK-COUNT-2: nvgpu.wgmma_sp %[[A:.*]] meta %[[M:.*]], %[[B:.*]], %[[C:.*]] {
+ // CHECK-DAG: layoutA = 0 : i32
+ // CHECK-DAG: layoutB = 0 : i32
+ // CHECK-DAG: m = 64 : i32
+ // CHECK-DAG: n = 64 : i32
+ // CHECK-DAG: k = 32 : i32
+ // CHECK: nvgpu.wgmma_commit_group
+ %acc = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
+ %D = triton_gpu.sparse_dot %A_alloc, %B_alloc, %acc, %meta_reg : !tt.memdesc<64x32xf16, #shared0> meta tensor<64x4xi16, #dot_meta_enc> * !tt.memdesc<64x64xf16, #shared0> -> tensor<64x64xf32, #mma0>
+ tt.return
+ }
+}
diff --git a/test/SparseDot/validation.mlir b/test/SparseDot/validation.mlir
new file mode 100644
--- /dev/null
+++ b/test/SparseDot/validation.mlir
@@ -0,0 +1,129 @@
+// RUN: triton-opt --split-input-file --verify-diagnostics %s
+
+tt.func @sparse_dot(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_lhs_type(%lhs: tensor<128x32xf32>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{element type of operand A is not supported}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xf32> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_lhs_shape(%lhs: tensor<1x128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{shape of operand A is incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<1x128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_rhs_type(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xf32>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{element type of operand B is not supported}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xf32> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_rhs_shape(%lhs: tensor<128x32xbf16>, %rhs: tensor<1x64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{shape of operand B is incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<1x64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_acc_type(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xbf16>
+ // expected-error @+1 {{element type of operand C is not supported}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xbf16>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_acc_shape(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<16384xf32>
+ // expected-error @+1 {{shape of operand C is incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<16384xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_mismatch_lhs_acc(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<64x128xf32>
+ // expected-error @+1 {{operand shape dimensions are incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<64x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_mismatch_rhs_acc(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x64xf32>
+ // expected-error @+1 {{operand shape dimensions are incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x64xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_mismatch_lhs_rhs(%lhs: tensor<128x32xbf16>, %rhs: tensor<32x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{operand shape dimensions are incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi16> * tensor<32x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_mismatch_input_types(%lhs: tensor<128x32xf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{operand element types do not match}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xf16> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_meta_type(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi8>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{sparse metadata tensor is invalid}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x4xi8> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_invalid_meta_shape(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<512xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{sparse metadata tensor is invalid}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<512xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_mismatch_meta_noncontracting(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<64x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{sparse metadata shape dimensions are incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<64x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+tt.func @sparse_dot_mismatch_meta_contracting(%lhs: tensor<128x32xbf16>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x8xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{sparse metadata shape dimensions are incorrect}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16> meta tensor<128x8xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
+
+// -----
+#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
+#enc0 = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
+tt.func @sparse_dot_encoding_operand_mismatch(%lhs: tensor<128x32xbf16, #enc0>, %rhs: tensor<64x128xbf16>, %meta: tensor<128x4xi16>) {
+ %acc = arith.constant dense<0.00e+00> : tensor<128x128xf32>
+ // expected-error @+1 {{mismatching encoding between A and B operands}}
+ %res = triton_gpu.sparse_dot %lhs, %rhs, %acc, %meta : tensor<128x32xbf16, #enc0> meta tensor<128x4xi16> * tensor<64x128xbf16> -> tensor<128x128xf32>
+ tt.return
+}
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp
+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Expand Down
19 changes: 0 additions & 19 deletions third_party/triton/xla_extensions/sparse_dot_nvgpu.patch
Expand Up @@ -17,25 +17,6 @@ diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialec
def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> {
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec);
let builders = [
diff --git a/test/SparseDot/test_wgmma_sp.mlir b/test/SparseDot/test_wgmma_sp.mlir
new file mode 100644
--- /dev/null
+++ b/test/SparseDot/test_wgmma_sp.mlir
@@ -0,0 +1,14 @@
+// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s
+
+module attributes {"triton_gpu.num-warps" = 4 : i32} {
+ tt.func @wgmma_sp(%descA: i64, %metaA: i32, %descB: i64, %acc: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) {
+ // CHECK: @wgmma_sp(%[[LHS:.*]]: i64, %[[META:.*]]: i32, %[[RHS:.*]]: i64,
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = []
+ // CHECK-SAME: "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7}, $16, $17, $18, 0, 1, 1, 1, 0, 0;"
+ // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,l,l,r" %0, %1, %2, %3, %4, %5, %6, %7, %[[LHS]], %[[RHS]], %[[META]]
+ %acc0 = nvgpu.wgmma_sp %descA meta %metaA, %descB, %acc
+ {eltTypeA = 5 : i32, eltTypeB = 5 : i32, eltTypeC = 7 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 16 : i32, k = 32 : i32} :
+ (i64, i32, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ tt.return
+ }
+}
diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
--- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
+++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
Expand Down

0 comments on commit ce6c42d

Please sign in to comment.