Skip to content

Commit

Permalink
Clean up sparsity patches.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642262009
  • Loading branch information
chsigg authored and tensorflower-gardener committed Jun 11, 2024
1 parent c64a51c commit ae908e1
Show file tree
Hide file tree
Showing 6 changed files with 0 additions and 280 deletions.
25 changes: 0 additions & 25 deletions third_party/triton/temporary/sparsity_layout.patch
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,6 @@
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return std::nullopt;
==== triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp#36 - /google/src/cloud/csigg/triton_sparse/triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp ====
# action=edit type=text
--- triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 2024-05-14 06:33:36.000000000 -0700
+++ triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 2024-06-04 04:34:23.000000000 -0700
@@ -636,7 +636,6 @@
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
- patterns.insert<TritonSparseDotPattern>(typeConverter, context);
}

//
@@ -878,12 +877,6 @@
mod->setAttr(AttrTargetName,
StringAttr::get(context, this->target.getValue()));

- // Only transform sparse dot op with undefined layout.
- target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
- [](triton::gpu::SparseDotOp op) {
- return op.getAMeta().getType().getEncoding() != nullptr;
- });
-
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();

==== triton/lib/Dialect/TritonGPU/IR/Dialect.cpp#51 - /google/src/cloud/csigg/triton_sparse/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/IR/Dialect.cpp 2024-06-07 05:28:31.000000000 -0700
Expand Down
115 changes: 0 additions & 115 deletions third_party/triton/xla_extensions/sparse_dot.patch
Original file line number Diff line number Diff line change
Expand Up @@ -53,121 +53,6 @@ index a87e1c44a..456a4f224 100644
+}
+
#endif
diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
index 4aa2712ec..16a6253d7 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
@@ -279,6 +279,89 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
};

+struct TritonSparseDotPattern
+ : public OpConversionPattern<triton::gpu::SparseDotOp> {
+ using OpConversionPattern<triton::gpu::SparseDotOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ triton::gpu::SparseDotOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ RankedTensorType origType = cast<RankedTensorType>(op.getType());
+ auto origShape = origType.getShape();
+ auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
+ int numWarps = typeConverter->getNumWarps();
+ int threadsPerWarp = typeConverter->getThreadsPerWarp();
+ int numCTAs = typeConverter->getNumCTAs();
+
+ auto rank = origShape.size();
+ auto numElements = product<int64_t>(origShape);
+ SmallVector<unsigned> retSizePerThread(rank, 1);
+ if (numElements / (numWarps * threadsPerWarp) >= 4) {
+ retSizePerThread[rank - 1] = 2;
+ retSizePerThread[rank - 2] = 2;
+ }
+ if (numElements / (numWarps * threadsPerWarp) >= 16) {
+ retSizePerThread[rank - 1] = 4;
+ retSizePerThread[rank - 2] = 4;
+ }
+ SmallVector<unsigned> retOrder(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ retOrder[i] = rank - 1 - i;
+ Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
+ getContext(), origShape, retSizePerThread, retOrder, numWarps,
+ threadsPerWarp, numCTAs);
+ RankedTensorType retType =
+ RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
+
+ // a & b must be of smem layout
+ auto aType = cast<RankedTensorType>(adaptor.getA().getType());
+ auto bType = cast<RankedTensorType>(adaptor.getB().getType());
+ Type aEltType = aType.getElementType();
+ Type bEltType = bType.getElementType();
+ Attribute aEncoding = aType.getEncoding();
+ Attribute bEncoding = bType.getEncoding();
+ if (!aEncoding || !bEncoding)
+ return failure();
+ Value a = adaptor.getA();
+ Value b = adaptor.getB();
+ Value c = adaptor.getC();
+ if (!isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
+ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+ getContext(), 0, dEncoding, aEltType);
+ auto dstType =
+ RankedTensorType::get(aType.getShape(), aEltType, encoding);
+ a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
+ }
+ if (!isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
+ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+ getContext(), 1, dEncoding, bEltType);
+ auto dstType =
+ RankedTensorType::get(bType.getShape(), bEltType, encoding);
+ b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
+ }
+ c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
+
+ // aMeta must be of smem layout
+ auto aMetaType = cast<RankedTensorType>(adaptor.getAMeta().getType());
+ Attribute aMetaEncoding = aMetaType.getEncoding();
+ if (!aMetaEncoding) return failure();
+ Value aMeta = adaptor.getAMeta();
+ if (!isa<triton::gpu::SparseDotMetaEncodingAttr>(aMetaEncoding)) {
+ Attribute encoding =
+ triton::gpu::SparseDotMetaEncodingAttr::get(getContext(), dEncoding);
+ auto dstType = RankedTensorType::get(
+ aMetaType.getShape(), aMetaType.getElementType(), encoding);
+ aMeta = rewriter.create<triton::gpu::ConvertLayoutOp>(aMeta.getLoc(),
+ dstType, aMeta);
+ }
+
+ addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SparseDotOp>(
+ op, retType, a, b, c, aMeta),
+ adaptor.getAttributes());
+ return success();
+ }
+};
+
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern::OpConversionPattern;

@@ -553,6 +636,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
+ patterns.insert<TritonSparseDotPattern>(typeConverter, context);
}

//
@@ -794,6 +878,12 @@ public:
mod->setAttr(AttrTargetName,
StringAttr::get(context, this->target.getValue()));

+ // Only transform sparse dot op with undefined layout.
+ target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
+ [](triton::gpu::SparseDotOp op) {
+ return op.getAMeta().getType().getEncoding() != nullptr;
+ });
+
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();

diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index 6e7868b14..035659a60 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
Expand Down
Empty file.
25 changes: 0 additions & 25 deletions third_party/xla/third_party/triton/temporary/sparsity_layout.patch
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,6 @@
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return std::nullopt;
==== triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp#36 - /google/src/cloud/csigg/triton_sparse/triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp ====
# action=edit type=text
--- triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 2024-05-14 06:33:36.000000000 -0700
+++ triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp 2024-06-04 04:34:23.000000000 -0700
@@ -636,7 +636,6 @@
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
- patterns.insert<TritonSparseDotPattern>(typeConverter, context);
}

//
@@ -878,12 +877,6 @@
mod->setAttr(AttrTargetName,
StringAttr::get(context, this->target.getValue()));

- // Only transform sparse dot op with undefined layout.
- target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
- [](triton::gpu::SparseDotOp op) {
- return op.getAMeta().getType().getEncoding() != nullptr;
- });
-
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();

==== triton/lib/Dialect/TritonGPU/IR/Dialect.cpp#51 - /google/src/cloud/csigg/triton_sparse/triton/lib/Dialect/TritonGPU/IR/Dialect.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/IR/Dialect.cpp 2024-06-07 05:28:31.000000000 -0700
Expand Down
115 changes: 0 additions & 115 deletions third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch
Original file line number Diff line number Diff line change
Expand Up @@ -53,121 +53,6 @@ index a87e1c44a..456a4f224 100644
+}
+
#endif
diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
index 4aa2712ec..16a6253d7 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
@@ -279,6 +279,89 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
};

+struct TritonSparseDotPattern
+ : public OpConversionPattern<triton::gpu::SparseDotOp> {
+ using OpConversionPattern<triton::gpu::SparseDotOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ triton::gpu::SparseDotOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ RankedTensorType origType = cast<RankedTensorType>(op.getType());
+ auto origShape = origType.getShape();
+ auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
+ int numWarps = typeConverter->getNumWarps();
+ int threadsPerWarp = typeConverter->getThreadsPerWarp();
+ int numCTAs = typeConverter->getNumCTAs();
+
+ auto rank = origShape.size();
+ auto numElements = product<int64_t>(origShape);
+ SmallVector<unsigned> retSizePerThread(rank, 1);
+ if (numElements / (numWarps * threadsPerWarp) >= 4) {
+ retSizePerThread[rank - 1] = 2;
+ retSizePerThread[rank - 2] = 2;
+ }
+ if (numElements / (numWarps * threadsPerWarp) >= 16) {
+ retSizePerThread[rank - 1] = 4;
+ retSizePerThread[rank - 2] = 4;
+ }
+ SmallVector<unsigned> retOrder(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ retOrder[i] = rank - 1 - i;
+ Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
+ getContext(), origShape, retSizePerThread, retOrder, numWarps,
+ threadsPerWarp, numCTAs);
+ RankedTensorType retType =
+ RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
+
+ // a & b must be of smem layout
+ auto aType = cast<RankedTensorType>(adaptor.getA().getType());
+ auto bType = cast<RankedTensorType>(adaptor.getB().getType());
+ Type aEltType = aType.getElementType();
+ Type bEltType = bType.getElementType();
+ Attribute aEncoding = aType.getEncoding();
+ Attribute bEncoding = bType.getEncoding();
+ if (!aEncoding || !bEncoding)
+ return failure();
+ Value a = adaptor.getA();
+ Value b = adaptor.getB();
+ Value c = adaptor.getC();
+ if (!isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
+ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+ getContext(), 0, dEncoding, aEltType);
+ auto dstType =
+ RankedTensorType::get(aType.getShape(), aEltType, encoding);
+ a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
+ }
+ if (!isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
+ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+ getContext(), 1, dEncoding, bEltType);
+ auto dstType =
+ RankedTensorType::get(bType.getShape(), bEltType, encoding);
+ b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
+ }
+ c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
+
+ // aMeta must be of smem layout
+ auto aMetaType = cast<RankedTensorType>(adaptor.getAMeta().getType());
+ Attribute aMetaEncoding = aMetaType.getEncoding();
+ if (!aMetaEncoding) return failure();
+ Value aMeta = adaptor.getAMeta();
+ if (!isa<triton::gpu::SparseDotMetaEncodingAttr>(aMetaEncoding)) {
+ Attribute encoding =
+ triton::gpu::SparseDotMetaEncodingAttr::get(getContext(), dEncoding);
+ auto dstType = RankedTensorType::get(
+ aMetaType.getShape(), aMetaType.getElementType(), encoding);
+ aMeta = rewriter.create<triton::gpu::ConvertLayoutOp>(aMeta.getLoc(),
+ dstType, aMeta);
+ }
+
+ addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SparseDotOp>(
+ op, retType, a, b, c, aMeta),
+ adaptor.getAttributes());
+ return success();
+ }
+};
+
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern::OpConversionPattern;

@@ -553,6 +636,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
+ patterns.insert<TritonSparseDotPattern>(typeConverter, context);
}

//
@@ -794,6 +878,12 @@ public:
mod->setAttr(AttrTargetName,
StringAttr::get(context, this->target.getValue()));

+ // Only transform sparse dot op with undefined layout.
+ target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
+ [](triton::gpu::SparseDotOp op) {
+ return op.getAMeta().getType().getEncoding() != nullptr;
+ });
+
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();

diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index 6e7868b14..035659a60 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
Expand Down
Empty file.

0 comments on commit ae908e1

Please sign in to comment.