Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate StableHLO at openxla/stablehlo@c44d9af8 #68292

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,8 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern<TF::GatherNdOp> {

auto dims_attr = GatherDimensionNumbersAttr::get(
rewriter.getContext(), offset_dims, collapsed_slice_dims,
start_index_map, index_vector_dim);
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim);
// TODO(disc): Remove this if-statement once fold and canonicalization is
// implemented.
if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) {
Expand Down Expand Up @@ -1956,7 +1957,9 @@ class ConvertMatrixDiagPartV3Op
auto dims_attr = GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/llvm::to_vector<4>(llvm::seq<int64_t>(0, num_dims - 2)),
/*collapsedSliceDims=*/collapsed_dims, start_index_map,
/*collapsedSliceDims=*/collapsed_dims,
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{}, start_index_map,
/*indexVectorDim=*/0);
Value gather = rewriter.create<mhlo::GatherOp>(
loc, op.getInput(), start_indices, dims_attr,
Expand Down Expand Up @@ -4373,6 +4376,8 @@ class ConvertTensorScatterOp : public OpRewritePattern<OpTy> {
llvm::to_vector<4>(
llvm::seq<int64_t>(updates_rank - window_dims, updates_rank)),
llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
/*inputBatchingDims=*/{},
/*scatterIndicesBatchingDims=*/{},
llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
indices_rank - 1);

Expand Down Expand Up @@ -5614,7 +5619,10 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
auto dims_attr = ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
llvm::to_vector<4>(llvm::seq<int64_t>(segment_ids_rank, data_rank)),
inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim);
inserted_window_dims,
/*inputBatchingDims=*/{},
/*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims,
index_vector_dim);

auto scatter = rewriter.create<ScatterOp>(
op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)),
Expand Down Expand Up @@ -5836,6 +5844,8 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
rewriter.getContext(),
/*offsetDims=*/llvm::to_vector<4>(llvm::seq<int64_t>(1, input_rank)),
/*collapsedSliceDims=*/{0},
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{},
/*startIndexMap=*/{0},
/*indexVectorDim=*/1);

Expand Down
126 changes: 57 additions & 69 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -175,53 +175,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists
add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
@@ -155,7 +155,7 @@

// CHECK-LABEL: @maximum_f64
func.func @maximum_f64(%arg0 : tensor<10xf64>, %arg1 : tensor<10xf64>) -> tensor<10xf64> {
- // CHECK: stablehlo.maximum
+ // CHECK: tosa.maximum
%0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
@@ -9,8 +9,7 @@

// CHECK-LABEL: @constant_f64
func.func @constant_f64() -> tensor<10xf64> {
- // TOSA does not support 64-bit types, so this should not legalize.
- // CHECK: stablehlo.constant
+ // CHECK: tosa.const
%0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp
--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp
+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp
@@ -305,8 +305,7 @@
bool isCommutativeNoRegionMatchingDialect(OperationName innerOp,
StringRef reduceOpDialect) {
auto innerOpDialect = innerOp.getDialect();
- return innerOpDialect &&
- innerOpDialect->getNamespace().equals(reduceOpDialect) &&
+ return innerOpDialect && innerOpDialect->getNamespace() == reduceOpDialect &&
innerOp.hasTrait<mlir::OpTrait::NOperands<2>::Impl>() &&
innerOp.hasTrait<mlir::OpTrait::OneResult>() &&
(innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
@@ -359,7 +358,7 @@
// Check E5.
LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E5\n");
auto retOp = block.getTerminator();
- if (!retOp->getName().stripDialect().equals("return")) return false;
+ if (retOp->getName().stripDialect() != "return") return false;

return llvm::equal(innerOp.getResults(), retOp->getOperands());
}
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down Expand Up @@ -2440,7 +2393,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy
diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
@@ -0,0 +1,170 @@
@@ -0,0 +1,171 @@
+/* Copyright 2022 The StableHLO Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2594,6 +2547,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+
+ RewritePatternSet patterns(&getContext());
+ populateStablehloRefineShapesPatterns(&patterns, &getContext());
+ populateStablehloShapeFolderPatterns(&patterns, &getContext());
+ patterns.add<RefineDynamicReduceWindowOpPattern>(&getContext());
+ patterns.add<RefineDynamicRngBitGeneratorOpPattern>(&getContext());
+ patterns.add<RefineDynamicTopKOpPattern>(&getContext());
Expand All @@ -2611,18 +2565,64 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py
--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py
+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py
@@ -115,14 +115,17 @@
operand_batching_dims=[6, 7],
start_indices_batching_dims=[8, 9],
start_index_map=[10],
- index_vector_dim=11)
- assert attr is not None
- assert str(attr) == ("#stablehlo.gather<offset_dims = [1, 2], "
- "collapsed_slice_dims = [3, 4, 5], "
- "operand_batching_dims = [6, 7], "
- "start_indices_batching_dims = [8, 9], "
- "start_index_map = [10], "
- "index_vector_dim = 11>")
+ index_vector_dim=11,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.gather<offset_dims = [1, 2], "
+ "collapsed_slice_dims = [3, 4, 5], "
+ "operand_batching_dims = [6, 7], "
+ "start_indices_batching_dims = [8, 9], "
+ "start_index_map = [10], "
+ "index_vector_dim = 11>"
+ )
assert attr.offset_dims == [1, 2]
assert attr.collapsed_slice_dims == [3, 4, 5]
assert attr.operand_batching_dims == [6, 7]
@@ -178,14 +181,17 @@
input_batching_dims=[6, 7],
scatter_indices_batching_dims=[8, 9],
scattered_dims_to_operand_dims=[10, 11],
- index_vector_dim=12)
- assert attr is not None
- assert str(attr) == ("#stablehlo.scatter<update_window_dims = [1, 2, 3], "
- "inserted_window_dims = [4, 5], "
- "input_batching_dims = [6, 7], "
- "scatter_indices_batching_dims = [8, 9], "
- "scatter_dims_to_operand_dims = [10, 11], "
- "index_vector_dim = 12>")
+ index_vector_dim=12,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.scatter<update_window_dims = [1, 2, 3], "
+ "inserted_window_dims = [4, 5], "
+ "input_batching_dims = [6, 7], "
+ "scatter_indices_batching_dims = [8, 9], "
+ "scatter_dims_to_operand_dims = [10, 11], "
+ "index_vector_dim = 12>"
+ )
assert attr.update_window_dims == [1, 2, 3]
assert attr.inserted_window_dims == [4, 5]
assert attr.input_batching_dims == [6, 7]
diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/reference/Api.cpp
--- stablehlo/stablehlo/reference/Api.cpp
+++ stablehlo/stablehlo/reference/Api.cpp
@@ -51,7 +51,7 @@
auto functions = module.getOps<func::FuncOp>();

for (auto funcOp : functions)
- if (funcOp.getSymName().equals(mainName)) return funcOp;
+ if (funcOp.getSymName() == mainName) return funcOp;

bool isSingleFunction =
std::distance(functions.begin(), functions.end()) == 1;
@@ -68,7 +68,7 @@
class DefaultInterpreterFallback : public InterpreterFallback {
public:
Expand All @@ -2632,16 +2632,4 @@ diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/referen

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -764,7 +764,7 @@

// Clean up operand buffers after refinement
// Must do in this pattern to avoid needing multiple refinement iterations
- if (op.getCallTargetName().equals(kCustomCallOperandBarrierTarget)) {
+ if (op.getCallTargetName() == kCustomCallOperandBarrierTarget) {
Value operand = op.getOperand(0);
if (operand.getType() == op.getResult(0).getType()) {
op.replaceAllUsesWith(ValueRange(operand));

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "797bee217e1a041e9aac22cad4db207274596d94"
STABLEHLO_SHA256 = "e5619033e131ea2eeb9eab8c8e362f3ba12e111c6b4a15dac789ca216ff22c58"
STABLEHLO_COMMIT = "c44d9af8d4879adccf1054cb61a53377ae5898cb"
STABLEHLO_SHA256 = "a8f5d4df0256e9d1c7b35fead77c31b9d8d985a0909eb198374faa9f7de15e94"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
126 changes: 57 additions & 69 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -175,53 +175,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists
add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
@@ -155,7 +155,7 @@

// CHECK-LABEL: @maximum_f64
func.func @maximum_f64(%arg0 : tensor<10xf64>, %arg1 : tensor<10xf64>) -> tensor<10xf64> {
- // CHECK: stablehlo.maximum
+ // CHECK: tosa.maximum
%0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
@@ -9,8 +9,7 @@

// CHECK-LABEL: @constant_f64
func.func @constant_f64() -> tensor<10xf64> {
- // TOSA does not support 64-bit types, so this should not legalize.
- // CHECK: stablehlo.constant
+ // CHECK: tosa.const
%0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp
--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp
+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp
@@ -305,8 +305,7 @@
bool isCommutativeNoRegionMatchingDialect(OperationName innerOp,
StringRef reduceOpDialect) {
auto innerOpDialect = innerOp.getDialect();
- return innerOpDialect &&
- innerOpDialect->getNamespace().equals(reduceOpDialect) &&
+ return innerOpDialect && innerOpDialect->getNamespace() == reduceOpDialect &&
innerOp.hasTrait<mlir::OpTrait::NOperands<2>::Impl>() &&
innerOp.hasTrait<mlir::OpTrait::OneResult>() &&
(innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
@@ -359,7 +358,7 @@
// Check E5.
LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E5\n");
auto retOp = block.getTerminator();
- if (!retOp->getName().stripDialect().equals("return")) return false;
+ if (retOp->getName().stripDialect() != "return") return false;

return llvm::equal(innerOp.getResults(), retOp->getOperands());
}
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down Expand Up @@ -2440,7 +2393,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy
diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
@@ -0,0 +1,170 @@
@@ -0,0 +1,171 @@
+/* Copyright 2022 The StableHLO Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2594,6 +2547,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+
+ RewritePatternSet patterns(&getContext());
+ populateStablehloRefineShapesPatterns(&patterns, &getContext());
+ populateStablehloShapeFolderPatterns(&patterns, &getContext());
+ patterns.add<RefineDynamicReduceWindowOpPattern>(&getContext());
+ patterns.add<RefineDynamicRngBitGeneratorOpPattern>(&getContext());
+ patterns.add<RefineDynamicTopKOpPattern>(&getContext());
Expand All @@ -2611,18 +2565,64 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py
--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py
+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py
@@ -115,14 +115,17 @@
operand_batching_dims=[6, 7],
start_indices_batching_dims=[8, 9],
start_index_map=[10],
- index_vector_dim=11)
- assert attr is not None
- assert str(attr) == ("#stablehlo.gather<offset_dims = [1, 2], "
- "collapsed_slice_dims = [3, 4, 5], "
- "operand_batching_dims = [6, 7], "
- "start_indices_batching_dims = [8, 9], "
- "start_index_map = [10], "
- "index_vector_dim = 11>")
+ index_vector_dim=11,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.gather<offset_dims = [1, 2], "
+ "collapsed_slice_dims = [3, 4, 5], "
+ "operand_batching_dims = [6, 7], "
+ "start_indices_batching_dims = [8, 9], "
+ "start_index_map = [10], "
+ "index_vector_dim = 11>"
+ )
assert attr.offset_dims == [1, 2]
assert attr.collapsed_slice_dims == [3, 4, 5]
assert attr.operand_batching_dims == [6, 7]
@@ -178,14 +181,17 @@
input_batching_dims=[6, 7],
scatter_indices_batching_dims=[8, 9],
scattered_dims_to_operand_dims=[10, 11],
- index_vector_dim=12)
- assert attr is not None
- assert str(attr) == ("#stablehlo.scatter<update_window_dims = [1, 2, 3], "
- "inserted_window_dims = [4, 5], "
- "input_batching_dims = [6, 7], "
- "scatter_indices_batching_dims = [8, 9], "
- "scatter_dims_to_operand_dims = [10, 11], "
- "index_vector_dim = 12>")
+ index_vector_dim=12,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.scatter<update_window_dims = [1, 2, 3], "
+ "inserted_window_dims = [4, 5], "
+ "input_batching_dims = [6, 7], "
+ "scatter_indices_batching_dims = [8, 9], "
+ "scatter_dims_to_operand_dims = [10, 11], "
+ "index_vector_dim = 12>"
+ )
assert attr.update_window_dims == [1, 2, 3]
assert attr.inserted_window_dims == [4, 5]
assert attr.input_batching_dims == [6, 7]
diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/reference/Api.cpp
--- stablehlo/stablehlo/reference/Api.cpp
+++ stablehlo/stablehlo/reference/Api.cpp
@@ -51,7 +51,7 @@
auto functions = module.getOps<func::FuncOp>();

for (auto funcOp : functions)
- if (funcOp.getSymName().equals(mainName)) return funcOp;
+ if (funcOp.getSymName() == mainName) return funcOp;

bool isSingleFunction =
std::distance(functions.begin(), functions.end()) == 1;
@@ -68,7 +68,7 @@
class DefaultInterpreterFallback : public InterpreterFallback {
public:
Expand All @@ -2632,16 +2632,4 @@ diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/referen

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -764,7 +764,7 @@

// Clean up operand buffers after refinement
// Must do in this pattern to avoid needing multiple refinement iterations
- if (op.getCallTargetName().equals(kCustomCallOperandBarrierTarget)) {
+ if (op.getCallTargetName() == kCustomCallOperandBarrierTarget) {
Value operand = op.getOperand(0);
if (operand.getType() == op.getResult(0).getType()) {
op.replaceAllUsesWith(ValueRange(operand));

4 changes: 2 additions & 2 deletions third_party/xla/third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "797bee217e1a041e9aac22cad4db207274596d94"
STABLEHLO_SHA256 = "e5619033e131ea2eeb9eab8c8e362f3ba12e111c6b4a15dac789ca216ff22c58"
STABLEHLO_COMMIT = "c44d9af8d4879adccf1054cb61a53377ae5898cb"
STABLEHLO_SHA256 = "a8f5d4df0256e9d1c7b35fead77c31b9d8d985a0909eb198374faa9f7de15e94"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
Loading
Loading