Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ those to this list.
"""

temporary_patch_list = [
"//third_party/triton:temporary/sparsity.patch",
"//third_party/triton:temporary/replace_unreachable_by_abort.patch",
"//third_party/triton:temporary/block_k_16_fix.patch",
"//third_party/triton:temporary/index_cast_ui_axis_info.patch",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This patch is already a public patch. It should be removed in the next integration.

diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index eb26ffe3b..ba87d671e 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Expand Down Expand Up @@ -136,26 +138,6 @@ index fb0e7f6fd..37795c20c 100644
return WalkResult::advance();
OpBuilder builder(op);
auto a = op->getOperand(0);
diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
index 7affd8840..52aa2c131 100644
--- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
@@ -87,6 +87,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
}

+def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> {
+ let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
+ I32Attr:$m, I32Attr:$n, I32Attr:$k,
+ WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
+ WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
+ let results = (outs LLVM_AnyStruct:$res);
+ let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
+}
+
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
let arguments = (ins BoolAttr:$bCluster);
let assemblyFormat = "attr-dict";
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
index baed96a29..e9d7f5859 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
Expand Down Expand Up @@ -221,4 +203,4 @@ index df3d3b042..e38c184f6 100644
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {
})) {
2 changes: 1 addition & 1 deletion third_party/triton/xla_extensions/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to
"""

extensions_files_patch_list = [
"//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch
"//third_party/triton:xla_extensions/sparse_wgmma_op.patch", # Sparsity internal patch
# Add new patches just above this line
]
21 changes: 21 additions & 0 deletions third_party/triton/xla_extensions/sparse_wgmma_op.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Tracked in b/377656276
diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
index 7affd8840..52aa2c131 100644
--- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
@@ -87,6 +87,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
}

+def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> {
+ let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
+ I32Attr:$m, I32Attr:$n, I32Attr:$k,
+ WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
+ WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
+ let results = (outs LLVM_AnyStruct:$res);
+ let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
+}
+
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
let arguments = (ins BoolAttr:$bCluster);
let assemblyFormat = "attr-dict";
1 change: 1 addition & 0 deletions third_party/xla/third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ those to this list.
"""

temporary_patch_list = [
"//third_party/triton:temporary/sparsity.patch",
"//third_party/triton:temporary/replace_unreachable_by_abort.patch",
"//third_party/triton:temporary/block_k_16_fix.patch",
"//third_party/triton:temporary/index_cast_ui_axis_info.patch",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This patch is already a public patch. It should be removed in the next integration.

diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index eb26ffe3b..ba87d671e 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Expand Down Expand Up @@ -136,26 +138,6 @@ index fb0e7f6fd..37795c20c 100644
return WalkResult::advance();
OpBuilder builder(op);
auto a = op->getOperand(0);
diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
index 7affd8840..52aa2c131 100644
--- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
@@ -87,6 +87,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
}

+def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> {
+ let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
+ I32Attr:$m, I32Attr:$n, I32Attr:$k,
+ WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
+ WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
+ let results = (outs LLVM_AnyStruct:$res);
+ let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
+}
+
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
let arguments = (ins BoolAttr:$bCluster);
let assemblyFormat = "attr-dict";
diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
index baed96a29..e9d7f5859 100644
--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp
Expand Down Expand Up @@ -221,4 +203,4 @@ index df3d3b042..e38c184f6 100644
+ mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {
})) {
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to
"""

extensions_files_patch_list = [
"//third_party/triton:xla_extensions/sparse_dot.patch", # Sparsity internal patch
"//third_party/triton:xla_extensions/sparse_wgmma_op.patch", # Sparsity internal patch
# Add new patches just above this line
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Tracked in b/377656276
diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
index 7affd8840..52aa2c131 100644
--- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td
@@ -87,6 +87,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)";
}

+def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> {
+ let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
+ I32Attr:$m, I32Attr:$n, I32Attr:$k,
+ WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
+ WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
+ let results = (outs LLVM_AnyStruct:$res);
+ let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
+}
+
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
let arguments = (ins BoolAttr:$bCluster);
let assemblyFormat = "attr-dict";