Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@c44d9af8
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638559828
  • Loading branch information
Michael Levesque-Dion authored and tensorflower-gardener committed May 30, 2024
1 parent 374e4d7 commit bc42c0c
Show file tree
Hide file tree
Showing 26 changed files with 1,739 additions and 675 deletions.
16 changes: 13 additions & 3 deletions tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,8 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern<TF::GatherNdOp> {

auto dims_attr = GatherDimensionNumbersAttr::get(
rewriter.getContext(), offset_dims, collapsed_slice_dims,
start_index_map, index_vector_dim);
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim);
// TODO(disc): Remove this if-statement once fold and canonicalization is
// implemented.
if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) {
Expand Down Expand Up @@ -1956,7 +1957,9 @@ class ConvertMatrixDiagPartV3Op
auto dims_attr = GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/llvm::to_vector<4>(llvm::seq<int64_t>(0, num_dims - 2)),
/*collapsedSliceDims=*/collapsed_dims, start_index_map,
/*collapsedSliceDims=*/collapsed_dims,
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{}, start_index_map,
/*indexVectorDim=*/0);
Value gather = rewriter.create<mhlo::GatherOp>(
loc, op.getInput(), start_indices, dims_attr,
Expand Down Expand Up @@ -4373,6 +4376,8 @@ class ConvertTensorScatterOp : public OpRewritePattern<OpTy> {
llvm::to_vector<4>(
llvm::seq<int64_t>(updates_rank - window_dims, updates_rank)),
llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
/*inputBatchingDims=*/{},
/*scatterIndicesBatchingDims=*/{},
llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
indices_rank - 1);

Expand Down Expand Up @@ -5614,7 +5619,10 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
auto dims_attr = ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
llvm::to_vector<4>(llvm::seq<int64_t>(segment_ids_rank, data_rank)),
inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim);
inserted_window_dims,
/*inputBatchingDims=*/{},
/*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims,
index_vector_dim);

auto scatter = rewriter.create<ScatterOp>(
op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)),
Expand Down Expand Up @@ -5836,6 +5844,8 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
rewriter.getContext(),
/*offsetDims=*/llvm::to_vector<4>(llvm::seq<int64_t>(1, input_rank)),
/*collapsedSliceDims=*/{0},
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{},
/*startIndexMap=*/{0},
/*indexVectorDim=*/1);

Expand Down
126 changes: 57 additions & 69 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -175,53 +175,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists
add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
@@ -155,7 +155,7 @@

// CHECK-LABEL: @maximum_f64
func.func @maximum_f64(%arg0 : tensor<10xf64>, %arg1 : tensor<10xf64>) -> tensor<10xf64> {
- // CHECK: stablehlo.maximum
+ // CHECK: tosa.maximum
%0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
@@ -9,8 +9,7 @@

// CHECK-LABEL: @constant_f64
func.func @constant_f64() -> tensor<10xf64> {
- // TOSA does not support 64-bit types, so this should not legalize.
- // CHECK: stablehlo.constant
+ // CHECK: tosa.const
%0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp
--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp
+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp
@@ -305,8 +305,7 @@
bool isCommutativeNoRegionMatchingDialect(OperationName innerOp,
StringRef reduceOpDialect) {
auto innerOpDialect = innerOp.getDialect();
- return innerOpDialect &&
- innerOpDialect->getNamespace().equals(reduceOpDialect) &&
+ return innerOpDialect && innerOpDialect->getNamespace() == reduceOpDialect &&
innerOp.hasTrait<mlir::OpTrait::NOperands<2>::Impl>() &&
innerOp.hasTrait<mlir::OpTrait::OneResult>() &&
(innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
@@ -359,7 +358,7 @@
// Check E5.
LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E5\n");
auto retOp = block.getTerminator();
- if (!retOp->getName().stripDialect().equals("return")) return false;
+ if (retOp->getName().stripDialect() != "return") return false;

return llvm::equal(innerOp.getResults(), retOp->getOperands());
}
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down Expand Up @@ -2440,7 +2393,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy
diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
@@ -0,0 +1,170 @@
@@ -0,0 +1,171 @@
+/* Copyright 2022 The StableHLO Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2594,6 +2547,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+
+ RewritePatternSet patterns(&getContext());
+ populateStablehloRefineShapesPatterns(&patterns, &getContext());
+ populateStablehloShapeFolderPatterns(&patterns, &getContext());
+ patterns.add<RefineDynamicReduceWindowOpPattern>(&getContext());
+ patterns.add<RefineDynamicRngBitGeneratorOpPattern>(&getContext());
+ patterns.add<RefineDynamicTopKOpPattern>(&getContext());
Expand All @@ -2611,18 +2565,64 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py
--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py
+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py
@@ -115,14 +115,17 @@
operand_batching_dims=[6, 7],
start_indices_batching_dims=[8, 9],
start_index_map=[10],
- index_vector_dim=11)
- assert attr is not None
- assert str(attr) == ("#stablehlo.gather<offset_dims = [1, 2], "
- "collapsed_slice_dims = [3, 4, 5], "
- "operand_batching_dims = [6, 7], "
- "start_indices_batching_dims = [8, 9], "
- "start_index_map = [10], "
- "index_vector_dim = 11>")
+ index_vector_dim=11,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.gather<offset_dims = [1, 2], "
+ "collapsed_slice_dims = [3, 4, 5], "
+ "operand_batching_dims = [6, 7], "
+ "start_indices_batching_dims = [8, 9], "
+ "start_index_map = [10], "
+ "index_vector_dim = 11>"
+ )
assert attr.offset_dims == [1, 2]
assert attr.collapsed_slice_dims == [3, 4, 5]
assert attr.operand_batching_dims == [6, 7]
@@ -178,14 +181,17 @@
input_batching_dims=[6, 7],
scatter_indices_batching_dims=[8, 9],
scattered_dims_to_operand_dims=[10, 11],
- index_vector_dim=12)
- assert attr is not None
- assert str(attr) == ("#stablehlo.scatter<update_window_dims = [1, 2, 3], "
- "inserted_window_dims = [4, 5], "
- "input_batching_dims = [6, 7], "
- "scatter_indices_batching_dims = [8, 9], "
- "scatter_dims_to_operand_dims = [10, 11], "
- "index_vector_dim = 12>")
+ index_vector_dim=12,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.scatter<update_window_dims = [1, 2, 3], "
+ "inserted_window_dims = [4, 5], "
+ "input_batching_dims = [6, 7], "
+ "scatter_indices_batching_dims = [8, 9], "
+ "scatter_dims_to_operand_dims = [10, 11], "
+ "index_vector_dim = 12>"
+ )
assert attr.update_window_dims == [1, 2, 3]
assert attr.inserted_window_dims == [4, 5]
assert attr.input_batching_dims == [6, 7]
diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/reference/Api.cpp
--- stablehlo/stablehlo/reference/Api.cpp
+++ stablehlo/stablehlo/reference/Api.cpp
@@ -51,7 +51,7 @@
auto functions = module.getOps<func::FuncOp>();

for (auto funcOp : functions)
- if (funcOp.getSymName().equals(mainName)) return funcOp;
+ if (funcOp.getSymName() == mainName) return funcOp;

bool isSingleFunction =
std::distance(functions.begin(), functions.end()) == 1;
@@ -68,7 +68,7 @@
class DefaultInterpreterFallback : public InterpreterFallback {
public:
Expand All @@ -2632,16 +2632,4 @@ diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/referen

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -764,7 +764,7 @@

// Clean up operand buffers after refinement
// Must do in this pattern to avoid needing multiple refinement iterations
- if (op.getCallTargetName().equals(kCustomCallOperandBarrierTarget)) {
+ if (op.getCallTargetName() == kCustomCallOperandBarrierTarget) {
Value operand = op.getOperand(0);
if (operand.getType() == op.getResult(0).getType()) {
op.replaceAllUsesWith(ValueRange(operand));

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "797bee217e1a041e9aac22cad4db207274596d94"
STABLEHLO_SHA256 = "e5619033e131ea2eeb9eab8c8e362f3ba12e111c6b4a15dac789ca216ff22c58"
STABLEHLO_COMMIT = "c44d9af8d4879adccf1054cb61a53377ae5898cb"
STABLEHLO_SHA256 = "a8f5d4df0256e9d1c7b35fead77c31b9d8d985a0909eb198374faa9f7de15e94"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
126 changes: 57 additions & 69 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -175,53 +175,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists
add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/binary.mlir
@@ -155,7 +155,7 @@

// CHECK-LABEL: @maximum_f64
func.func @maximum_f64(%arg0 : tensor<10xf64>, %arg1 : tensor<10xf64>) -> tensor<10xf64> {
- // CHECK: stablehlo.maximum
+ // CHECK: tosa.maximum
%0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
--- stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
+++ stablehlo/stablehlo/conversions/tosa/tests/nullary.mlir
@@ -9,8 +9,7 @@

// CHECK-LABEL: @constant_f64
func.func @constant_f64() -> tensor<10xf64> {
- // TOSA does not support 64-bit types, so this should not legalize.
- // CHECK: stablehlo.constant
+ // CHECK: tosa.const
%0 = stablehlo.constant dense<0.000000e+00> : tensor<10xf64>
return %0 : tensor<10xf64>
}
diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp
--- stablehlo/stablehlo/dialect/AssemblyFormat.cpp
+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp
@@ -305,8 +305,7 @@
bool isCommutativeNoRegionMatchingDialect(OperationName innerOp,
StringRef reduceOpDialect) {
auto innerOpDialect = innerOp.getDialect();
- return innerOpDialect &&
- innerOpDialect->getNamespace().equals(reduceOpDialect) &&
+ return innerOpDialect && innerOpDialect->getNamespace() == reduceOpDialect &&
innerOp.hasTrait<mlir::OpTrait::NOperands<2>::Impl>() &&
innerOp.hasTrait<mlir::OpTrait::OneResult>() &&
(innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
@@ -359,7 +358,7 @@
// Check E5.
LLVM_DEBUG(llvm::dbgs() << "Checking ReduceOp compact print E5\n");
auto retOp = block.getTerminator();
- if (!retOp->getName().stripDialect().equals("return")) return false;
+ if (retOp->getName().stripDialect() != "return") return false;

return llvm::equal(innerOp.getResults(), retOp->getOperands());
}
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down Expand Up @@ -2440,7 +2393,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloCanonicalizeDy
diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.cpp
@@ -0,0 +1,170 @@
@@ -0,0 +1,171 @@
+/* Copyright 2022 The StableHLO Authors.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2594,6 +2547,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+
+ RewritePatternSet patterns(&getContext());
+ populateStablehloRefineShapesPatterns(&patterns, &getContext());
+ populateStablehloShapeFolderPatterns(&patterns, &getContext());
+ patterns.add<RefineDynamicReduceWindowOpPattern>(&getContext());
+ patterns.add<RefineDynamicRngBitGeneratorOpPattern>(&getContext());
+ patterns.add<RefineDynamicTopKOpPattern>(&getContext());
Expand All @@ -2611,18 +2565,64 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py
--- stablehlo/stablehlo/integrations/python/tests/stablehlo.py
+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py
@@ -115,14 +115,17 @@
operand_batching_dims=[6, 7],
start_indices_batching_dims=[8, 9],
start_index_map=[10],
- index_vector_dim=11)
- assert attr is not None
- assert str(attr) == ("#stablehlo.gather<offset_dims = [1, 2], "
- "collapsed_slice_dims = [3, 4, 5], "
- "operand_batching_dims = [6, 7], "
- "start_indices_batching_dims = [8, 9], "
- "start_index_map = [10], "
- "index_vector_dim = 11>")
+ index_vector_dim=11,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.gather<offset_dims = [1, 2], "
+ "collapsed_slice_dims = [3, 4, 5], "
+ "operand_batching_dims = [6, 7], "
+ "start_indices_batching_dims = [8, 9], "
+ "start_index_map = [10], "
+ "index_vector_dim = 11>"
+ )
assert attr.offset_dims == [1, 2]
assert attr.collapsed_slice_dims == [3, 4, 5]
assert attr.operand_batching_dims == [6, 7]
@@ -178,14 +181,17 @@
input_batching_dims=[6, 7],
scatter_indices_batching_dims=[8, 9],
scattered_dims_to_operand_dims=[10, 11],
- index_vector_dim=12)
- assert attr is not None
- assert str(attr) == ("#stablehlo.scatter<update_window_dims = [1, 2, 3], "
- "inserted_window_dims = [4, 5], "
- "input_batching_dims = [6, 7], "
- "scatter_indices_batching_dims = [8, 9], "
- "scatter_dims_to_operand_dims = [10, 11], "
- "index_vector_dim = 12>")
+ index_vector_dim=12,
+ )
+ assert attr is not None
+ assert str(attr) == (
+ "#stablehlo.scatter<update_window_dims = [1, 2, 3], "
+ "inserted_window_dims = [4, 5], "
+ "input_batching_dims = [6, 7], "
+ "scatter_indices_batching_dims = [8, 9], "
+ "scatter_dims_to_operand_dims = [10, 11], "
+ "index_vector_dim = 12>"
+ )
assert attr.update_window_dims == [1, 2, 3]
assert attr.inserted_window_dims == [4, 5]
assert attr.input_batching_dims == [6, 7]
diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/reference/Api.cpp
--- stablehlo/stablehlo/reference/Api.cpp
+++ stablehlo/stablehlo/reference/Api.cpp
@@ -51,7 +51,7 @@
auto functions = module.getOps<func::FuncOp>();

for (auto funcOp : functions)
- if (funcOp.getSymName().equals(mainName)) return funcOp;
+ if (funcOp.getSymName() == mainName) return funcOp;

bool isSingleFunction =
std::distance(functions.begin(), functions.end()) == 1;
@@ -68,7 +68,7 @@
class DefaultInterpreterFallback : public InterpreterFallback {
public:
Expand All @@ -2632,16 +2632,4 @@ diff --ruN a/stablehlo/stablehlo/reference/Api.cpp b/stablehlo/stablehlo/referen

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -764,7 +764,7 @@

// Clean up operand buffers after refinement
// Must do in this pattern to avoid needing multiple refinement iterations
- if (op.getCallTargetName().equals(kCustomCallOperandBarrierTarget)) {
+ if (op.getCallTargetName() == kCustomCallOperandBarrierTarget) {
Value operand = op.getOperand(0);
if (operand.getType() == op.getResult(0).getType()) {
op.replaceAllUsesWith(ValueRange(operand));

4 changes: 2 additions & 2 deletions third_party/xla/third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "797bee217e1a041e9aac22cad4db207274596d94"
STABLEHLO_SHA256 = "e5619033e131ea2eeb9eab8c8e362f3ba12e111c6b4a15dac789ca216ff22c58"
STABLEHLO_COMMIT = "c44d9af8d4879adccf1054cb61a53377ae5898cb"
STABLEHLO_SHA256 = "a8f5d4df0256e9d1c7b35fead77c31b9d8d985a0909eb198374faa9f7de15e94"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
Loading

0 comments on commit bc42c0c

Please sign in to comment.