Skip to content

Commit

Permalink
Update stablehlo.custom_call's caller_name attribute from `String…
Browse files Browse the repository at this point in the history
…Attr` to `FlatSymbolRefAttr` after deserializing `XlaCallModule`

PiperOrigin-RevId: 534637021
  • Loading branch information
tensorflower-gardener committed May 24, 2023
1 parent 37b3572 commit 75599f8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ module {

// CHECK-LABEL: func private @_stablehlo_main_0
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xi32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %[[ARG1:.*]]: tensor<*xi32>) -> (tensor<?xi32> {jax.result_info = ""}) attributes {_from_xla_call_module} {
// CHECK: stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]]) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__tf_host_callback"}} : (tensor<?xi32>, tensor<*xi32>) -> ()
// CHECK: stablehlo.custom_call @tf.call_tf_function(%[[ARG0]], %[[ARG1]]) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = @__tf_host_callback}} : (tensor<?xi32>, tensor<*xi32>) -> ()
// CHECK: return %arg0 : tensor<?xi32>
// CHECK: }
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ module {
// CHECK-SAME: {
// CHECK-SAME: api_version = 2 : i32,
// CHECK-SAME: has_side_effect = true,
// CHECK-SAME: tf.backend_config = {caller_name = "__tf_host_callback"}
// CHECK-SAME: tf.backend_config = {caller_name = @__tf_host_callback}
// CHECK-SAME: }
stablehlo.custom_call @tf.call_tf_function(%arg0, %arg1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__tf_host_callback"}} : (tensor<?xi32>, tensor<*xi32>) -> ()
// CHECK: call @_stablehlo__stablehlo_f_0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ limitations under the License.
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep
#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h"
#include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h"
#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep
Expand Down Expand Up @@ -130,6 +132,21 @@ void CopyStablehloModuleAttrs(ModuleOp stablehlo_module, XlaCallModuleOp op) {
stablehlo_module->getAttrDictionary());
}

// Update `caller_name` from `StringAttr` to `FlatSymbolRefAttr`.
LogicalResult SymbolizeCustomCallCallerName(ModuleOp module) {
auto result = module.walk([&](stablehlo::CustomCallOp op) {
auto name = GetTfHostCallbackName(op);
if (failed(name)) {
return WalkResult::interrupt();
}
if (*name != nullptr) {
SetTfHostCallbackName(op, FlatSymbolRefAttr::get(op.getContext(), *name));
}
return WalkResult::advance();
});
return result.wasInterrupted() ? failure() : success();
}

LogicalResult DeserializeXlaCallModule(MLIRContext *context,
SymbolTableCollection &symbol_tables,
ModuleOp module, XlaCallModuleOp op) {
Expand All @@ -149,6 +166,10 @@ LogicalResult DeserializeXlaCallModule(MLIRContext *context,

CopyFunctions(symbol_tables, stablehlo_module.get(), module);

if (failed(SymbolizeCustomCallCallerName(module))) {
return failure();
}

// Module is deserialized, we set an empty string to it instead removing
// it because it's a required attribute.
op.setModule("");
Expand Down

0 comments on commit 75599f8

Please sign in to comment.