Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions third_party/xla_client/multi_wait.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void MultiWait::Done() {
{
std::lock_guard<std::mutex> lock(mutex_);
completed_count_ += 1;
notify = completed_count_ >= count_;
notify = completed_count_ == count_;
}
if (notify) {
cv_.notify_all();
Expand Down Expand Up @@ -45,17 +45,27 @@ void MultiWait::Reset(size_t count) {
}

std::function<void()> MultiWait::Completer(std::function<void()> func) {
auto completer = [this, func = std::move(func)]() {
try {
func();
} catch (...) {
std::lock_guard<std::mutex> lock(mutex_);
exptr_ = std::current_exception();
}
Done();
auto completer = [this, func = std::move(func)]() { Complete(func); };
return completer;
}

std::function<void()> MultiWait::Completer(std::shared_ptr<MultiWait> mwait,
std::function<void()> func) {
auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
mwait->Complete(func);
};
return completer;
}

void MultiWait::Complete(const std::function<void()>& func) {
try {
func();
} catch (...) {
std::lock_guard<std::mutex> lock(mutex_);
exptr_ = std::current_exception();
}
Done();
}

} // namespace util
} // namespace xla
12 changes: 11 additions & 1 deletion third_party/xla_client/multi_wait.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>

#include "tensorflow/compiler/xla/types.h"
Expand Down Expand Up @@ -31,10 +32,19 @@ class MultiWait {

// Creates a completer functor which signals the mult wait object once func
// has completed. Handles exceptions by signaling the multi wait with the
// proper status value.
// proper status value. This API returns a function which captures a MultiWait
// reference, so care must be taken such that the reference remains valid for
// the whole lifetime of the returned function.
std::function<void()> Completer(std::function<void()> func);

// Similar as the above API, but with explicit capture of the MultiWait shared
// pointer.
static std::function<void()> Completer(std::shared_ptr<MultiWait> mwait,
std::function<void()> func);

private:
void Complete(const std::function<void()>& func);

std::mutex mutex_;
std::condition_variable cv_;
size_t count_ = 0;
Expand Down
48 changes: 27 additions & 21 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
}
XLA_COUNTER("XrtPartitionedTransferToServer", 1);

util::MultiWait mwait(partitions.size());
auto mwait = std::make_shared<util::MultiWait>(partitions.size());
std::vector<DataPtr> results(tensors.size());
for (size_t i = 0; i < partitions.size(); ++i) {
auto sender = [&, i]() {
Expand All @@ -316,9 +316,10 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
results[base_index + r] = std::move(partitions_results[r]);
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(sender)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(sender)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand All @@ -330,7 +331,7 @@ XrtComputationClient::TransferToServerInternal(
std::mutex lock;
XrtSessionCache::SessionMap session_map;
int64 total_size = 0;
util::MultiWait mwait(tensors.size());
auto mwait = std::make_shared<util::MultiWait>(tensors.size());
std::map<XrtSession*, SessionWork> session_work_map;
{
metrics::TimedSection timed(TransferToServerTransformMetric());
Expand Down Expand Up @@ -363,13 +364,14 @@ XrtComputationClient::TransferToServerInternal(
total_size += tdata.size();
}
};
env::ScheduleClosure(mwait.Completer(std::move(converter)));
env::ScheduleClosure(
util::MultiWait::Completer(mwait, std::move(converter)));
}
mwait.Wait();
mwait->Wait();
}
OutboundDataMetric()->AddSample(total_size);

mwait.Reset(session_work_map.size());
mwait->Reset(session_work_map.size());
std::vector<DataPtr> results(tensors.size());
for (auto& session_session_work : session_work_map) {
XrtSession* session = session_session_work.first;
Expand All @@ -388,9 +390,10 @@ XrtComputationClient::TransferToServerInternal(
}
CreateDataHandlesCounter()->AddValue(outputs.size());
};
env::ScheduleIoClosure(mwait.Completer(std::move(runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(runner)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand Down Expand Up @@ -426,7 +429,7 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
session_work->index_mapping.push_back(i);
}

util::MultiWait mwait(session_work_map.size());
auto mwait = std::make_shared<util::MultiWait>(session_work_map.size());
std::atomic<int64> total_size(0);
std::vector<Literal> results(handles.size());
for (auto& session_session_work : session_work_map) {
Expand All @@ -446,9 +449,10 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
total_size += results[li].size_bytes();
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(runner)));
}
mwait.Wait();
mwait->Wait();
InboundDataMetric()->AddSample(total_size.load());
return results;
}
Expand All @@ -458,7 +462,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
metrics::TimedSection timed(CompileMetric());

std::mutex lock;
util::MultiWait mwait(instances.size());
auto mwait = std::make_shared<util::MultiWait>(instances.size());
std::vector<ProgramShape> program_shapes(instances.size());
std::vector<ComputationPtr> results(instances.size());
std::vector<CompilationCacheKey> cache_keys(instances.size());
Expand Down Expand Up @@ -499,10 +503,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
results[i] = computation_ptr;
}
};
env::ScheduleClosure(mwait.Completer(std::move(builder)));
env::ScheduleClosure(util::MultiWait::Completer(mwait, std::move(builder)));
}
mwait.Wait();
mwait.Reset(session_work_map.size());
mwait->Wait();
mwait->Reset(session_work_map.size());

for (auto& session_and_work : session_work_map) {
XrtSession* session = session_and_work.first;
Expand Down Expand Up @@ -532,9 +536,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
CreateCompileHandlesCounter()->AddValue(1);
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(session_runner)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand Down Expand Up @@ -626,7 +631,7 @@ XrtComputationClient::RunComputations(
}
XLA_CHECK_EQ(computations.size(), devices.size());

util::MultiWait mwait(session_replicas.size());
auto mwait = std::make_shared<util::MultiWait>(session_replicas.size());
std::vector<std::vector<DataPtr>> results(devices.size());
for (auto& sess_replica : session_replicas) {
XrtSession* session = sess_replica.first;
Expand Down Expand Up @@ -655,9 +660,10 @@ XrtComputationClient::RunComputations(
GetEffectiveDevice(devices[replica]));
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(session_runner)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset,
std::min<size_t>(num_threads, std::thread::hardware_concurrency());
size_t block_size = size / num_threads;

xla::util::MultiWait mwait(num_threads);
auto mwait = std::make_shared<xla::util::MultiWait>(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
auto reader = [&, i]() {
uint64_t base = static_cast<uint64_t>(i) * block_size;
Expand All @@ -491,9 +491,10 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset,
XLA_CHECK_OK(
file->Read(offset + base, tsize, &result, buffer.get() + base));
};
xla::env::ScheduleIoClosure(mwait.Completer(std::move(reader)));
xla::env::ScheduleIoClosure(
xla::util::MultiWait::Completer(mwait, std::move(reader)));
}
mwait.Wait();
mwait->Wait();
}
return py::bytes(buffer.get(), size);
}
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,15 +423,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape,
std::vector<xla::int64> iter_dims = GetIterationDimensions(dest_shape);
std::vector<CopyPartition> parts =
CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front());
xla::util::MultiWait mwait(parts.size());
auto mwait = std::make_shared<xla::util::MultiWait>(parts.size());
for (size_t i = 0; i < parts.size(); ++i) {
auto copy_fn = [&, i]() {
SlicedCopy<SType, DType>(dest_shape.dimensions(), src_data, src_strides,
dest_data, dest_strides, iter_dims, parts[i]);
};
xla::env::ScheduleClosure(mwait.Completer(std::move(copy_fn)));
xla::env::ScheduleClosure(
xla::util::MultiWait::Completer(mwait, std::move(copy_fn)));
}
mwait.Wait();
mwait->Wait();
}
}

Expand Down