Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632778493
  • Loading branch information
tensorflower-gardener committed May 11, 2024
1 parent 69e94a9 commit d325637
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 24 deletions.
56 changes: 45 additions & 11 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -175,26 +175,29 @@ cc_library(
":tf2xla_util",
":xla_compiler",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:import_model",
"//tensorflow/compiler/mlir/tensorflow:import_utils",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow/transforms:bridge",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes",
"//tensorflow/compiler/mlir/tf2xla:compile_mlir_util",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:framework_types_hdr",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:device_set",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:ShapeDialect",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/client",
"@local_xla//xla/client:xla_computation",
"@local_xla//xla/mlir_hlo",
"@local_xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
],
)

Expand Down Expand Up @@ -781,11 +784,12 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:status",
"@local_xla//xla:literal",
"@local_xla//xla:shape_util",
"@local_xla//xla:status_macros",
"@local_xla//xla:xla_data_proto_cc",
],
)
Expand Down Expand Up @@ -1008,9 +1012,14 @@ tf_cc_test(
deps = [
":common",
"//tensorflow/core:framework",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
"@local_xla//xla:literal",
"@local_xla//xla:literal_util",
],
Expand Down Expand Up @@ -1047,9 +1056,19 @@ cc_library(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla/kernels:light_outside_compilation",
"//tensorflow/core:framework",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/stream_executor:device_memory",
"@local_xla//xla/stream_executor:stream_executor_install_hdrs_gather",
],
alwayslink = 1,
)
Expand Down Expand Up @@ -1172,6 +1191,7 @@ cc_library(
hdrs = ["mlir_bridge_pass.h"],
visibility = [":internal"],
deps = [
":tf2xla_defs",
":xla_op_registry",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
Expand All @@ -1190,13 +1210,15 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:device_set",
"//tensorflow/core/protobuf:for_core_protos_cc",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
Expand All @@ -1213,8 +1235,8 @@ cc_library(
],
deps = [
":mlir_bridge_pass",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
"//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,
)
Expand Down Expand Up @@ -1435,11 +1457,23 @@ cc_library(
hdrs = ["mlir_xla_op_kernel.h"],
deps = [
":xla_compiler",
":xla_expression",
"//tensorflow/compiler/jit:xla_compile_util",
"//tensorflow/compiler/mlir/tf2xla/api/v1:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/mlir/utils:array_container_utils",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:resource_base",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
"@local_xla//xla/client:xla_builder",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,28 @@ limitations under the License.
#include <algorithm>
#include <string>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

// Sample kernels for the light outside compilation test.

Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/tf2xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,24 @@ limitations under the License.

#include "tensorflow/compiler/tf2xla/literal_util.h"

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tsl/platform/errors.h"

namespace tensorflow {

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_

#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/tf2xla/literal_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ limitations under the License.

#include "tensorflow/compiler/tf2xla/literal_util.h"

#include <gtest/gtest.h>
#include "absl/types/span.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tsl/protobuf/error_codes.pb.h"

namespace tensorflow {
namespace {
Expand Down
9 changes: 6 additions & 3 deletions tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ limitations under the License.
#include <string>

#include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
#include "absl/algorithm/container.h"
#include "absl/base/call_once.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
Expand All @@ -36,15 +36,18 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h"
#include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/util/device_name_utils.h"
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/tf2xla/mlir_bridge_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ limitations under the License.

#include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/config.pb.h"

namespace tensorflow {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <memory>

#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"

namespace tensorflow {
Expand Down
32 changes: 23 additions & 9 deletions tensorflow/compiler/tf2xla/mlir_tf2xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,38 @@ limitations under the License.
#include <utility>
#include <vector>

#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "absl/algorithm/container.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "xla/client/xla_computation.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_debug_info.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace tensorflow {

Expand Down
21 changes: 20 additions & 1 deletion tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,31 @@ limitations under the License.

#include <string>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/compiler/jit/xla_compile_util.h"
#include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/client/xla_builder.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/graph_debug_info.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/resource_base.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace tensorflow {

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_

#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/status.h"

namespace tensorflow {

Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ cc_library(
"//learning/deepmind/tensorflow/queues:__pkg__",
"//learning/deepmind/tensorflow/sstable:__pkg__",
"//tensorflow/compiler/mlir/tools/kernel_gen:__pkg__",
"//tensorflow/compiler/tf2xla:__pkg__",
"//third_party/py/grain/_src/tensorflow/ops:__pkg__",
"//waymo/ml/compiler/frontend/kernels:__pkg__",
],
Expand Down

0 comments on commit d325637

Please sign in to comment.