Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def make_relative_rpath(path):


extra_compile_args = [
'-std=c++14',
'-Wno-sign-compare',
'-Wno-deprecated-declarations',
'-Wno-return-type',
Expand Down
1 change: 1 addition & 0 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ set(TORCH_XLA_TEST_SOURCES
test_mayberef.cpp
test_replication.cpp
test_tensor.cpp
test_xla_util_cache.cpp
torch_xla_test.cpp
)

Expand Down
6 changes: 3 additions & 3 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ void TestSingleReplication(const std::vector<Device>& devices) {
}
auto tensors_data = CreateTensorsData(tensors, device_strings);

std::vector<std::vector<std::shared_ptr<xla::ComputationClient::Data>>>
results(device_strings.size());
std::vector<std::vector<xla::ComputationClient::DataPtr>> results(
device_strings.size());
xla::xla_util::MultiWait mwait(device_strings.size());
xla::ComputationClient::ExecuteComputationOptions exec_options;
for (size_t i = 0; i < device_strings.size(); ++i) {
auto executor = [&, i]() {
results[i] = xla::ComputationClient::Get()->ExecuteComputation(
*compiled_computations[i], {tensors_data[i].get()}, device_strings[i],
*compiled_computations[i], {tensors_data[i]}, device_strings[i],
exec_options);
};
xla::xla_env::ScheduleIoClosure(mwait.Completer(std::move(executor)));
Expand Down
33 changes: 33 additions & 0 deletions test/cpp/test_xla_util_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <gtest/gtest.h>

#include <string>

#include "cpp_test_util.h"
#include "tensorflow/compiler/xla/xla_client/cache.h"
#include "tensorflow/compiler/xla/xla_client/util.h"

namespace torch_xla {
namespace cpp_test {

TEST(XlaUtilCacheTest, BasicTest) {
static const size_t kMaxSize = 64;
xla::util::Cache<int, std::string> cache(kMaxSize);

for (int i = 0; i < 2 * kMaxSize; ++i) {
std::string istr = std::to_string(i);
auto ptr = cache.Add(i, std::make_shared<std::string>(istr));
ASSERT_NE(ptr, nullptr);
EXPECT_EQ(*ptr, istr);

ptr = cache.Get(i);
ASSERT_NE(ptr, nullptr);
EXPECT_EQ(*ptr, istr);
}
for (int i = 0; i < kMaxSize - 1; ++i) {
auto ptr = cache.Get(i);
EXPECT_EQ(ptr, nullptr);
}
}

} // namespace cpp_test
} // namespace torch_xla
4 changes: 4 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ def loop_fn(model, loader):
model_parallel = dp.DataParallel(
XlaMNIST, train_loader, loop_fn, device_ids=devices)
model_parallel()
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(torch_xla._XLAC._xla_metrics_report())


class TestParallelTensorResnet18(XlaTestCase):
Expand Down Expand Up @@ -665,6 +667,8 @@ def loop_fn(model, loader):
model_parallel = dp.DataParallel(
torchvision.models.resnet18, train_loader, loop_fn, device_ids=devices)
model_parallel()
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(torch_xla._XLAC._xla_metrics_report())


class AxPlusB(nn.Module):
Expand Down
69 changes: 39 additions & 30 deletions third_party/xla_client/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,57 @@

#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>

namespace xla {
namespace util {

// Generic key and object cache with LRU expiration policy.
// Generic key and object cache with LRU expiration policy. The objects of type
// T will be stored as std::shared_ptr<T> and taken and returned as such, by the
// cache API.
template <typename K, typename T, typename H = std::hash<K>,
typename E = std::equal_to<K>>
class Cache {
using Element = std::pair<K, T>;
using ElementList = std::list<Element>;

struct Hasher {
size_t operator()(const K* key) const { return hasher(*key); }

H hasher;
};

struct Equaler {
bool operator()(const K* k1, const K* k2) const {
return equaler(*k1, *k2);
}

E equaler;
};

using ElementMap =
std::unordered_map<const K*, typename ElementList::iterator, Hasher,
Equaler>;

public:
using TypePtr = std::shared_ptr<T>;
using Element = std::pair<K, TypePtr>;

explicit Cache(size_t max_size) : max_size_(max_size) {}

// Adds an object to the cache, unless it already exists. If the cache grows
// beyond the limit set during construction, the oldest used object will be
// removed from the cache.
void Add(K key, T object) {
TypePtr Add(K key, TypePtr object) {
std::lock_guard<std::mutex> slock(lock_);
element_list_.emplace_front(Element(std::move(key), std::move(object)));
auto it = element_list_.begin();
if (!element_map_.emplace(&it->first, it).second) {
auto emplace_result = element_map_.emplace(&it->first, it);
if (!emplace_result.second) {
element_list_.erase(it);
DoLRU(emplace_result.first->second);
} else if (element_list_.size() > max_size_) {
Element* last = &element_list_.back();
element_map_.erase(&last->first);
element_list_.pop_back();
}
return emplace_result.first->second->second;
}

// Retrieves the existing object if it exists. If it does, it's position in
// the LRU list gets moved to the head of the list.
// Returns nullptr if no object with the specified key is found within the
// cache.
const T* Get(const K& key) {
TypePtr Get(const K& key) {
std::lock_guard<std::mutex> slock(lock_);
auto it = element_map_.find(&key);
if (it == element_map_.end()) {
return nullptr;
}
if (it->second != element_list_.begin()) {
// LRU re-positioning.
element_list_.splice(element_list_.begin(), element_list_, it->second);
}
return &it->second->second;
DoLRU(it->second);
return it->second->second;
}

bool Erase(const K& key) {
Expand All @@ -90,6 +75,30 @@ class Cache {
}

private:
using ElementList = std::list<Element>;

struct Hasher {
size_t operator()(const K* key) const { return hasher(*key); }

H hasher;
};

struct Equaler {
bool operator()(const K* k1, const K* k2) const {
return equaler(*k1, *k2);
}

E equaler;
};

using ElementMap =
std::unordered_map<const K*, typename ElementList::iterator, Hasher,
Equaler>;

void DoLRU(typename ElementList::iterator it) {
element_list_.splice(element_list_.begin(), element_list_, it);
}

std::mutex lock_;
size_t max_size_ = 0;
ElementList element_list_;
Expand Down
8 changes: 5 additions & 3 deletions third_party/xla_client/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ void AddXrtHostDevices(const string& worker_name, int task_no,
XrtComputationClient::Options* options) {
struct Devices {
const char* name;
const char* tf_name;
int count;
} const devices[] = {
{"TPU", sys_util::GetEnvInt("TPU_NUM_DEVICES", 8)},
{"CPU", sys_util::GetEnvInt("CPU_NUM_DEVICES", 1)},
{"TPU", "TPU", sys_util::GetEnvInt("TPU_NUM_DEVICES", 8)},
{"CPU", "XLA_CPU", sys_util::GetEnvInt("CPU_NUM_DEVICES", 1)},
};
string host_port = server.compare(0, 7, "grpc://") == 0
? server
Expand All @@ -66,9 +67,10 @@ void AddXrtHostDevices(const string& worker_name, int task_no,
int& device_ordinal = (*device_ordinals)[device.name];
for (int j = 0; j < device.count; ++j, ++device_ordinal) {
string device_name = absl::StrCat(device.name, ":", device_ordinal);
string tf_device_name = absl::StrCat(device.tf_name, ":", device_ordinal);
string xrt_device_name =
absl::StrCat("/job:", worker_name, "/replica:0/task:", task_no,
"/device:", device_name);
"/device:", tf_device_name);
options->device_map.emplace(device_name, xrt_device_name);
}
}
Expand Down
30 changes: 19 additions & 11 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ class ComputationClient {

const Shape& shape() const { return shape_; }

virtual void Swap(Data* data) = 0;

private:
int64 unique_id_ = 0;
string device_;
Shape shape_;
};

using DataPtr = std::shared_ptr<Data>;

class Computation {
public:
Computation(XlaComputation computation, ProgramShape program_shape,
Expand Down Expand Up @@ -105,14 +109,18 @@ class ComputationClient {

virtual ~ComputationClient() {}

// Creates a Data object with no actual device handle in it. The device handle
// will be populated in an asynchrounous fashion.
virtual DataPtr CreateDataPlaceholder(string device, Shape shape) = 0;

// Transfers local tensor values to the TPU servers and fetches the handles.
virtual std::vector<std::shared_ptr<Data>> TransferToServer(
virtual std::vector<DataPtr> TransferToServer(
tensorflow::gtl::ArraySlice<const TensorSource> tensors) = 0;

// Reads the tensor literal values stored at TPU server sites, behind the
// supplied handles.
virtual std::vector<Literal> TransferFromServer(
tensorflow::gtl::ArraySlice<const std::shared_ptr<Data>> handles) = 0;
tensorflow::gtl::ArraySlice<const DataPtr> handles) = 0;

// Compiles a set of computations.
virtual std::vector<std::shared_ptr<Computation>> Compile(
Expand All @@ -122,10 +130,10 @@ class ComputationClient {
// The passed device must match the common device of the arguments Data.
// If options.explode_tuple is true, the output tuple will be decomposed into
// its single elements.
virtual std::vector<std::shared_ptr<Data>> ExecuteComputation(
virtual std::vector<DataPtr> ExecuteComputation(
const Computation& computation,
tensorflow::gtl::ArraySlice<Data*> arguments, const string& device,
const ExecuteComputationOptions& options) = 0;
tensorflow::gtl::ArraySlice<const DataPtr> arguments,
const string& device, const ExecuteComputationOptions& options) = 0;

// Executes the computation in replicated mode.
// The size of the arguments vector is the number of replicas to execute,
Expand All @@ -138,9 +146,9 @@ class ComputationClient {
// The result[i], a vector itself, will be the result of the computation fed
// with arguments[i]. If options.explode_tuple is true, the output tuples will
// be decomposed into their single elements.
virtual std::vector<std::vector<std::shared_ptr<Data>>> ExecuteReplicated(
virtual std::vector<std::vector<DataPtr>> ExecuteReplicated(
const Computation& computation,
const std::vector<std::vector<Data*>>& arguments,
const std::vector<std::vector<DataPtr>>& arguments,
tensorflow::gtl::ArraySlice<const string> devices,
const ExecuteReplicatedOptions& options) = 0;

Expand All @@ -151,14 +159,14 @@ class ComputationClient {
// Returns a vector of vectors of device side Data object, with result[i]
// being the return value of computations[i]. If options.explode_tuple is
// true, the output tuples will be decomposed into their single elements.
virtual std::vector<std::vector<std::shared_ptr<Data>>> ExecuteParallel(
virtual std::vector<std::vector<DataPtr>> ExecuteParallel(
tensorflow::gtl::ArraySlice<const Computation* const> computations,
const std::vector<std::vector<Data*>>& arguments,
const std::vector<std::vector<DataPtr>>& arguments,
tensorflow::gtl::ArraySlice<const string> devices,
const ExecuteParallelOptions& options) = 0;

virtual std::vector<std::vector<std::shared_ptr<Data>>> DeconstructTuple(
tensorflow::gtl::ArraySlice<const std::shared_ptr<Data>> tuples) = 0;
virtual std::vector<std::vector<DataPtr>> DeconstructTuple(
tensorflow::gtl::ArraySlice<const DataPtr> tuples) = 0;

virtual string GetDefaultDevice() const = 0;

Expand Down
5 changes: 5 additions & 0 deletions third_party/xla_client/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ class TimedSection {
int64 start_;
};

#define XLA_TIMED(name) \
static xla::metrics::Metric* timed_metric = \
new xla::metrics::Metric(name, xla::metrics::MetricFnTime); \
xla::metrics::TimedSection timed_section(timed_metric)

} // namespace metrics
} // namespace xla

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla_client/multi_wait.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void MultiWait::Reset(size_t count) {
}

std::function<void()> MultiWait::Completer(std::function<void()> func) {
auto completer = [this, func{std::move(func)}]() {
auto completer = [this, func = std::move(func)]() {
try {
func();
Done();
Expand Down
31 changes: 31 additions & 0 deletions third_party/xla_client/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,43 @@
#include <vector>

#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"

namespace xla {
namespace util {

class Cleanup {
public:
explicit Cleanup(std::function<void(Status)> func) : func_(std::move(func)) {}
Cleanup(Cleanup&& ref) : func_(std::move(ref.func_)) {}
Cleanup(const Cleanup&) = delete;

~Cleanup() {
if (func_ != nullptr) {
func_(std::move(status_));
}
}

Cleanup& operator=(const Cleanup&) = delete;

Cleanup& operator=(Cleanup&& ref) {
if (this != &ref) {
func_ = std::move(ref.func_);
}
return *this;
}

void Release() { func_ = nullptr; }

void SetStatus(Status status) { status_ = std::move(status); }

private:
std::function<void(Status)> func_;
Status status_;
};

// Allows APIs which might return const references and values, to not be forced
// to return values in the signature.
template <typename T>
Expand Down
Loading