Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1030,14 +1030,18 @@ const std::vector<int>& XrtComputationClient::GetDeviceMeshCoords(
}

tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology(
const string& job, int task_no, const string& worker_host_port) {
const string& job, int task_no, const string& worker_host_port,
const tensorflow::ConfigProto& config) {
tensorflow::SessionOptions session_options;
session_options.env = tensorflow::Env::Default();
session_options.target = worker_host_port;
session_options.config = config;

tensorflow::Scope root = tensorflow::Scope::NewRootScope();
tensorflow::ClientSession session(root, session_options);
string system_device = absl::StrCat("/job:", job, "/replica:0/task:", task_no,
"/device:TPU_SYSTEM:0");
XrtSessionCache::SessionMap session_map;
XrtSession* session =
GetSessionForTarget(session_cache_.get(), worker_host_port, &session_map);
tensorflow::Scope tpu_system_scope =
session->root()->WithDevice(system_device);
tensorflow::Scope tpu_system_scope = root.WithDevice(system_device);
const auto unique_name =
tpu_system_scope.GetUniqueNameForOp("ConfigureDistributedTPU");
auto builder = tensorflow::NodeBuilder(unique_name, "ConfigureDistributedTPU")
Expand All @@ -1047,15 +1051,13 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology(
tpu_system_scope.UpdateBuilder(&builder);

tensorflow::Node* result;
session->root()->UpdateStatus(
builder.Finalize(tpu_system_scope.graph(), &result));
root.UpdateStatus(builder.Finalize(tpu_system_scope.graph(), &result));
XLA_CHECK_OK(tpu_system_scope.status());
session->root()->UpdateStatus(tpu_system_scope.DoShapeInference(result));
root.UpdateStatus(tpu_system_scope.DoShapeInference(result));

std::vector<tensorflow::Tensor> outputs;
XLA_CHECK_OK(session->root()->status());
XLA_CHECK_OK(
session->session()->Run({tensorflow::Output(result, 0)}, &outputs));
XLA_CHECK_OK(root.status());
XLA_CHECK_OK(session.Run({tensorflow::Output(result, 0)}, &outputs));
XLA_CHECK_EQ(outputs.size(), 1);

tensorflow::tpu::TopologyProto topology_proto;
Expand All @@ -1075,22 +1077,24 @@ void XrtComputationClient::InitializeDevices(
tpu_workers.emplace(parsed_device.job, parsed_device.task);
}
}
for (auto& worker : tpu_workers) {
if (!tpu_workers.empty()) {
const Worker& worker = *tpu_workers.begin();
auto it = options_.workers_map.find(worker);
XLA_CHECK(it != options_.workers_map.end());

TF_LOG(INFO) << "Configuring TPU for worker " << worker.name << ":"
TF_LOG(INFO) << "Configuring TPU for master worker " << worker.name << ":"
<< worker.task_no << " at " << it->second;
tensorflow::tpu::TopologyProto worker_topology_proto =
InitializeAndFetchTopology(worker.name, worker.task_no, it->second);
InitializeAndFetchTopology(worker.name, worker.task_no, it->second,
session_cache_->GetConfig());
if (topology_proto == nullptr) {
topology_proto = absl::make_unique<tensorflow::tpu::TopologyProto>(
std::move(worker_topology_proto));
}
}
}
if (topology_proto != nullptr) {
TF_LOG(INFO) << "TPU topology: " << topology_proto->DebugString();
if (topology_proto != nullptr) {
TF_LOG(INFO) << "TPU topology: " << topology_proto->DebugString();
}
}
for (const auto& dev_target : options_.global_device_map) {
tensorflow::DeviceNameUtils::ParsedName parsed_device =
Expand Down
7 changes: 4 additions & 3 deletions third_party/xla_client/xrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,6 @@ class XrtComputationClient : public ComputationClient {
// Retrieves the mesh coordinates of a given XRT device.
const std::vector<int>& GetDeviceMeshCoords(const string& xrt_device) const;

tensorflow::tpu::TopologyProto InitializeAndFetchTopology(
const string& job, int task_no, const string& worker_host_port);

void InitializeDevices(
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto);

Expand Down Expand Up @@ -445,6 +442,10 @@ class XrtComputationClient : public ComputationClient {

static tensorflow::ConfigProto CreateConfigProto(const Options& options);

static tensorflow::tpu::TopologyProto InitializeAndFetchTopology(
const string& job, int task_no, const string& worker_host_port,
const tensorflow::ConfigProto& config);

// Checks whether a local GRPC service is required, and starts it if need it.
static void MaybeCreateLocalService(
const XrtComputationClient::Options& options);
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla_client/xrt_session_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class XrtSessionCache {
XrtSessionCache(tensorflow::ConfigProto config,
std::function<void(XrtSession*)> initfn);

const tensorflow::ConfigProto& GetConfig() const { return config_; }

// Retrieves a new session reference, for which the caller will have exclusive
// access. Once the reference object is destroyed, the session will be
// returned to the cache.
Expand Down