From 6bd003a00deb20d251ed502c88733aff96bab5e5 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Thu, 7 Feb 2019 10:01:49 -0800 Subject: [PATCH] Wire in ATEN XLA Python bindings and added simple end-to-end test. --- scripts/gen.py | 1 + test/cpp/cpp_test_util.cpp | 10 +++------ test/cpp/cpp_test_util.h | 2 -- test/test_operations.py | 11 ++++++++++ torch_patches/16844.diff | 27 +++++++++++++++++++++++++ torch_xla/__init__.py | 2 ++ torch_xla/csrc/aten_xla_bridge.cpp | 14 ++++++++++++- torch_xla/csrc/aten_xla_bridge.h | 7 +++++++ torch_xla/csrc/init_python_bindings.cpp | 6 ++++++ torch_xla/csrc/torch_util.h | 6 ++++++ 10 files changed, 76 insertions(+), 10 deletions(-) create mode 100644 torch_patches/16844.diff diff --git a/scripts/gen.py b/scripts/gen.py index 3cc53a731a3c..8166bdfc579f 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -217,6 +217,7 @@ class {type_name} : public AtenXlaType {{ _XLA_FUNCTIONS = { 'empty': 'bridge::CreateEmptyTensor', + 'randn': 'bridge::CreateRandTensor', } _RESULT_NAME = 'x_result' diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 5188090d5940..4c8b3adab5db 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -6,21 +6,17 @@ #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "torch/csrc/autograd/variable.h" +#include "torch_util.h" namespace torch_xla { namespace cpp_test { -at::Tensor ToTensor(const at::Tensor& tensor) { - return tensor.is_variable() ? torch::autograd::as_variable_ref(tensor).data() - : tensor; -} - at::Tensor ToTensor(XLATensor& xla_tensor) { - return ToTensor(xla_tensor.ToTensor()); + return torch_xla::ToTensor(xla_tensor.ToTensor()); } at::Tensor ToCpuTensor(const at::Tensor& t) { - at::Tensor tensor = ToTensor(t); + at::Tensor tensor = torch_xla::ToTensor(t); XLATensorImpl* impl = dynamic_cast(tensor.unsafeGetTensorImpl()); return impl != nullptr ? ToTensor(impl->tensor()) : tensor; diff --git a/test/cpp/cpp_test_util.h b/test/cpp/cpp_test_util.h index 1d6060975131..1560937d55b0 100644 --- a/test/cpp/cpp_test_util.h +++ b/test/cpp/cpp_test_util.h @@ -18,8 +18,6 @@ namespace cpp_test { // tensor, it will be returned. at::Tensor ToCpuTensor(const at::Tensor& t); -at::Tensor ToTensor(const at::Tensor& tensor); - at::Tensor ToTensor(XLATensor& xla_tensor); bool EqualValues(at::Tensor a, at::Tensor b); diff --git a/test/test_operations.py b/test/test_operations.py index b02242906b1c..5edc9e634e5c 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -38,6 +38,10 @@ def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) +def _xla_device(n=0): + return torch.device('xla:{}'.format(n)) + + class Holder(object): pass @@ -1353,6 +1357,13 @@ def test_log_softmax(self): self.assertEqualRel(out.data, expected.data) +class TestAtenXlaTensor(XlaTestCase): + + def test_size(self): + x = _gen_tensor(4, 2, device=_xla_device()) + torch_xla._XLAC._get_xla_tensor(x) + + if __name__ == '__main__': torch.set_default_tensor_type('torch.FloatTensor') torch.manual_seed(42) diff --git a/torch_patches/16844.diff b/torch_patches/16844.diff new file mode 100644 index 000000000000..f470b08044ce --- /dev/null +++ b/torch_patches/16844.diff @@ -0,0 +1,27 @@ +commit 55fbb131fd1676365af72417aeb46292251075f3 +Author: Davide Libenzi +Date: Thu Feb 7 08:35:20 2019 -0800 + + Add recognition for XLA device types. + +diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp +index 1d2d1ec0b..79ee6251e 100644 +--- a/c10/core/Device.cpp ++++ b/c10/core/Device.cpp +@@ -13,7 +13,7 @@ + namespace c10 { + namespace { + DeviceType parse_type(const std::string& device_string) { +- static const std::array, 8> types = {{ ++ static const std::array, 9> types = {{ + {"cpu", DeviceType::CPU}, + {"cuda", DeviceType::CUDA}, + {"mkldnn", DeviceType::MKLDNN}, +@@ -22,6 +22,7 @@ DeviceType parse_type(const std::string& device_string) { + {"ideep", DeviceType::IDEEP}, + {"hip", DeviceType::HIP}, + {"msnpu", DeviceType::MSNPU}, ++ {"xla", DeviceType::XLA}, + }}; + auto device = std::find_if( + types.begin(), diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 668aa537f30e..b97abb4ff742 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -12,3 +12,5 @@ def _torch_xla_zeros_like(p): torch.zeros_like = _torch_xla_zeros_like + +_XLAC._register_aten_types() diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 1b046593991f..026693326173 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -12,10 +12,22 @@ at::Tensor CreateEmptyTensor(at::IntList size, return at::empty(size, options.device(at::kCPU)); } +at::Tensor CreateRandTensor(at::IntArrayRef size, + at::Generator* generator, + const at::TensorOptions& options) { + return at::randn(size, generator, options.device(at::DeviceType::CPU)); +} + +at::Tensor CreateRandTensor(at::IntArrayRef size, + const at::TensorOptions& options) { + return at::randn(size, options.device(at::DeviceType::CPU)); +} + XLATensor& GetXlaTensor(const at::Tensor& tensor) { XLATensorImpl* impl = dynamic_cast(tensor.unsafeGetTensorImpl()); - XLA_CHECK(impl != nullptr); + XLA_CHECK(impl != nullptr) + << "Input tensor is not an XLA tensor: " << tensor.toString(); return impl->tensor(); } diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index 0bddd20aad19..2a8f41edc466 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -16,6 +16,13 @@ namespace bridge { at::Tensor CreateEmptyTensor(at::IntList size, const at::TensorOptions& options); +// Helper function which creates a random CPU ATEN tensor. +at::Tensor CreateRandTensor(at::IntArrayRef size, + at::Generator* generator, + const at::TensorOptions& options); +at::Tensor CreateRandTensor(at::IntArrayRef size, + const at::TensorOptions& options); + // Extracts the XLATensor out of our version of at::Tensor. Throws an exception // if tensor is not an XLA tensor. XLATensor& GetXlaTensor(const at::Tensor& tensor); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b1588dfe320c..52ffbd5106e4 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1,5 +1,7 @@ #include "init_python_bindings.h" +#include "aten_xla_bridge.h" +#include "aten_xla_type.h" #include "module.h" #include "passes/eval_static_size.h" #include "passes/replace_in_place_ops.h" @@ -58,6 +60,10 @@ void InitXlaModuleBindings(py::module m) { .def("parameters_buffers", [](XlaModule& xla_module) { return xla_module.parameters_buffers(); }); + m.def("_register_aten_types", []() { AtenXlaType::RegisterAtenTypes(); }); + m.def("_get_xla_tensor", [](const at::Tensor& tensor) -> XLATensor { + return bridge::GetXlaTensor(ToTensor(tensor)); + }); m.def("_xla_sync_multi", [](std::vector& tensors) { NoGilSection nogil; XLATensor::ApplyPendingGraph(&tensors, /*apply_context=*/nullptr); diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h index bf7cdce05d9c..a6111507747d 100644 --- a/torch_xla/csrc/torch_util.h +++ b/torch_xla/csrc/torch_util.h @@ -6,6 +6,7 @@ #include #include "module.h" #include "tensor.h" +#include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/pybind_utils.h" namespace torch_xla { @@ -21,4 +22,9 @@ static inline at::Tensor CopyTensor(const at::Tensor& ref) { return ref.to(ref.options(), /*non_blocking=*/false, /*copy=*/true); } +static inline at::Tensor ToTensor(const at::Tensor& tensor) { + return tensor.is_variable() ? torch::autograd::as_variable_ref(tensor).data() + : tensor; +} + } // namespace torch_xla