diff --git a/third_party/xla_client/xrt_computation_client.cc b/third_party/xla_client/xrt_computation_client.cc index 49532d0e1acb..885c65ec0fa1 100644 --- a/third_party/xla_client/xrt_computation_client.cc +++ b/third_party/xla_client/xrt_computation_client.cc @@ -1030,14 +1030,18 @@ const std::vector& XrtComputationClient::GetDeviceMeshCoords( } tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology( - const string& job, int task_no, const string& worker_host_port) { + const string& job, int task_no, const string& worker_host_port, + const tensorflow::ConfigProto& config) { + tensorflow::SessionOptions session_options; + session_options.env = tensorflow::Env::Default(); + session_options.target = worker_host_port; + session_options.config = config; + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + tensorflow::ClientSession session(root, session_options); string system_device = absl::StrCat("/job:", job, "/replica:0/task:", task_no, "/device:TPU_SYSTEM:0"); - XrtSessionCache::SessionMap session_map; - XrtSession* session = - GetSessionForTarget(session_cache_.get(), worker_host_port, &session_map); - tensorflow::Scope tpu_system_scope = - session->root()->WithDevice(system_device); + tensorflow::Scope tpu_system_scope = root.WithDevice(system_device); const auto unique_name = tpu_system_scope.GetUniqueNameForOp("ConfigureDistributedTPU"); auto builder = tensorflow::NodeBuilder(unique_name, "ConfigureDistributedTPU") @@ -1047,15 +1051,13 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology( tpu_system_scope.UpdateBuilder(&builder); tensorflow::Node* result; - session->root()->UpdateStatus( - builder.Finalize(tpu_system_scope.graph(), &result)); + root.UpdateStatus(builder.Finalize(tpu_system_scope.graph(), &result)); XLA_CHECK_OK(tpu_system_scope.status()); - session->root()->UpdateStatus(tpu_system_scope.DoShapeInference(result)); + root.UpdateStatus(tpu_system_scope.DoShapeInference(result)); std::vector outputs; - XLA_CHECK_OK(session->root()->status()); - XLA_CHECK_OK( - session->session()->Run({tensorflow::Output(result, 0)}, &outputs)); + XLA_CHECK_OK(root.status()); + XLA_CHECK_OK(session.Run({tensorflow::Output(result, 0)}, &outputs)); XLA_CHECK_EQ(outputs.size(), 1); tensorflow::tpu::TopologyProto topology_proto; @@ -1075,22 +1077,24 @@ void XrtComputationClient::InitializeDevices( tpu_workers.emplace(parsed_device.job, parsed_device.task); } } - for (auto& worker : tpu_workers) { + if (!tpu_workers.empty()) { + const Worker& worker = *tpu_workers.begin(); auto it = options_.workers_map.find(worker); XLA_CHECK(it != options_.workers_map.end()); - TF_LOG(INFO) << "Configuring TPU for worker " << worker.name << ":" + TF_LOG(INFO) << "Configuring TPU for master worker " << worker.name << ":" << worker.task_no << " at " << it->second; tensorflow::tpu::TopologyProto worker_topology_proto = - InitializeAndFetchTopology(worker.name, worker.task_no, it->second); + InitializeAndFetchTopology(worker.name, worker.task_no, it->second, + session_cache_->GetConfig()); if (topology_proto == nullptr) { topology_proto = absl::make_unique( std::move(worker_topology_proto)); } } - } - if (topology_proto != nullptr) { - TF_LOG(INFO) << "TPU topology: " << topology_proto->DebugString(); + if (topology_proto != nullptr) { + TF_LOG(INFO) << "TPU topology: " << topology_proto->DebugString(); + } } for (const auto& dev_target : options_.global_device_map) { tensorflow::DeviceNameUtils::ParsedName parsed_device = diff --git a/third_party/xla_client/xrt_computation_client.h b/third_party/xla_client/xrt_computation_client.h index 0b0e733673e4..3121bc8b77fd 100644 --- a/third_party/xla_client/xrt_computation_client.h +++ b/third_party/xla_client/xrt_computation_client.h @@ -291,9 +291,6 @@ class XrtComputationClient : public ComputationClient { // Retrieves the mesh coordinates of a given XRT device. const std::vector& GetDeviceMeshCoords(const string& xrt_device) const; - tensorflow::tpu::TopologyProto InitializeAndFetchTopology( - const string& job, int task_no, const string& worker_host_port); - void InitializeDevices( std::unique_ptr topology_proto); @@ -445,6 +442,10 @@ class XrtComputationClient : public ComputationClient { static tensorflow::ConfigProto CreateConfigProto(const Options& options); + static tensorflow::tpu::TopologyProto InitializeAndFetchTopology( + const string& job, int task_no, const string& worker_host_port, + const tensorflow::ConfigProto& config); + // Checks whether a local GRPC service is required, and starts it if need it. static void MaybeCreateLocalService( const XrtComputationClient::Options& options); diff --git a/third_party/xla_client/xrt_session_cache.h b/third_party/xla_client/xrt_session_cache.h index 81b25bfcda10..2bb1fdc68f5c 100644 --- a/third_party/xla_client/xrt_session_cache.h +++ b/third_party/xla_client/xrt_session_cache.h @@ -68,6 +68,8 @@ class XrtSessionCache { XrtSessionCache(tensorflow::ConfigProto config, std::function initfn); + const tensorflow::ConfigProto& GetConfig() const { return config_; } + // Retrieves a new session reference, for which the caller will have exclusive // access. Once the reference object is destroyed, the session will be // returned to the cache.