diff --git a/tensorflow/core/data/service/common.cc b/tensorflow/core/data/service/common.cc index 86ff05fb7aed37..cf0d352baf5a81 100644 --- a/tensorflow/core/data/service/common.cc +++ b/tensorflow/core/data/service/common.cc @@ -64,7 +64,7 @@ Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) { return absl::OkStatus(); } -StatusOr ToAutoShardPolicy( +absl::StatusOr ToAutoShardPolicy( const ProcessingModeDef::ShardingPolicy sharding_policy) { switch (sharding_policy) { case ProcessingModeDef::FILE: @@ -87,7 +87,7 @@ StatusOr ToAutoShardPolicy( } } -StatusOr ParseTargetWorkers(absl::string_view s) { +absl::StatusOr ParseTargetWorkers(absl::string_view s) { std::string str_upper = absl::AsciiStrToUpper(s); if (str_upper.empty() || str_upper == kAuto) { return TARGET_WORKERS_AUTO; @@ -115,7 +115,7 @@ std::string TargetWorkersToString(TargetWorkers target_workers) { } } -StatusOr ParseDeploymentMode(absl::string_view s) { +absl::StatusOr ParseDeploymentMode(absl::string_view s) { std::string str_upper = absl::AsciiStrToUpper(s); if (str_upper == kColocated) { return DEPLOYMENT_MODE_COLOCATED; diff --git a/tensorflow/core/data/service/common.h b/tensorflow/core/data/service/common.h index 873c8361325840..550cffeb7b9558 100644 --- a/tensorflow/core/data/service/common.h +++ b/tensorflow/core/data/service/common.h @@ -71,19 +71,19 @@ Status ValidateProcessingMode(const ProcessingModeDef& processing_mode); // Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an // internal error if `sharding_policy` is not supported. -StatusOr ToAutoShardPolicy( +absl::StatusOr ToAutoShardPolicy( ProcessingModeDef::ShardingPolicy sharding_policy); // Parses a string representing a `TargetWorkers` (case-insensitive). // Returns InvalidArgument if the string is not recognized. -StatusOr ParseTargetWorkers(absl::string_view s); +absl::StatusOr ParseTargetWorkers(absl::string_view s); // Converts a `TargetWorkers` enum to string. std::string TargetWorkersToString(TargetWorkers target_workers); // Parses a string representing a `DeploymentMode` (case-insensitive). // Returns InvalidArgument if the string is not recognized. -StatusOr ParseDeploymentMode(absl::string_view s); +absl::StatusOr ParseDeploymentMode(absl::string_view s); // Returns true if `status` is a retriable error that indicates preemption. bool IsPreemptedError(const Status& status); diff --git a/tensorflow/core/data/service/cross_trainer_cache_test.cc b/tensorflow/core/data/service/cross_trainer_cache_test.cc index 9d4aff79f6f2e6..a0359b5135266e 100644 --- a/tensorflow/core/data/service/cross_trainer_cache_test.cc +++ b/tensorflow/core/data/service/cross_trainer_cache_test.cc @@ -51,7 +51,7 @@ using ::testing::UnorderedElementsAreArray; class InfiniteRange : public CachableSequence { public: - StatusOr GetNext() override { return next_++; } + absl::StatusOr GetNext() override { return next_++; } size_t GetElementSizeBytes(const int64_t& element) const override { return sizeof(element); } @@ -63,7 +63,7 @@ class InfiniteRange : public CachableSequence { class TensorDataset : public CachableSequence { public: - StatusOr GetNext() override { return Tensor("Test Tensor"); } + absl::StatusOr GetNext() override { return Tensor("Test Tensor"); } size_t GetElementSizeBytes(const Tensor& element) const override { return element.TotalBytes(); } @@ -73,7 +73,7 @@ class SlowDataset : public CachableSequence { public: explicit SlowDataset(absl::Duration delay) : delay_(delay) {} - StatusOr GetNext() override { + absl::StatusOr GetNext() override { Env::Default()->SleepForMicroseconds(absl::ToInt64Microseconds(delay_)); return Tensor("Test Tensor"); } @@ -369,7 +369,7 @@ TEST(CrossTrainerCacheTest, Cancel) { /*thread_options=*/{}, /*name=*/absl::StrCat("Trainer_", i), [&cache, &status, &mu]() { for (int j = 0; true; ++j) { - StatusOr> tensor = + absl::StatusOr> tensor = cache.Get(absl::StrCat("Trainer_", j % 1000)); { mutex_lock l(mu); diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index f506b3bcb13b54..52e07993195a61 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -235,7 +235,7 @@ TEST(DataServiceTest, GcMissingClientsWithSmallTimeout) { TF_ASSERT_OK(dataset_client.GetTasks(iteration_client_id).status()); // Iteration should be garbage collected within 10 seconds. absl::Time wait_start = absl::Now(); - TF_ASSERT_OK(WaitWhile([&]() -> StatusOr { + TF_ASSERT_OK(WaitWhile([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(size_t num_iterations, cluster.NumActiveIterations()); return num_iterations > 0; })); diff --git a/tensorflow/core/data/service/data_transfer.h b/tensorflow/core/data/service/data_transfer.h index f50fa12b4105c9..788dd241f185b7 100644 --- a/tensorflow/core/data/service/data_transfer.h +++ b/tensorflow/core/data/service/data_transfer.h @@ -91,7 +91,7 @@ class DataTransferClient { // Returns a string describing properties of the client relevant for checking // compatibility with a server for a given protocol. - virtual StatusOr GetCompatibilityInfo() const { + virtual absl::StatusOr GetCompatibilityInfo() const { return std::string(); } @@ -130,7 +130,7 @@ class DataTransferServer { // Returns a string describing properties of the server relevant for checking // compatibility with a client for a given protocol. - virtual StatusOr GetCompatibilityInfo() const { + virtual absl::StatusOr GetCompatibilityInfo() const { return std::string(); } }; diff --git a/tensorflow/core/data/service/dispatcher_client.cc b/tensorflow/core/data/service/dispatcher_client.cc index ee1a2b859b0fc3..ff3b165899f562 100644 --- a/tensorflow/core/data/service/dispatcher_client.cc +++ b/tensorflow/core/data/service/dispatcher_client.cc @@ -84,7 +84,8 @@ Status DataServiceDispatcherClient::Initialize() { return absl::OkStatus(); } -StatusOr DataServiceDispatcherClient::WorkerHeartbeat( +absl::StatusOr +DataServiceDispatcherClient::WorkerHeartbeat( const WorkerHeartbeatRequest& request) { WorkerHeartbeatResponse response; grpc::ClientContext client_ctx; diff --git a/tensorflow/core/data/service/dispatcher_client.h b/tensorflow/core/data/service/dispatcher_client.h index 40385928482040..9f521bd210ac6a 100644 --- a/tensorflow/core/data/service/dispatcher_client.h +++ b/tensorflow/core/data/service/dispatcher_client.h @@ -50,7 +50,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase { // registered with the dispatcher, this will register the worker. The // dispatcher will report which new tasks the worker should run, and which // tasks it should delete. - StatusOr WorkerHeartbeat( + absl::StatusOr WorkerHeartbeat( const WorkerHeartbeatRequest& request); // Updates the dispatcher with information about the worker's state. diff --git a/tensorflow/core/data/service/dispatcher_client_test.cc b/tensorflow/core/data/service/dispatcher_client_test.cc index 2cc4d15a5ae59a..64cf5c3c76e360 100644 --- a/tensorflow/core/data/service/dispatcher_client_test.cc +++ b/tensorflow/core/data/service/dispatcher_client_test.cc @@ -89,7 +89,7 @@ class DispatcherClientTest : public ::testing::Test { } // Creates a dataset and returns the dataset ID. - StatusOr RegisterDataset( + absl::StatusOr RegisterDataset( const DatasetDef& dataset, const DataServiceMetadata& metadata, const std::optional& requested_dataset_id = std::nullopt) { std::string dataset_id; @@ -99,7 +99,7 @@ class DispatcherClientTest : public ::testing::Test { } // Starts snapshots and returns the directories. - StatusOr> StartDummySnapshots( + absl::StatusOr> StartDummySnapshots( int64_t num_snapshots) { DistributedSnapshotMetadata metadata = CreateDummyDistributedSnapshotMetadata(); diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index cd5a5b608a5926..73ee07a6c34e64 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -622,7 +622,8 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( return absl::OkStatus(); } -StatusOr> DataServiceDispatcherImpl::FindDataset( +absl::StatusOr> +DataServiceDispatcherImpl::FindDataset( const GetOrRegisterDatasetRequest& request) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::shared_ptr existing_dataset; diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 3ece2684c95dab..5f1f31315a49fd 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -217,7 +217,7 @@ class DataServiceDispatcherImpl { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Finds the dataset ID with the requested dataset ID. // Returns nullptr if no such dataset exists. - StatusOr> FindDataset( + absl::StatusOr> FindDataset( const GetOrRegisterDatasetRequest& request); // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index 5383beaba644d7..22ab9ff2aeb988 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -467,7 +467,7 @@ Status DispatcherState::ValidateWorker(absl::string_view worker_address) const { return worker_index_resolver_.ValidateWorker(worker_address); } -StatusOr DispatcherState::GetWorkerIndex( +absl::StatusOr DispatcherState::GetWorkerIndex( absl::string_view worker_address) const { return worker_index_resolver_.GetWorkerIndex(worker_address); } diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index 39d10453763251..e64b48771400ad 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -288,7 +288,8 @@ class DispatcherState { // If the dispatcher config specifies worker addresses, `GetWorkerIndex` // returns the worker index according to the list. This is useful for // deterministically sharding a dataset among a fixed set of workers. - StatusOr GetWorkerIndex(absl::string_view worker_address) const; + absl::StatusOr GetWorkerIndex( + absl::string_view worker_address) const; // Returns the paths of all snapshots initiated during the lifetime of this // journal. diff --git a/tensorflow/core/data/service/graph_rewriters.cc b/tensorflow/core/data/service/graph_rewriters.cc index bf7066e5f51600..af2059ae89c707 100644 --- a/tensorflow/core/data/service/graph_rewriters.cc +++ b/tensorflow/core/data/service/graph_rewriters.cc @@ -93,7 +93,7 @@ bool ShouldReplaceDynamicPort(absl::string_view config_address, } } // namespace -StatusOr +absl::StatusOr RemoveCompressionMapRewriter::ApplyRemoveCompressionMapRewrite( const GraphDef& graph_def) { grappler::RemoveCompressionMap remove_compression_map; @@ -122,7 +122,8 @@ RemoveCompressionMapRewriter::GetRewriteConfig() const { return config; } -StatusOr AutoShardRewriter::Create(const TaskDef& task_def) { +absl::StatusOr AutoShardRewriter::Create( + const TaskDef& task_def) { TF_ASSIGN_OR_RETURN( AutoShardPolicy auto_shard_policy, ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy())); @@ -130,7 +131,7 @@ StatusOr AutoShardRewriter::Create(const TaskDef& task_def) { task_def.worker_index()); } -StatusOr AutoShardRewriter::ApplyAutoShardRewrite( +absl::StatusOr AutoShardRewriter::ApplyAutoShardRewrite( const GraphDef& graph_def) { if (auto_shard_policy_ == AutoShardPolicy::OFF) { return graph_def; @@ -214,7 +215,7 @@ void WorkerIndexResolver::AddWorker(absl::string_view worker_address) { } } -StatusOr WorkerIndexResolver::GetWorkerIndex( +absl::StatusOr WorkerIndexResolver::GetWorkerIndex( absl::string_view worker_address) const { const auto it = absl::c_find(worker_addresses_, worker_address); if (it == worker_addresses_.cend()) { diff --git a/tensorflow/core/data/service/graph_rewriters.h b/tensorflow/core/data/service/graph_rewriters.h index 7c0c347a836b1d..84c43a4f29d579 100644 --- a/tensorflow/core/data/service/graph_rewriters.h +++ b/tensorflow/core/data/service/graph_rewriters.h @@ -37,7 +37,7 @@ namespace data { class RemoveCompressionMapRewriter { public: // Returns `graph_def` with the compression map removed. - StatusOr ApplyRemoveCompressionMapRewrite( + absl::StatusOr ApplyRemoveCompressionMapRewrite( const GraphDef& graph_def); private: @@ -49,11 +49,11 @@ class AutoShardRewriter { public: // Creates an `AutoShardRewriter` according to `task_def`. Returns an error if // the sharding policy is not a valid auto-shard policy. - static StatusOr Create(const TaskDef& task_def); + static absl::StatusOr Create(const TaskDef& task_def); // Applies auto-sharding to `graph_def`. If auto-shard policy is OFF, returns // the same graph as `graph_def`. Otherwise, returns the re-written graph. - StatusOr ApplyAutoShardRewrite(const GraphDef& graph_def); + absl::StatusOr ApplyAutoShardRewrite(const GraphDef& graph_def); private: AutoShardRewriter(AutoShardPolicy auto_shard_policy, int64_t num_workers, @@ -97,7 +97,8 @@ class WorkerIndexResolver { // Returns the worker index for the worker at `worker_address`. Returns a // NotFound error if the worker is not registered. - StatusOr GetWorkerIndex(absl::string_view worker_address) const; + absl::StatusOr GetWorkerIndex( + absl::string_view worker_address) const; private: std::vector worker_addresses_; diff --git a/tensorflow/core/data/service/graph_rewriters_test.cc b/tensorflow/core/data/service/graph_rewriters_test.cc index fe52fe9e4f38cc..a549c548353276 100644 --- a/tensorflow/core/data/service/graph_rewriters_test.cc +++ b/tensorflow/core/data/service/graph_rewriters_test.cc @@ -49,7 +49,8 @@ using ::tensorflow::testing::StatusIs; using ::testing::HasSubstr; using ::testing::SizeIs; -StatusOr GetNode(const GraphDef& graph_def, absl::string_view name) { +absl::StatusOr GetNode(const GraphDef& graph_def, + absl::string_view name) { for (const NodeDef& node : graph_def.node()) { if (node.name() == name) { return node; @@ -59,7 +60,8 @@ StatusOr GetNode(const GraphDef& graph_def, absl::string_view name) { name, graph_def.ShortDebugString())); } -StatusOr GetValue(const GraphDef& graph_def, absl::string_view name) { +absl::StatusOr GetValue(const GraphDef& graph_def, + absl::string_view name) { for (const NodeDef& node : graph_def.node()) { if (node.name() == name) { return node.attr().at("value").tensor().int64_val()[0]; diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index 7dcc57dc2fdc0f..bfb4b3474e00de 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -212,7 +212,7 @@ void WorkerGrpcDataServer::MaybeStartAlternativeDataTransferServer( str_util::StringReplace(config_.data_transfer_address(), kPortPlaceholder, absl::StrCat(transfer_server_->Port()), /*replace_all=*/false)); - StatusOr compatibility_info = + absl::StatusOr compatibility_info = transfer_server_->GetCompatibilityInfo(); if (!compatibility_info.ok()) { LOG(ERROR)