From 35c1fc9347fb59661c8d15b27f1b35edd2ff3827 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Mon, 25 Feb 2019 19:37:49 -0800 Subject: [PATCH] Handle copy() operation directly. --- torch_xla/csrc/aten_xla_type.cpp | 10 ++++++++++ torch_xla/csrc/aten_xla_type.h | 3 +++ 2 files changed, 13 insertions(+) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 842a3a1e10ce..51eb75ef2020 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -57,6 +57,16 @@ int64_t AtenXlaType::numel(const at::Tensor& self) const { return xla::ShapeUtil::ElementsIn(self_tensor.shape()); } +at::Tensor AtenXlaType::copy(const at::Tensor& src, bool /* non_blocking */, + at::optional to_device) const { + std::vector tensors = {src}; + std::vector writeables = {false}; + auto xla_tensors = bridge::XlaCreateTensorList(tensors, &writeables); + Device device = to_device ? bridge::AtenDeviceToXlaDevice(*to_device) + : *GetDefaultDevice(); + return bridge::CreateXlaTensor(xla_tensors.front(), device); +} + at::Tensor AtenXlaType::_s_copy_from(const at::Tensor& self, const at::Tensor& dst, bool /* non_blocking */) const { diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 5f3bc96b0ae8..6beef4ecedcb 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -11,6 +11,9 @@ class AtenXlaType : public AtenXlaTypeBase { int64_t numel(const at::Tensor& self) const override; + at::Tensor copy(const at::Tensor& src, bool non_blocking, + at::optional to_device) const override; + at::Tensor _s_copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) const override;