Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/test_mp_rendezvous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import re
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp


def _mp_fn(index):
ordinal = xm.get_ordinal()
print('Core {} waiting for rendezvous ...'.format(ordinal))
data = xmp.rendezvous('rendezvous_test', 'ORD={}'.format(ordinal))
print('Core {} got rendezvous!'.format(ordinal))
for i in range(0, len(data)):
m = re.match(r'ORD=(\d+)', data[i])
assert m, 'Bad payload format: {}'.format(data[i])
xordinal = int(m.group(1))
assert i == xordinal, 'Payload {} got ordinal {}'.format(i, xordinal)


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
89 changes: 70 additions & 19 deletions third_party/xla_client/mesh_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <mutex>
#include <unordered_map>

#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/mesh_service.grpc.pb.h"
Expand Down Expand Up @@ -65,34 +66,70 @@ class MeshServiceImpl : public grpc::MeshService::Service {
grpc::RendezvousResponse* response) override;

private:
struct RendezvousData {
explicit RendezvousData(size_t count) : mwait(count), release_count(0) {}
class RendezvousData {
public:
explicit RendezvousData(size_t count)
: mwait_(count), release_count_(0), payloads_(count) {}

util::MultiWait mwait;
std::atomic<size_t> release_count;
bool Release() { return release_count_.fetch_add(1) == 0; }

void SetPayload(size_t ordinal, std::string payload) {
std::lock_guard<std::mutex> lock(lock_);
if (ordinal >= payloads_.size()) {
status_ = ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT,
absl::StrCat("Invalid ordinal: ", ordinal));
} else {
payloads_[ordinal] = std::move(payload);
}
}

::grpc::Status Wait() {
::grpc::Status status =
ToGrpcStatus(xla::util::CheckedCall([&]() { mwait_.Wait(); }));
if (status.ok()) {
std::lock_guard<std::mutex> lock(lock_);
status = status_;
}
return status;
}

void Done() { mwait_.Done(); }

const std::vector<std::string>& Payloads() const { return payloads_; };

private:
std::mutex lock_;
util::MultiWait mwait_;
std::atomic<size_t> release_count_;
std::vector<std::string> payloads_;
::grpc::Status status_;
};

std::shared_ptr<RendezvousData> GetRendezvous(const std::string& tag) {
std::lock_guard<std::mutex> lock(lock_);
auto it = rendezvous_map_.find(tag);
if (it == rendezvous_map_.end()) {
it =
rendezvous_map_
.emplace(tag,
std::make_shared<RendezvousData>(config_.workers_size()))
.first;
it = rendezvous_map_
.emplace(tag, std::make_shared<RendezvousData>(
GetTopologyDeviceCount()))
.first;
}
return it->second;
}

void ReleaseRendezvous(const std::string& tag,
const std::shared_ptr<RendezvousData>& rendezvous) {
if (rendezvous->release_count.fetch_add(1) == 0) {
if (rendezvous->Release()) {
std::lock_guard<std::mutex> lock(lock_);
rendezvous_map_.erase(tag);
}
}

size_t GetTopologyDeviceCount() const {
return config_.proto().num_tasks() *
config_.proto().num_tpu_devices_per_task();
}

std::mutex lock_;
grpc::Config config_;
std::unordered_map<std::string, std::shared_ptr<RendezvousData>>
Expand All @@ -110,13 +147,19 @@ ::grpc::Status MeshServiceImpl::Rendezvous(
::grpc::ServerContext* context, const grpc::RendezvousRequest* request,
grpc::RendezvousResponse* response) {
auto rendezvous = GetRendezvous(request->tag());
rendezvous->mwait.Done();
TF_VLOG(3) << "Entering rendezvous: tag=" << request->tag()
<< " peer=" << context->peer();
::grpc::Status status =
ToGrpcStatus(xla::util::CheckedCall([&]() { rendezvous->mwait.Wait(); }));
TF_VLOG(3) << "Exiting rendezvous: tag=" << request->tag()
<< " peer=" << context->peer() << " status=" << status;
rendezvous->SetPayload(request->ordinal(), request->payload());
rendezvous->Done();
TF_VLOG(3) << "Entering rendezvous: ordinal=" << request->ordinal()
<< " tag=" << request->tag() << " peer=" << context->peer();
::grpc::Status status = rendezvous->Wait();
TF_VLOG(3) << "Exiting rendezvous: ordinal=" << request->ordinal()
<< " tag=" << request->tag() << " peer=" << context->peer()
<< " status=" << status;
if (status.ok()) {
for (auto& payload : rendezvous->Payloads()) {
response->add_payloads(payload);
}
}
ReleaseRendezvous(request->tag(), rendezvous);
return status;
}
Expand Down Expand Up @@ -190,17 +233,25 @@ grpc::Config MeshClient::GetConfig() const {
return std::move(*response.mutable_config());
}

void MeshClient::Rendezvous(const std::string& tag) const {
std::vector<std::string> MeshClient::Rendezvous(
int ordinal, const std::string& tag, const std::string& payload) const {
::grpc::ClientContext context;
grpc::RendezvousRequest reqeust;
grpc::RendezvousResponse response;
reqeust.set_tag(tag);
TF_VLOG(3) << "Waiting for rendezvous: " << tag;
reqeust.set_payload(payload);
reqeust.set_ordinal(ordinal);
TF_VLOG(3) << "Waiting for rendezvous: ordinal=" << ordinal << " tag=" << tag;
::grpc::Status status = impl_->stub->Rendezvous(&context, reqeust, &response);
TF_VLOG(3) << "Rendezvous wait complete: " << tag;
if (!status.ok()) {
XLA_ERROR() << "Failed to meet rendezvous '" << tag << "': " << status;
}
std::vector<std::string> rv_payloads;
for (auto& rv_payload : response.payloads()) {
rv_payloads.push_back(rv_payload);
}
return rv_payloads;
}

} // namespace service
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla_client/mesh_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <memory>
#include <string>
#include <vector>

#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_client/mesh_service.pb.h"
Expand Down Expand Up @@ -32,7 +33,8 @@ class MeshClient {

grpc::Config GetConfig() const;

void Rendezvous(const std::string& tag) const;
std::vector<std::string> Rendezvous(int ordinal, const std::string& tag,
const std::string& payload) const;

private:
MeshClient(const std::string& address);
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla_client/mesh_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ message GetConfigResponse {

message RendezvousRequest {
required string tag = 1;
required bytes payload = 2;
required uint32 ordinal = 3;
}

message RendezvousResponse {
repeated bytes payloads = 1;
}

service MeshService {
Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/mesh_service.h"
#include "tensorflow/compiler/xla/xla_client/metrics.h"
#include "tensorflow/compiler/xla/xla_client/metrics_reader.h"
#include "tensorflow/compiler/xla/xla_client/multi_wait.h"
Expand Down Expand Up @@ -257,6 +258,16 @@ py::object GetRevisions() {
return py_dict;
}

std::vector<std::string> Rendezvous(int ordinal, const std::string& tag,
const std::string& payload) {
xla::service::MeshClient* mesh_client = xla::service::MeshClient::Get();
std::vector<std::string> payloads;
if (mesh_client != nullptr) {
payloads = mesh_client->Rendezvous(ordinal, tag, payload);
}
return payloads;
}

std::shared_ptr<xla::util::RecordReader> CreateRecordReader(
std::string path, const std::string& compression, xla::int64 buffer_size) {
return std::make_shared<xla::util::RecordReader>(std::move(path), compression,
Expand Down Expand Up @@ -490,6 +501,10 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_get_replication_devices_count", []() {
return xla::ComputationClient::Get()->GetReplicationDevices().size();
});
m.def("_xla_rendezvous",
[](int ordinal, const std::string& tag, const std::string& payload) {
return Rendezvous(ordinal, tag, payload);
});

py::class_<ir::Value, std::shared_ptr<ir::Value>>(m, "IrValue");
m.def("_xla_create_token", []() {
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/distributed/xla_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,13 @@ def spawn(fn,
join=join,
daemon=daemon,
start_method=start_method)


def rendezvous(tag, payload=''):
"""Waits for all the mesh clients to reach the named rendezvous.

Args:
tag (string): The name of the rendezvous to join.
payload (string, optional): The payload to be sent to the rendezvous.
"""
return torch_xla._XLAC._xla_rendezvous(xm.get_ordinal(), tag, payload)