From 738348f0add194ce53d88ff69e4f3b8ca3c9e607 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Tue, 24 Sep 2019 08:01:24 -0700 Subject: [PATCH] Drop the unwrap business and handle any sort of tensor type coming in where they should have not. --- torch_xla/csrc/aten_xla_bridge.cpp | 21 +++++++++++---------- torch_xla/csrc/aten_xla_bridge.h | 4 ---- torch_xla/csrc/aten_xla_type.cpp | 4 ++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 393f95492100..2b740b21f1de 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -74,10 +74,6 @@ std::vector GetXlaTensors( return xla_tensors; } -XLATensor GetXlaTensorUnwrap(const at::Tensor& tensor) { - return GetXlaTensor(tensor); -} - XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device) { if (!tensor.defined()) { return XLATensor(); @@ -95,11 +91,12 @@ std::vector XlaCreateTensorList(const at::TensorList& tensors) { for (size_t i = 0; i < tensors.size(); ++i) { const at::Tensor& tensor = tensors[i]; if (tensor.defined()) { - if (tensor.device().is_cpu()) { - aten_xla_tensors[i] = tensor; - } else { + auto xtensor = TryGetXlaTensor(tensor); + if (xtensor) { to_translate[i] = true; - xla_tensors.push_back(GetXlaTensorUnwrap(tensor)); + xla_tensors.push_back(*xtensor); + } else { + aten_xla_tensors[i] = tensor; } } } @@ -119,8 +116,12 @@ void XlaUpdateTensors( tensorflow::gtl::ArraySlice source_cpu_tensors, tensorflow::gtl::ArraySlice indices) { for (auto index : indices) { - XLATensor xtensor = GetXlaTensorUnwrap(dest_xla_tensors.at(index)); - xtensor.UpdateFromTensor(source_cpu_tensors.at(index)); + auto xtensor = TryGetXlaTensor(dest_xla_tensors.at(index)); + if (xtensor) { + xtensor->UpdateFromTensor(source_cpu_tensors.at(index)); + } else { + dest_xla_tensors.at(index).copy_(source_cpu_tensors.at(index)); + } } } diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index 65c7711bebbd..d5e7812e9d87 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -23,10 +23,6 @@ XLATensor GetXlaTensor(const at::Tensor& tensor); std::vector GetXlaTensors( tensorflow::gtl::ArraySlice tensors); -// Like GetXlaTensor(), but if tensor is a variable, unwraps it and access the -// underline tensor. -XLATensor GetXlaTensorUnwrap(const at::Tensor& tensor); - // If tensor is an XLA tensor type, returns the XLATensor embedded within it, // otherwise creates a new XLA tensor type with tensor as data. XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d10cf3d75926..caf4167981f0 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1273,14 +1273,14 @@ at::Tensor AtenXlaType::full(at::IntArrayRef size, at::Scalar fill_value, at::Tensor AtenXlaType::full_like(const at::Tensor& self, at::Scalar fill_value) { - XLATensor self_tensor = bridge::GetXlaTensorUnwrap(self); + XLATensor self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor(XLATensor::full_like( self_tensor, fill_value, self_tensor.GetDevice(), c10::nullopt)); } at::Tensor AtenXlaType::full_like(const at::Tensor& self, at::Scalar fill_value, const at::TensorOptions& options) { - XLATensor self_tensor = bridge::GetXlaTensorUnwrap(self); + XLATensor self_tensor = bridge::GetXlaTensor(self); XlaOptions xla_options(options, self_tensor.GetDevice()); return bridge::AtenFromXlaTensor( XLATensor::full_like(self_tensor, fill_value, xla_options.get_device(),