Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void WithAllDevices(
const std::function<void(const std::vector<Device>&)>& devfn) {
std::vector<Device> devices;
for (const auto& device_str :
xla::ComputationClient::Get()->GetAvailableDevices()) {
xla::ComputationClient::Get()->GetLocalDevices()) {
Device device(device_str);
if (device.hw_type == device_type) {
devices.push_back(device);
Expand Down
62 changes: 44 additions & 18 deletions third_party/xla_client/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/xrt_computation_client.h"
#include "tensorflow/core/util/device_name_utils.h"

namespace xla {
namespace {

ComputationClient* CreateClient() {
return ComputationClient::Create().ConsumeValueOrDie().release();
return ComputationClient::Create().release();
}

string GetTpuClusterConfigPath() {
Expand Down Expand Up @@ -51,6 +52,32 @@ string MakeGrpcEndPoint(const string& server) {
: absl::StrCat("grpc://", server);
}

void PopulateLocalDevices(XrtComputationClient::Options* options) {
string local_worker = sys_util::GetEnvString("XRT_LOCAL_WORKER", "");
string worker_name;
int task_no = -1;
if (!local_worker.empty()) {
std::vector<string> parts = absl::StrSplit(local_worker, ':');
XLA_CHECK_EQ(parts.size(), 2) << local_worker;
worker_name = std::move(parts[0]);
task_no = std::stoi(parts[1]);
}
for (auto& device_xrt_device : options->global_device_map) {
if (!worker_name.empty()) {
tensorflow::DeviceNameUtils::ParsedName parsed_device;
XLA_CHECK(tensorflow::DeviceNameUtils::ParseFullName(
device_xrt_device.second, &parsed_device) &&
parsed_device.has_job && parsed_device.has_task &&
parsed_device.has_id && parsed_device.has_type)
<< device_xrt_device.second;
if (worker_name != parsed_device.job || task_no != parsed_device.task) {
continue;
}
}
options->devices.insert(device_xrt_device.first);
}
}

void AddXrtHostDevices(const string& worker_name, int task_no,
const string& server,
std::map<string, int>* device_ordinals,
Expand All @@ -74,82 +101,81 @@ void AddXrtHostDevices(const string& worker_name, int task_no,
string xrt_device_name =
absl::StrCat("/job:", worker_name, "/replica:0/task:", task_no,
"/device:", tf_device_name);
options->device_map.emplace(device_name, xrt_device_name);
options->global_device_map.emplace(device_name, xrt_device_name);
}
}
}

StatusOr<bool> ParseEnvBasedTpuClusterConfig(
XrtComputationClient::Options* options) {
bool ParseEnvBasedTpuClusterConfig(XrtComputationClient::Options* options) {
string tpu_config = sys_util::GetEnvString("XRT_TPU_CONFIG", "");
if (tpu_config.empty()) {
return false;
}
std::map<string, int> device_ordinals;
std::vector<string> spec_parts = absl::StrSplit(tpu_config, '|');
TF_RET_CHECK(!spec_parts.empty()) << tpu_config;
XLA_CHECK(!spec_parts.empty()) << tpu_config;
for (const auto& spec : spec_parts) {
std::vector<string> host_parts = absl::StrSplit(spec, ';');
TF_RET_CHECK(host_parts.size() == 3) << spec;
XLA_CHECK_EQ(host_parts.size(), 3) << spec;
AddXrtHostDevices(host_parts[0], std::stoi(host_parts[1]), host_parts[2],
&device_ordinals, options);
}
PopulateLocalDevices(options);
options->default_device = "TPU:0";
return true;
}

Status ParseTpuClusterConfig(const string& xrt_config_path,
XrtComputationClient::Options* options) {
void ParseTpuClusterConfig(const string& xrt_config_path,
XrtComputationClient::Options* options) {
std::map<string, int> device_ordinals;
std::ifstream config_file(xrt_config_path);
string line;
while (std::getline(config_file, line)) {
if (line.compare(0, 7, "worker:") == 0) {
std::vector<string> parts =
absl::StrSplit(line.substr(7), ' ', absl::SkipWhitespace());
TF_RET_CHECK(parts.size() >= 2) << line;
XLA_CHECK_GE(parts.size(), 2) << line;
const string& worker_name = parts[0];
for (std::size_t i = 1; i < parts.size(); ++i) {
AddXrtHostDevices(worker_name, i - 1, parts[i], &device_ordinals,
options);
}
}
}
PopulateLocalDevices(options);
options->default_device = "TPU:0";
return Status::OK();
}

} // namespace

StatusOr<std::unique_ptr<ComputationClient>> ComputationClient::Create() {
std::unique_ptr<ComputationClient> ComputationClient::Create() {
XrtComputationClient::Options options;
string xrt_config_path;
if (HasXrtConfigFile(&xrt_config_path)) {
TF_LOG(INFO) << "Loading XRT configuration from " << xrt_config_path;
TF_RETURN_IF_ERROR(ParseTpuClusterConfig(xrt_config_path, &options));
ParseTpuClusterConfig(xrt_config_path, &options);
} else {
TF_ASSIGN_OR_RETURN(bool configured,
ParseEnvBasedTpuClusterConfig(&options));
if (!configured) {
if (!ParseEnvBasedTpuClusterConfig(&options)) {
string device_spec = sys_util::GetEnvString(
"XRT_DEVICE_MAP",
"TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0");
for (const auto& device_target : absl::StrSplit(device_spec, '|')) {
std::vector<string> parts = absl::StrSplit(device_target, ';');
TF_RET_CHECK(parts.size() == 2) << device_target;
XLA_CHECK_EQ(parts.size(), 2) << device_target;
if (options.default_device.empty()) {
options.default_device = parts[0];
}
options.device_map.emplace(parts[0], parts[1]);
options.global_device_map.emplace(parts[0], parts[1]);
}
string workers_spec = sys_util::GetEnvString(
"XRT_WORKERS", "tpu_worker:0;grpc://localhost:51000");
for (const auto& name_target : absl::StrSplit(workers_spec, '|')) {
std::vector<string> parts = absl::StrSplit(name_target, ';');
TF_RET_CHECK(parts.size() == 2);
XLA_CHECK_EQ(parts.size(), 2) << name_target;
options.workers_map.emplace(ParseWorker(parts[0]),
MakeGrpcEndPoint(parts[1]));
}
PopulateLocalDevices(&options);
}
}
return std::unique_ptr<ComputationClient>(new XrtComputationClient(options));
Expand Down
6 changes: 4 additions & 2 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class ComputationClient {
std::vector<Input> inputs;
};

static StatusOr<std::unique_ptr<ComputationClient>> Create();
static std::unique_ptr<ComputationClient> Create();

virtual ~ComputationClient() {}

Expand Down Expand Up @@ -220,7 +220,9 @@ class ComputationClient {

virtual size_t GetNumDevices() const = 0;

virtual std::vector<string> GetAvailableDevices() const = 0;
virtual std::vector<string> GetLocalDevices() const = 0;

virtual std::vector<string> GetAllDevices() const = 0;

virtual void SetRngSeed(size_t seed) = 0;

Expand Down
64 changes: 41 additions & 23 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,25 @@ XrtComputationClient::XrtComputationClient(
alloc_session_cache_ = absl::make_unique<XrtSessionCache>(config, nullptr);

auto default_device_target =
options_.device_map.find(options_.default_device);
XLA_CHECK(default_device_target != options_.device_map.end());
for (const auto& dev_target : options_.device_map) {
TF_LOG(INFO) << "XRT device " << dev_target.first << " -> "
options_.global_device_map.find(options_.default_device);
XLA_CHECK(default_device_target != options_.global_device_map.end());
for (auto& device : options_.devices) {
XLA_CHECK(options_.global_device_map.find(device) !=
options_.global_device_map.end())
<< "Missing device in global map: " << device;
}
for (const auto& dev_target : options_.global_device_map) {
const char* tag =
options_.devices.count(dev_target.first) > 0 ? "LOCAL" : "REMOTE";
TF_LOG(INFO) << "XRT device (" << tag << ") " << dev_target.first << " -> "
<< dev_target.second;
}
TF_LOG(INFO) << "XRT default device: " << default_device_target->first;
for (auto& worker_target : options_.workers_map) {
TF_LOG(INFO) << "Worker " << worker_target.second
<< " for /job:" << worker_target.first.name
<< "/replica:0/task:" << worker_target.first.task_no;
}
MaybeCreateLocalService(options_);
InitializeDevices();
StartHandleReleaser();
Expand Down Expand Up @@ -301,7 +313,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
util::MultiWait mwait(instances.size());
std::vector<ProgramShape> program_shapes(instances.size());
std::vector<ComputationPtr> results(instances.size());
std::vector<string> serialized_computations(instances.size());
std::vector<CompilationCacheKey> cache_keys(instances.size());
XrtSessionCache::SessionMap session_map;
std::map<XrtSession*, SessionWork> session_work_map;
for (size_t i = 0; i < instances.size(); ++i) {
Expand All @@ -310,11 +322,12 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
std::unique_ptr<xrt::XLAComputation> xrt_computation =
CreateXrtComputation(instance.computation, instance.devices,
instance.output_shape);
string serialized_computation = xrt_computation->SerializeAsString();

auto computation_ptr = compilation_cache_.Get(serialized_computation);
CompilationCacheKey cache_key(
GetResourceDomain(instance.compilation_device),
xrt_computation->SerializeAsString());
auto computation_ptr = compilation_cache_.Get(cache_key);
if (computation_ptr == nullptr) {
serialized_computations[i] = std::move(serialized_computation);
cache_keys[i] = std::move(cache_key);
program_shapes[i] =
ProgramShape(xrt_computation->config().program_shape());

Expand All @@ -330,7 +343,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
const XrtSession::CachedNode& cached_node = GetCompileNode(
session, device_scope, instance.compilation_device);
session_work->feed_inputs.insert(
{cached_node.holders[0], serialized_computations[i]});
{cached_node.holders[0], cache_keys[i].serialized_computation});
session_work->outputs_handles.push_back(cached_node.outputs[0]);
session_work->index_mapping.push_back(i);
}
Expand Down Expand Up @@ -365,8 +378,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
instance->compilation_device);
++output_index;

compilation_cache_.Add(std::move(serialized_computations[li]),
results[li]);
compilation_cache_.Add(std::move(cache_keys[li]), results[li]);
CreateCompileHandlesCounter()->AddValue(1);
}
};
Expand Down Expand Up @@ -768,8 +780,9 @@ string XrtComputationClient::GetEffectiveDevice(const string& device) const {

const string& XrtComputationClient::TorchDeviceToXrtDevice(
const string& device) const {
auto device_target = options_.device_map.find(GetEffectiveDevice(device));
XLA_CHECK(device_target != options_.device_map.end())
auto device_target =
options_.global_device_map.find(GetEffectiveDevice(device));
XLA_CHECK(device_target != options_.global_device_map.end())
<< "Unable to find device: " << device;
return device_target->second;
}
Expand Down Expand Up @@ -933,7 +946,7 @@ void XrtComputationClient::ReleaseHandles(

void XrtComputationClient::StartHandleReleaser() {
int64 num_threads = sys_util::GetEnvInt("XLA_HANDLE_RELEASE_THREADS",
options_.device_map.size());
options_.devices.size());
triggered_task_.reset(
new util::TriggeredTask([this]() { HandleReleaser(); }, num_threads));
}
Expand Down Expand Up @@ -1045,13 +1058,14 @@ tensorflow::tpu::TopologyProto XrtComputationClient::InitializeAndFetchTopology(

void XrtComputationClient::InitializeDevices() {
std::set<Worker> tpu_workers;
for (const auto& dev_target : options_.device_map) {
for (const auto& device : options_.devices) {
const string& xrt_device = TorchDeviceToXrtDevice(device);
tensorflow::DeviceNameUtils::ParsedName parsed_device;
XLA_CHECK(tensorflow::DeviceNameUtils::ParseFullName(dev_target.second,
XLA_CHECK(tensorflow::DeviceNameUtils::ParseFullName(xrt_device,
&parsed_device) &&
parsed_device.has_job && parsed_device.has_task &&
parsed_device.has_id && parsed_device.has_type)
<< dev_target.second;
<< xrt_device;
if (parsed_device.type == "TPU") {
tpu_workers.emplace(parsed_device.job, parsed_device.task);
}
Expand All @@ -1070,7 +1084,7 @@ void XrtComputationClient::InitializeDevices() {
}
}

for (const auto& dev_target : options_.device_map) {
for (const auto& dev_target : options_.global_device_map) {
tensorflow::DeviceNameUtils::ParsedName parsed_device;
XLA_CHECK(tensorflow::DeviceNameUtils::ParseFullName(dev_target.second,
&parsed_device) &&
Expand Down Expand Up @@ -1129,12 +1143,16 @@ string XrtComputationClient::GetDefaultDevice() const {
}

size_t XrtComputationClient::GetNumDevices() const {
return options_.device_map.size();
return options_.devices.size();
}

std::vector<string> XrtComputationClient::GetLocalDevices() const {
return std::vector<string>(options_.devices.begin(), options_.devices.end());
}

std::vector<string> XrtComputationClient::GetAvailableDevices() const {
std::vector<string> XrtComputationClient::GetAllDevices() const {
std::vector<string> devices;
for (const auto& dev_target : options_.device_map) {
for (const auto& dev_target : options_.global_device_map) {
devices.push_back(dev_target.first);
}
return devices;
Expand All @@ -1156,7 +1174,7 @@ void XrtComputationClient::InitSession(XrtSession* session) const {
{16, &XrtComputationClient::GetReleaseCompileHandleNode},
{16, &XrtComputationClient::GetSubTupleNode},
};
auto devices = GetAvailableDevices();
auto devices = GetLocalDevices();
for (auto& device : devices) {
// HACK: The XRT ops on the remote GRPC service has only recently been
// enabled, so until TF 1.14 is out, we cannot add XRT ops on CPU.
Expand Down
42 changes: 38 additions & 4 deletions third_party/xla_client/xrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,13 @@ class XrtComputationClient : public ComputationClient {
// Maps a PyTorch device ID (example, "GPU:0", "TPU:0") to the full
// coordinates in TF device format
// (ie, /job:tpu_worker/replica:0/task:0/device:TPU:0), of the worker
// exposing that device.
std::map<string, string> device_map;
// exposing that device. These devices are all the devices present within
// the TPU mesh.
std::map<string, string> global_device_map;
// These are the devices that this instance of PyTorch is handling. These
// devices are in the form of "CPU:0", "TPU:3", ... For each of these
// devices, there is an entry within the global_device_map.
std::set<string> devices;
// Maps a TPU Worker with an EndPoint.
std::map<Worker, string> workers_map;
};
Expand Down Expand Up @@ -162,11 +167,40 @@ class XrtComputationClient : public ComputationClient {

size_t GetNumDevices() const override;

std::vector<string> GetAvailableDevices() const override;
std::vector<string> GetLocalDevices() const override;

std::vector<string> GetAllDevices() const override;

void SetRngSeed(size_t seed) override;

private:
// The data structure used for the key in the compilation cache. Compilations
// handles are valid within given domain (essentially the host+port worker
// endpoints), so the key must include the domain.
struct CompilationCacheKey {
struct Hash {
size_t operator()(const CompilationCacheKey& entry) const {
util::PartialHasher<string, 4096> hasher;
return tensorflow::Hash64(entry.domain.data(), entry.domain.size(),
hasher(entry.serialized_computation));
}
};

CompilationCacheKey(string domain, string serialized_computation)
: domain(std::move(domain)),
serialized_computation(std::move(serialized_computation)) {}
CompilationCacheKey() = default;
CompilationCacheKey(CompilationCacheKey&&) = default;
CompilationCacheKey& operator=(CompilationCacheKey&&) = default;
bool operator==(const CompilationCacheKey& rhs) const {
return domain == rhs.domain &&
serialized_computation == rhs.serialized_computation;
}

string domain;
string serialized_computation;
};

// When we split a batch operation into per-session batches, we use this data
// structure to collect the per-session work.
struct SessionWork {
Expand Down Expand Up @@ -411,7 +445,7 @@ class XrtComputationClient : public ComputationClient {
std::unique_ptr<XrtSessionCache> session_cache_;
std::unique_ptr<XrtSessionCache> alloc_session_cache_;
std::unique_ptr<util::TriggeredTask> triggered_task_;
util::Cache<string, Computation, util::PartialHasher<string, 4096>>
util::Cache<CompilationCacheKey, Computation, CompilationCacheKey::Hash>
compilation_cache_;
std::atomic<size_t> rng_seed_;
// Access to the following members must be done while holding lock_.
Expand Down
Loading