Skip to content

Commit

Permalink
Migrate deprecated types to their replacements.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631428582
  • Loading branch information
tensorflower-gardener committed May 22, 2024
1 parent 74a9135 commit a977e1e
Show file tree
Hide file tree
Showing 62 changed files with 672 additions and 309 deletions.
22 changes: 1 addition & 21 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,6 @@ absl::Status ValidateSEPlatformRegistrationParams(
}
#undef TF_VALIDATE_NOT_NULL

// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}

// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
Expand Down Expand Up @@ -422,12 +408,6 @@ class CStreamExecutor : public StreamExecutor {
absl::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
SP_Event event_handle = static_cast<CEvent*>(event)->Handle();
SE_EventStatus event_status =
stream_executor_->get_event_status(&device_, event_handle);
return SEEventStatusToEventStatus(event_status);
}
void DeallocateStream(Stream* stream) override {
static_cast<CStream*>(stream->implementation())->Destroy();
}
Expand Down Expand Up @@ -541,7 +521,7 @@ class CStreamExecutor : public StreamExecutor {
}

absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_, this);
auto c_event = std::make_unique<CEvent>(&device_, stream_executor_);
TF_RETURN_IF_ERROR(c_event->Create());
return std::move(c_event);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,28 @@ class CStream : public StreamInterface {

class CEvent : public Event {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor,
StreamExecutorInterface* executor_interface)
: Event(executor_interface),
device_(device),
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }

Event::Status PollForStatus() override {
SE_EventStatus event_status =
stream_executor_->get_event_status(device_, event_handle_);

switch (event_status) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}

absl::Status Create() {
tensorflow::TF_StatusPtr c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,9 @@ func.func @decompose_resource_gather_op(%indices : tensor<?xi32>) -> tensor<*xi3
%resource = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf_type.resource<tensor<*xi32>>>

// CHECK-DAG: [[READVAR:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) <{batch_dims = 0 : i64}> : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
// CHECK: [[GATHER:%.+]] = "tf.GatherV2"([[READVAR]], [[INDEX]], [[ZERO]]) <{batch_dims = 0 : i64}> {_xla_outside_compilation = "0"} : (tensor<*xi32>, tensor<?xi32>, tensor<i64>) -> tensor<*xi32>
// CHECK: return [[GATHER]]
%1 = "tf.ResourceGather"(%resource, %indices) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<?xi32>) -> (tensor<*xi32>)
%1 = "tf.ResourceGather"(%resource, %indices) {_xla_outside_compilation = "0"} : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<?xi32>) -> (tensor<*xi32>)
tf_device.return %1 : tensor<*xi32>
}) : () -> (tensor<*xi32>)
func.return %0: tensor<*xi32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def Clamp: NativeCodeCall<
" $0.getLoc(),"
" $2.getType(), $2, $1, $3)">;

def CopyXlaOutsideCompilationAttr: NativeCodeCallVoid<
"CopyXlaOutsideCompilationAttributesAdaptor($0, $1)">;

def DecomposeAssignAddVariableOp :
Pat<
(TF_AssignAddVariableOp:$src_op $resource, $value),
Expand Down Expand Up @@ -338,12 +341,12 @@ def DecomposeResourceApplyAdamNesterov :
def DecomposeResourceGather : Pat<
(TF_ResourceGatherOp:$old_result
$resource, $indices, $batch_dims, $validate_indices),
(TF_GatherV2Op
(TF_GatherV2Op:$dest
(CreateTFReadVariableOp $old_result, $old_result, $resource),
$indices,
(TF_ConstOp $batch_dims), // axis
$batch_dims
)>;
), [], [(CopyXlaOutsideCompilationAttr $old_result, $dest)]>;

// Pattern to decompose tf.ResourceScatterAdd into tf.ReadVariable,
// tf.TensorScatterAdd, and tf.AssignVariable.
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,20 @@ void CopyDeviceAndUnderscoredAttributesAdaptor(mlir::Operation *src,
mlir::Operation *dest) {
CopyDeviceAndUnderscoredAttributes(src, dest);
}

void CopyXlaOutsideCompilationAttributesAdaptor(mlir::OpResult src,
mlir::OpResult dest) {
CopyXlaOutsideCompilationAttributesAdaptor(src.getOwner(), dest.getOwner());
}

void CopyXlaOutsideCompilationAttributesAdaptor(mlir::Operation *src,
mlir::OpResult dest) {
CopyXlaOutsideCompilationAttributesAdaptor(src, dest.getOwner());
}

void CopyXlaOutsideCompilationAttributesAdaptor(mlir::Operation *src,
mlir::Operation *dest) {
CopyXlaOutsideCompilationAttributes(src, dest);
}
} // namespace TF
} // namespace mlir
9 changes: 9 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ void CopyDeviceAndUnderscoredAttributesAdaptor(mlir::Operation *src,
mlir::OpResult dest);
void CopyDeviceAndUnderscoredAttributesAdaptor(mlir::Operation *src,
mlir::Operation *dest);

// Wrappers for CopyXlaOutsideCompilationAttributes
void CopyXlaOutsideCompilationAttributesAdaptor(mlir::OpResult src,
mlir::OpResult dest);
void CopyXlaOutsideCompilationAttributesAdaptor(mlir::Operation *src,
mlir::OpResult dest);
void CopyXlaOutsideCompilationAttributesAdaptor(mlir::Operation *src,
mlir::Operation *dest);

} // namespace TF
} // namespace mlir

Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ inline const char kMlirPh1BridgeCounterV1[] = "v1";
inline const char kMlirPh1BridgeCounterV2[] = "v2";
inline const char kMlirPh1BridgeCounterTpu[] = "tpu";
inline const char kMlirPh1BridgeCounterNonTpu[] = "cpu/gpu";
inline const char kXlaOutsideCompilation[] = "_xla_outside_compilation";

// Copies attributes that satisfy the given predicate from `from` to `to`.
template <typename Predicate>
Expand All @@ -148,6 +149,14 @@ inline void CopyUnderscoredAttributes(Operation *from, Operation *to) {
});
}

// Copies outside compilation attribute from `from` to `to`.
inline void CopyXlaOutsideCompilationAttributes(Operation *from,
Operation *to) {
CopyAttributes(from, to, [](const NamedAttribute &attr) {
return attr.getName().strref() == kXlaOutsideCompilationAttr;
});
}

// Copies attributes that are either `device` or whose name begins with an _
// from `from` to `to`.
// TODO(b/158769932): This should be a general feature instead post some policy
Expand Down
10 changes: 7 additions & 3 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3560,7 +3560,10 @@ cc_library(
tf_kernel_library(
name = "aggregate_ops",
prefix = "aggregate_ops",
deps = MATH_DEPS + [":variant_ops_util"],
deps = MATH_DEPS + [
":variant_ops_util",
"@com_google_absl//absl/status",
],
)

tf_kernel_library(
Expand Down Expand Up @@ -3648,7 +3651,7 @@ tf_kernel_library(
features = if_cuda(["-layering_check"]),
gpu_srcs = ["gpu_device_array.h"],
prefix = "bucketize_op",
deps = ARRAY_DEPS,
deps = ARRAY_DEPS + ["@com_google_absl//absl/status"],
)

tf_kernel_library(
Expand Down Expand Up @@ -4464,11 +4467,12 @@ tf_kernel_library(
prefix = "bincount_op",
deps = [
":fill_functor",
":gpu_prim_hdrs",
":sparse_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@eigen_archive//:eigen3",
],
Expand Down
11 changes: 7 additions & 4 deletions tensorflow/core/kernels/aggregate_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ limitations under the License.

#include "tensorflow/core/kernels/aggregate_ops.h"

#include <utility>

#include "absl/status/status.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
Expand Down Expand Up @@ -52,7 +55,7 @@ class AddNOp : public OpKernel {

// Try to forward and accumulate the result in one of the input buffers.
int reused_input = -1;
gtl::InlinedVector<int, 8> input_indices(num);
absl::InlinedVector<int, 8> input_indices(num);
std::iota(input_indices.begin(), input_indices.end(), 0);
Tensor* output = nullptr;
for (int input_idx = 0; input_idx < num; ++input_idx) {
Expand Down Expand Up @@ -172,8 +175,8 @@ class AddNOp<Device, Variant> : public OpKernel {
// the inputs into temp at the lowest levels of the summation tree.
static inline Status AddVariantTo(OpKernelContext* ctx, const int lhs_ix,
const int rhs_ix,
gtl::InlinedVector<Variant, 4>* temp,
gtl::InlinedVector<bool, 4>* temp_filled) {
absl::InlinedVector<Variant, 4>* temp,
absl::InlinedVector<bool, 4>* temp_filled) {
Variant tmp;
if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
const Variant& a = temp_filled->at(lhs_ix)
Expand All @@ -186,7 +189,7 @@ class AddNOp<Device, Variant> : public OpKernel {
TF_RETURN_IF_ERROR(
BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
temp_filled->at(lhs_ix) = true;
return OkStatus();
return absl::OkStatus();
}
};

Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/batchtospace_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#define EIGEN_USE_THREADS

#include <cstdint>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -64,8 +65,8 @@ static void BatchToSpaceOpCompute(OpKernelContext* context,
orig_crops.shape().DebugString()));
// To avoid out-of-bounds access in the case that the block_shape and/or
// crops tensors are concurrently modified, we must copy the values.
gtl::InlinedVector<int64_t, 4> block_shape;
gtl::InlinedVector<int64_t, 8> crops;
absl::InlinedVector<int64_t, 4> block_shape;
absl::InlinedVector<int64_t, 8> crops;
internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops);

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/bcast_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BCastArgsOp : public OpKernel {
OP_REQUIRES(
ctx, ctx->num_inputs() == 2,
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
gtl::InlinedVector<BCast::Vec, 4> shapes;
absl::InlinedVector<BCast::Vec, 4> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& in = ctx->input(i);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
Expand Down Expand Up @@ -81,7 +81,7 @@ class BCastGradArgsOp : public OpKernel {
OP_REQUIRES(
ctx, ctx->num_inputs() == 2,
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
gtl::InlinedVector<BCast::Vec, 4> shapes;
absl::InlinedVector<BCast::Vec, 4> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& in = ctx->input(i);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
Expand Down
20 changes: 12 additions & 8 deletions tensorflow/core/kernels/bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@ limitations under the License.

// See docs in ../ops/math_ops.cc.

#include <atomic>

#include "tensorflow/core/platform/errors.h"
#define EIGEN_USE_THREADS

#include "tensorflow/core/kernels/bincount_op.h"

#include <atomic>
#include <cstdint>

#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bincount_op.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/sparse_utils.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/determinism.h"

Expand Down Expand Up @@ -81,7 +85,7 @@ struct BincountFunctor<CPUDevice, Tidx, T, true> {
Eigen::array<int, 1> reduce_dim({0});
output.device(context->eigen_cpu_device()) =
partial_bins.any(reduce_dim).cast<T>();
return OkStatus();
return absl::OkStatus();
}
};

Expand Down Expand Up @@ -164,7 +168,7 @@ struct BincountFunctor<CPUDevice, Tidx, T, false> {
Eigen::array<int, 1> reduce_dim({0});
output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dim);
}
return OkStatus();
return absl::OkStatus();
}
};

Expand Down Expand Up @@ -209,7 +213,7 @@ struct BincountReduceFunctor<CPUDevice, Tidx, T, binary_output> {
static_cast<int>(err_neg_val)));
}

return OkStatus();
return absl::OkStatus();
}
};

Expand Down Expand Up @@ -325,7 +329,7 @@ class DenseBincountOp : public OpKernel {
const int64_t num_rows = data.dim_size(0);
auto weight_matrix =
(weights.NumElements() == 0)
? weights.shaped<T, 2>(gtl::InlinedVector<int64_t, 2>(2, 0))
? weights.shaped<T, 2>(absl::InlinedVector<int64_t, 2>(2, 0))
: weights.matrix<T>();
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t));
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/kernels/bucketize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.

#include "tensorflow/core/kernels/bucketize_op.h"

#include "absl/status/status.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down Expand Up @@ -44,7 +46,7 @@ struct BucketizeFunctor<CPUDevice, T> {
output(i) = first_bigger_it - boundaries_vector.begin();
}

return OkStatus();
return absl::OkStatus();
}
};

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/conv_grad_shape_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct ConvBackpropSpatialDimension {
// Computed dimensions for a backwards convolution.
struct ConvBackpropDimensions {
// Information about each spatial dimension.
gtl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims;
absl::InlinedVector<ConvBackpropSpatialDimension, 3> spatial_dims;

// Batch size.
int64_t batch_size;
Expand Down
6 changes: 1 addition & 5 deletions third_party/xla/xla/backends/interpreter/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class XlaInterpreterExecutor : public StreamExecutor {
return absl::Status{absl::StatusCode::kUnimplemented, "WaitForEvent"};
}

Event::Status PollForEventStatus(Event *event) override {
return Event::Status::kError;
}

void DeallocateStream(Stream *stream) override {}
bool CreateStreamDependency(Stream *dependent, Stream *other) override;

Expand All @@ -150,7 +146,7 @@ class XlaInterpreterExecutor : public StreamExecutor {
return true;
}
absl::StatusOr<std::unique_ptr<Event>> CreateEvent() override {
return std::make_unique<Event>(this);
return std::make_unique<Event>();
}

absl::StatusOr<std::unique_ptr<Stream>> CreateStream(
Expand Down

0 comments on commit a977e1e

Please sign in to comment.