Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class {type_name} : public AtenXlaType {{

_XLA_FUNCTIONS = {
'empty': 'bridge::CreateEmptyTensor',
'randn': 'bridge::CreateRandTensor',
}

_RESULT_NAME = 'x_result'
Expand Down
10 changes: 3 additions & 7 deletions test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/autograd/variable.h"
#include "torch_util.h"

namespace torch_xla {
namespace cpp_test {

at::Tensor ToTensor(const at::Tensor& tensor) {
return tensor.is_variable() ? torch::autograd::as_variable_ref(tensor).data()
: tensor;
}

at::Tensor ToTensor(XLATensor& xla_tensor) {
return ToTensor(xla_tensor.ToTensor());
return torch_xla::ToTensor(xla_tensor.ToTensor());
}

at::Tensor ToCpuTensor(const at::Tensor& t) {
at::Tensor tensor = ToTensor(t);
at::Tensor tensor = torch_xla::ToTensor(t);
XLATensorImpl* impl =
dynamic_cast<XLATensorImpl*>(tensor.unsafeGetTensorImpl());
return impl != nullptr ? ToTensor(impl->tensor()) : tensor;
Expand Down
2 changes: 0 additions & 2 deletions test/cpp/cpp_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ namespace cpp_test {
// tensor, it will be returned.
at::Tensor ToCpuTensor(const at::Tensor& t);

at::Tensor ToTensor(const at::Tensor& tensor);

at::Tensor ToTensor(XLATensor& xla_tensor);

bool EqualValues(at::Tensor a, at::Tensor b);
Expand Down
11 changes: 11 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)


def _xla_device(n=0):
return torch.device('xla:{}'.format(n))


class Holder(object):
pass

Expand Down Expand Up @@ -1353,6 +1357,13 @@ def test_log_softmax(self):
self.assertEqualRel(out.data, expected.data)


class TestAtenXlaTensor(XlaTestCase):

def test_size(self):
x = _gen_tensor(4, 2, device=_xla_device())
torch_xla._XLAC._get_xla_tensor(x)


if __name__ == '__main__':
torch.set_default_tensor_type('torch.FloatTensor')
torch.manual_seed(42)
Expand Down
27 changes: 27 additions & 0 deletions torch_patches/16844.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
commit 55fbb131fd1676365af72417aeb46292251075f3
Author: Davide Libenzi <dlibenzi@google.com>
Date: Thu Feb 7 08:35:20 2019 -0800

Add recognition for XLA device types.

diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp
index 1d2d1ec0b..79ee6251e 100644
--- a/c10/core/Device.cpp
+++ b/c10/core/Device.cpp
@@ -13,7 +13,7 @@
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
- static const std::array<std::pair<std::string, DeviceType>, 8> types = {{
+ static const std::array<std::pair<std::string, DeviceType>, 9> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
@@ -22,6 +22,7 @@ DeviceType parse_type(const std::string& device_string) {
{"ideep", DeviceType::IDEEP},
{"hip", DeviceType::HIP},
{"msnpu", DeviceType::MSNPU},
+ {"xla", DeviceType::XLA},
}};
auto device = std::find_if(
types.begin(),
2 changes: 2 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ def _torch_xla_zeros_like(p):


torch.zeros_like = _torch_xla_zeros_like

_XLAC._register_aten_types()
14 changes: 13 additions & 1 deletion torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,22 @@ at::Tensor CreateEmptyTensor(at::IntList size,
return at::empty(size, options.device(at::kCPU));
}

at::Tensor CreateRandTensor(at::IntArrayRef size,
at::Generator* generator,
const at::TensorOptions& options) {
return at::randn(size, generator, options.device(at::DeviceType::CPU));
}

at::Tensor CreateRandTensor(at::IntArrayRef size,
const at::TensorOptions& options) {
return at::randn(size, options.device(at::DeviceType::CPU));
}

XLATensor& GetXlaTensor(const at::Tensor& tensor) {
XLATensorImpl* impl =
dynamic_cast<XLATensorImpl*>(tensor.unsafeGetTensorImpl());
XLA_CHECK(impl != nullptr);
XLA_CHECK(impl != nullptr)
<< "Input tensor is not an XLA tensor: " << tensor.toString();
return impl->tensor();
}

Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ namespace bridge {
at::Tensor CreateEmptyTensor(at::IntList size,
const at::TensorOptions& options);

// Helper function which creates a random CPU ATEN tensor.
at::Tensor CreateRandTensor(at::IntArrayRef size,
at::Generator* generator,
const at::TensorOptions& options);
at::Tensor CreateRandTensor(at::IntArrayRef size,
const at::TensorOptions& options);

// Extracts the XLATensor out of our version of at::Tensor. Throws an exception
// if tensor is not an XLA tensor.
XLATensor& GetXlaTensor(const at::Tensor& tensor);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "init_python_bindings.h"

#include "aten_xla_bridge.h"
#include "aten_xla_type.h"
#include "module.h"
#include "passes/eval_static_size.h"
#include "passes/replace_in_place_ops.h"
Expand Down Expand Up @@ -58,6 +60,10 @@ void InitXlaModuleBindings(py::module m) {
.def("parameters_buffers", [](XlaModule& xla_module) {
return xla_module.parameters_buffers();
});
m.def("_register_aten_types", []() { AtenXlaType::RegisterAtenTypes(); });
m.def("_get_xla_tensor", [](const at::Tensor& tensor) -> XLATensor {
return bridge::GetXlaTensor(ToTensor(tensor));
});
m.def("_xla_sync_multi", [](std::vector<XLATensor>& tensors) {
NoGilSection nogil;
XLATensor::ApplyPendingGraph(&tensors, /*apply_context=*/nullptr);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/torch_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/ATen.h>
#include "module.h"
#include "tensor.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/pybind_utils.h"

namespace torch_xla {
Expand All @@ -21,4 +22,9 @@ static inline at::Tensor CopyTensor(const at::Tensor& ref) {
return ref.to(ref.options(), /*non_blocking=*/false, /*copy=*/true);
}

static inline at::Tensor ToTensor(const at::Tensor& tensor) {
return tensor.is_variable() ? torch::autograd::as_variable_ref(tensor).data()
: tensor;
}

} // namespace torch_xla