From 9255becd9597c75a3d7f4cfba8fc2ccf335d0e73 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 4 Dec 2023 23:40:04 +0000 Subject: [PATCH 1/5] Not creating the coordinator servie for single process. --- torch_xla/csrc/runtime/xla_coordinator.cc | 9 +++++++-- torch_xla/csrc/runtime/xla_coordinator.h | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc index 72855d8681ea..4840d6aeb7bc 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -9,6 +9,9 @@ namespace runtime { XlaCoordinator::XlaCoordinator(int global_rank, int world_size, std::string master_addr, std::string port) { + if (world_size <= 1) { + return; + } std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":"); if (global_rank == 0) { xla::CoordinationServiceImpl::Options service_options; @@ -43,8 +46,10 @@ XlaCoordinator::~XlaCoordinator() { } std::shared_ptr XlaCoordinator::GetClient() { - XLA_CHECK(dist_runtime_client_ != nullptr) - << "distributed runtime client is null."; + if (world_size_ > 1) { + XLA_CHECK(dist_runtime_client_ != nullptr) + << "distributed runtime client is null."; + } return dist_runtime_client_; } diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h index ae85c79a9416..fd4721d6398a 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.h +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -45,6 +45,7 @@ class XlaCoordinator { std::unique_ptr dist_runtime_service_; std::shared_ptr dist_runtime_client_; std::unique_ptr preemption_sync_manager_; + int world_size_; }; } // namespace runtime From 2c6c93b8b726343eea2d8a2efc20a952bf854232 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 6 Dec 2023 23:33:15 +0000 Subject: [PATCH 2/5] fix comments --- .../csrc/runtime/pjrt_computation_client.cc | 40 ++++++++++--------- torch_xla/csrc/runtime/xla_coordinator.cc | 7 ---- torch_xla/csrc/runtime/xla_coordinator.h | 1 - 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 871760d4802a..2772cea32dad 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -177,27 +177,29 @@ PjRtComputationClient::PjRtComputationClient() { std::string port = runtime::sys_util::GetEnvString( "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - // Use the XlaCoordinator as the distributed key-value store. - coordinator_ = std::make_unique( - global_process_rank, global_world_size, master_addr, port); - std::shared_ptr distributed_client = - coordinator_->GetClient(); - auto allowed_devices = - std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; - if (distributed_client != nullptr) { - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; + auto allowed_devices = + std::make_optional>(std::set{local_process_rank}); + if (global_world_size > 1) { + // Use the XlaCoordinator as the distributed key-value store. + coordinator_ = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator_->GetClient(); + if (distributed_client != nullptr) { + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; + } } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc index 4840d6aeb7bc..dd33bc11e033 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -9,9 +9,6 @@ namespace runtime { XlaCoordinator::XlaCoordinator(int global_rank, int world_size, std::string master_addr, std::string port) { - if (world_size <= 1) { - return; - } std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":"); if (global_rank == 0) { xla::CoordinationServiceImpl::Options service_options; @@ -46,10 +43,6 @@ XlaCoordinator::~XlaCoordinator() { } std::shared_ptr XlaCoordinator::GetClient() { - if (world_size_ > 1) { - XLA_CHECK(dist_runtime_client_ != nullptr) - << "distributed runtime client is null."; - } return dist_runtime_client_; } diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h index fd4721d6398a..ae85c79a9416 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.h +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -45,7 +45,6 @@ class XlaCoordinator { std::unique_ptr dist_runtime_service_; std::shared_ptr dist_runtime_client_; std::unique_ptr preemption_sync_manager_; - int world_size_; }; } // namespace runtime From 9f8c2a87965c725ae4de0f5829012e22e68e12d4 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 6 Dec 2023 23:46:37 +0000 Subject: [PATCH 3/5] revert a unwanted change --- torch_xla/csrc/runtime/xla_coordinator.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc index dd33bc11e033..72855d8681ea 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -43,6 +43,8 @@ XlaCoordinator::~XlaCoordinator() { } std::shared_ptr XlaCoordinator::GetClient() { + XLA_CHECK(dist_runtime_client_ != nullptr) + << "distributed runtime client is null."; return dist_runtime_client_; } From 8b0f9f5756b1ed3f0785ebf59605ad78da9e4ab5 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 7 Dec 2023 06:12:48 +0000 Subject: [PATCH 4/5] fix linter --- .../csrc/runtime/pjrt_computation_client.cc | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 2772cea32dad..949a5af56ace 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -187,19 +187,18 @@ PjRtComputationClient::PjRtComputationClient() { global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = coordinator_->GetClient(); - if (distributed_client != nullptr) { - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; - } + XLA_CHECK(distributed_client != nullptr); + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; From a5ec7e8345f5d684c428c355f15876a0be141334 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 7 Dec 2023 21:39:25 +0000 Subject: [PATCH 5/5] fix one last comment --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 949a5af56ace..187c2f35c7e1 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -187,7 +187,6 @@ PjRtComputationClient::PjRtComputationClient() { global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = coordinator_->GetClient(); - XLA_CHECK(distributed_client != nullptr); std::string key_prefix = "gpu:"; kv_get = [distributed_client, key_prefix]( std::string_view k,