Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ std::vector<XLATensor> GetXlaTensors(
return xla_tensors;
}

XLATensor GetXlaTensorUnwrap(const at::Tensor& tensor) {
return GetXlaTensor(tensor);
}

XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device) {
if (!tensor.defined()) {
return XLATensor();
Expand All @@ -95,11 +91,12 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::TensorList& tensors) {
for (size_t i = 0; i < tensors.size(); ++i) {
const at::Tensor& tensor = tensors[i];
if (tensor.defined()) {
if (tensor.device().is_cpu()) {
aten_xla_tensors[i] = tensor;
} else {
auto xtensor = TryGetXlaTensor(tensor);
if (xtensor) {
to_translate[i] = true;
xla_tensors.push_back(GetXlaTensorUnwrap(tensor));
xla_tensors.push_back(*xtensor);
} else {
aten_xla_tensors[i] = tensor;
}
}
}
Expand All @@ -119,8 +116,12 @@ void XlaUpdateTensors(
tensorflow::gtl::ArraySlice<const at::Tensor> source_cpu_tensors,
tensorflow::gtl::ArraySlice<const size_t> indices) {
for (auto index : indices) {
XLATensor xtensor = GetXlaTensorUnwrap(dest_xla_tensors.at(index));
xtensor.UpdateFromTensor(source_cpu_tensors.at(index));
auto xtensor = TryGetXlaTensor(dest_xla_tensors.at(index));
if (xtensor) {
xtensor->UpdateFromTensor(source_cpu_tensors.at(index));
} else {
dest_xla_tensors.at(index).copy_(source_cpu_tensors.at(index));
}
}
}

Expand Down
4 changes: 0 additions & 4 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ XLATensor GetXlaTensor(const at::Tensor& tensor);
std::vector<XLATensor> GetXlaTensors(
tensorflow::gtl::ArraySlice<const at::Tensor> tensors);

// Like GetXlaTensor(), but if tensor is a variable, unwraps it and access the
// underline tensor.
XLATensor GetXlaTensorUnwrap(const at::Tensor& tensor);

// If tensor is an XLA tensor type, returns the XLATensor embedded within it,
// otherwise creates a new XLA tensor type with tensor as data.
XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device);
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1273,14 +1273,14 @@ at::Tensor AtenXlaType::full(at::IntArrayRef size, at::Scalar fill_value,

at::Tensor AtenXlaType::full_like(const at::Tensor& self,
at::Scalar fill_value) {
XLATensor self_tensor = bridge::GetXlaTensorUnwrap(self);
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(XLATensor::full_like(
self_tensor, fill_value, self_tensor.GetDevice(), c10::nullopt));
}

at::Tensor AtenXlaType::full_like(const at::Tensor& self, at::Scalar fill_value,
const at::TensorOptions& options) {
XLATensor self_tensor = bridge::GetXlaTensorUnwrap(self);
XLATensor self_tensor = bridge::GetXlaTensor(self);
XlaOptions xla_options(options, self_tensor.GetDevice());
return bridge::AtenFromXlaTensor(
XLATensor::full_like(self_tensor, fill_value, xla_options.get_device(),
Expand Down