Skip to content

Commit

Permalink
Remove XlaOutlineEntryFunctions pass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623407968
  • Loading branch information
jcai19 authored and tensorflower-gardener committed Apr 10, 2024
1 parent 4a5bfc0 commit eebfd2e
Show file tree
Hide file tree
Showing 7 changed files with 1 addition and 306 deletions.

This file was deleted.

Expand Up @@ -189,8 +189,6 @@ void AddNonReplicatedBridgeClusteringPipelinePasses(OpPassManager& pm) {
// inference.
pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
pm.addPass(mlir::TF::CreateTFShapeInferencePass());
pm.addPass(
tensorflow::tf2xla::internal::CreateXlaOutlineEntryFunctionsPass());
// Encapsulate PartitionedCall ops within a cluster so that the composite
// resource ops can be decomposed.
pm.addPass(tensorflow::tf2xla::internal::CreateXlaClusterFormationPass());
Expand Down
Expand Up @@ -35,7 +35,7 @@ TEST(ClusteringBridgePassesTest, AddsNonTPUBridgePasses) {
OpPassManager pass_manager;
AddNonReplicatedBridgeClusteringPipelinePasses(pass_manager);

EXPECT_EQ(pass_manager.size(), 16);
EXPECT_EQ(pass_manager.size(), 15);
}

}; // namespace internal
Expand Down
31 changes: 0 additions & 31 deletions tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD
Expand Up @@ -33,7 +33,6 @@ cc_library(
":verify_clustering_pass",
":xla_broadcast",
":xla_cluster_formation",
":xla_outline_entry_functions",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down Expand Up @@ -289,36 +288,6 @@ cc_library(
],
)

cc_library(
name = "xla_outline_entry_functions",
srcs = ["xla_outline_entry_functions.cc"],
textual_hdrs = [
"clustering_passes.h.inc",
],
deps = [
":clustering_passes_inc_gen",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
"//tensorflow/compiler/mlir/tensorflow:call_graph_util",
"//tensorflow/compiler/mlir/tensorflow:cluster_util",
"//tensorflow/compiler/mlir/tensorflow:string_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
)

cc_library(
name = "mark_ops_for_outside_compilation",
srcs = ["mark_ops_for_outside_compilation.cc"],
Expand Down
Expand Up @@ -46,11 +46,6 @@ CreateExtractOutsideCompilationPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateXlaClusterFormationPass();

// Create a pass that rewrites entry functions with `_xla_compile_device` into a
// `tf.StatefulPartitionedCall` to the original function.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateXlaOutlineEntryFunctionsPass();

// Creates a pass that marks unsupported ops in device cluster for outside
// compilation.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
Expand Down
Expand Up @@ -253,56 +253,6 @@ def XlaClusterFormationPass : Pass<"tf-xla-cluster-formation", "ModuleOp"> {
let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"];
}

def XlaOutlineEntryFunctionsPass : Pass<"tf-xla-outline-entry-functions", "ModuleOp"> {
let summary = "Outline the body of an entry function into a call to the "
"original function body";
let description = [{
This pass adds support for top-level function with
`_xla_compile_device_type` attribute in MLIR generic pipeline.
It renames such a function, and creates a new function taking the original
name with a `tf.StatefulPartitionedCall` to the original function. It
allows the MLIR generic pipeline to handle such functions the same way it
handles other partitioned calls with the attribute.

For example, the following code

```mlir
func.func @main(%arg0: tensor<i32>) -> tensor<i32> attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true, tf.entry_function = {}} {
%0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", _xla_compile_device_type = "CPU", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.Add"(%0, %1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
func.return %2 : tensor<i32>
}

func.func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0 : tensor<i32>
}
```

will be replaced as

```mlir
func.func private @main_outlined(%arg0: tensor<i32>) -> tensor<i32> attributes {_xla_compile_device_type = "CPU", allow_soft_placement = true} {
%0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> tensor<i32>
%cst = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Add"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}

func.func @main(%arg0: tensor<i32>) -> tensor<i32> attributes {_xla_compile_device_type = "CPU", tf.entry_function = {}} {
%0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", allow_soft_placement = true, config = "", config_proto = "", executor_type = "", f = @main_outlined} : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}

func.func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
func.return %arg0 : tensor<i32>
}
```
}];
let constructor = "tensorflow::tf2xla::internal::CreateXlaOutlineEntryFunctionsPass()";
let dependentDialects = ["mlir::tf_device::TensorFlowDeviceDialect"];
}

def MarkOpsForOutsideCompilationPass : Pass<"tf-mark-ops-for-outside-compilation", "ModuleOp"> {
let summary = "Marks ops in device cluster for outside compilation if they are unsupported on device.";

Expand Down

This file was deleted.

0 comments on commit eebfd2e

Please sign in to comment.