Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,17 +246,17 @@ void WithAllDevices(
}

std::string GetTensorTextGraph(at::Tensor tensor) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
return DumpUtil::ToText({xtensor->GetIrValue().node.get()});
}

std::string GetTensorDotGraph(at::Tensor tensor) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
return DumpUtil::ToDot({xtensor->GetIrValue().node.get()});
}

std::string GetTensorHloGraph(at::Tensor tensor) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
return DumpUtil::ToHlo({xtensor->GetIrValue()}, xtensor->GetDevice());
}

Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TEST_F(AtenXlaTensorTest, TestStorage) {
torch::Tensor a = torch::tensor({0.0});
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
XLATensorPtr xla_tensor_a = bridge::GetXlaTensor(xla_a);
XLATensorPtr xla_tensor_a = GetValueOrThrow(bridge::GetXlaTensor(xla_a));
EXPECT_EQ(xla_a.device(), xla_tensor_a->Storage().device());
AllClose(a, xla_a);
});
Expand Down
22 changes: 14 additions & 8 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "torch_xla/csrc/aten_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/tensor_methods.h"
#include "torch_xla/csrc/torch_util.h"

Expand All @@ -33,7 +34,8 @@ torch::Tensor EinsumAutogradFunction::forward(
}
ctx->save_for_backward(vars);

std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(tensors);
std::vector<XLATensorPtr> xla_tensors =
GetValueOrThrow(bridge::GetXlaTensors(tensors));
XLATensorPtr output = tensor_methods::einsum(eq_str, xla_tensors);
return bridge::AtenFromXlaTensor(output);
}
Expand All @@ -43,11 +45,13 @@ torch::autograd::variable_list EinsumAutogradFunction::backward(
torch::autograd::variable_list grad_output) {
std::string equation = ctx->saved_data["equation"].toString()->string();
torch::autograd::variable_list tensors = ctx->get_saved_variables();
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(tensors);
std::vector<XLATensorPtr> xla_tensors =
GetValueOrThrow(bridge::GetXlaTensors(tensors));

std::tuple<XLATensorPtr, XLATensorPtr> outputs =
tensor_methods::einsum_backward(bridge::GetXlaTensor(grad_output[0]),
xla_tensors, equation);
tensor_methods::einsum_backward(
GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), xla_tensors,
equation);

// For both einsum and max pool, we use "undef" as a placeholder for the
// non-tensor grad inputs, in this case the equation string.
Expand Down Expand Up @@ -190,7 +194,7 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
}
ctx->save_for_backward({self});
auto outputs = tensor_methods::max_pool_nd(
bridge::GetXlaTensor(self), /*spatial_dim_count=*/3,
GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode);
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
Expand Down Expand Up @@ -218,7 +222,8 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
ceil_mode, indices);
}
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
bridge::GetXlaTensor(grad_output[0]), bridge::GetXlaTensor(self),
GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])),
GetValueOrThrow(bridge::GetXlaTensor(self)),
/*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));

Expand All @@ -234,7 +239,7 @@ torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode) {
auto outputs = tensor_methods::max_pool_nd(
bridge::GetXlaTensor(self), /*spatial_dim_count=*/2,
GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode);
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
Expand All @@ -245,7 +250,8 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode) {
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
GetValueOrThrow(bridge::GetXlaTensor(grad_output)),
GetValueOrThrow(bridge::GetXlaTensor(self)),
/*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));
return grad;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/aten_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ static bool validate_tensor_list(const c10::List<at::Tensor>& tensorlist) {

// Retrieve the inner XLATensorPtr, and check it lives inside CUDA.
static XLATensorPtr get_xla_cuda_tensor(const at::Tensor& tensor) {
XLATensorPtr xla_tensor = bridge::GetXlaTensor(tensor);
XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
const torch::lazy::BackendDevice& device = xla_tensor->GetDevice();
TORCH_CHECK(device.type() == static_cast<int8_t>(XlaDeviceType::CUDA),
"OpenXLA CUDA fallback only supports XLA:CUDA tensors. Found a "
Expand Down
125 changes: 64 additions & 61 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/status.h"
#include "torch_xla/csrc/tensor_impl.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"
Expand Down Expand Up @@ -72,72 +74,68 @@ AtenXlaDeviceMapper* AtenXlaDeviceMapper::Get() {
return device_mapper;
}

XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
static absl::StatusOr<XLATensorImpl * absl_nonnull> GetXlaTensorImpl(
const at::Tensor& tensor) {
auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
return dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
XLATensorImpl* impl =
dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
if (impl == nullptr) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
"Input tensor is not an XLA tensor: ", tensor.toString())));
}
return impl;
}

} // namespace

XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
return GetXlaTensor(tensor).value_or(XLATensorPtr{});
}

absl::StatusOr<absl_nonnull XLATensorPtr> GetXlaTensor(
const at::Tensor& tensor) {
if (tensor.defined() &&
at::functionalization::impl::isFunctionalTensor(tensor)) {
// To make sure we have the most updated version of tensor.
at::functionalization::impl::sync(tensor);
}
XLATensorImpl* impl = GetXlaTensorImpl(tensor);
if (impl == nullptr) {
return XLATensorPtr();
}
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor));
return impl->tensor();
}

std::vector<XLATensorPtr> TryGetXlaTensors(const at::ITensorListRef& tensors) {
std::vector<XLATensorPtr> xla_tensors;
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> GetXlaTensors(
const at::ITensorListRef& tensors) {
std::vector<absl_nonnull XLATensorPtr> xla_tensors;
xla_tensors.reserve(tensors.size());
for (const auto& tensor : tensors) {
xla_tensors.push_back(bridge::TryGetXlaTensor(tensor));
XLA_ASSIGN_OR_RETURN(XLATensorPtr ptr, bridge::GetXlaTensor(tensor));
xla_tensors.push_back(std::move(ptr));
}
return xla_tensors;
}

bool IsXlaTensor(const at::Tensor& tensor) {
return GetXlaTensorImpl(tensor) != nullptr;
}

XLATensorPtr GetXlaTensor(const at::Tensor& tensor) {
auto xtensor = TryGetXlaTensor(tensor);
XLA_CHECK(xtensor) << "Input tensor is not an XLA tensor: "
<< tensor.toString();
return xtensor;
return GetXlaTensorImpl(tensor).ok();
}

void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
XLATensorImpl* impl =
dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
XLA_CHECK(impl != nullptr)
<< "Input tensor is not an XLA tensor: " << inner_tensor.toString();
absl::Status ReplaceXlaTensor(const at::Tensor& tensor,
XLATensorPtr new_xla_tensor) {
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor));
impl->set_tensor(std::move(new_xla_tensor));
return absl::OkStatus();
}

void ReplaceXlaTensor(const std::vector<at::Tensor>& tensors,
const std::vector<XLATensorPtr> new_xla_tensors) {
XLA_CHECK(tensors.size() == new_xla_tensors.size())
<< "The size of tensors and new_xla_tensors are not equal: "
<< tensors.size() << " vs. " << new_xla_tensors.size();
for (size_t i = 0; i < tensors.size(); ++i) {
ReplaceXlaTensor(tensors[i], new_xla_tensors[i]);
absl::Status ReplaceXlaTensor(const std::vector<at::Tensor>& tensors,
const std::vector<XLATensorPtr> new_xla_tensors) {
if (tensors.size() != new_xla_tensors.size()) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("The size of tensors and new_xla_tensors are not equal: ",
tensors.size(), " vs. ", new_xla_tensors.size())));
}
}

std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
std::vector<XLATensorPtr> xla_tensors;
xla_tensors.reserve(tensors.size());
for (const auto& tensor : tensors) {
xla_tensors.push_back(bridge::GetXlaTensor(tensor));
for (size_t i = 0; i < tensors.size(); ++i) {
XLA_RETURN_IF_ERROR(ReplaceXlaTensor(tensors[i], new_xla_tensors[i]));
}
return xla_tensors;
return absl::OkStatus();
}

torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
Expand All @@ -146,7 +144,7 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
(tensor.dim() == 0 && tensor.numel() == 1)) {
return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device);
} else {
return torch_xla::bridge::GetXlaTensor(tensor);
return GetValueOrThrow(torch_xla::bridge::GetXlaTensor(tensor));
}
}

Expand All @@ -155,22 +153,23 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
if (!tensor.defined()) {
return XLATensorPtr();
}

auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
if (!inner_tensor.defined()) {
return XLATensorPtr();
}
auto xtensor = TryGetXlaTensor(tensor);
return xtensor ? xtensor : XLATensor::Create(inner_tensor, device);

auto xtensor = GetXlaTensor(tensor);
return xtensor.ok() ? xtensor.value()
: XLATensor::Create(inner_tensor, device);
}

XLATensorPtr GetOrCreateXlaTensor(const std::optional<at::Tensor>& tensor,
const torch::lazy::BackendDevice& device) {
if (!IsDefined(tensor)) {
if (!tensor.has_value()) {
return XLATensorPtr();
}
auto xtensor = TryGetXlaTensor(*tensor);
auto inner_tensor = torch::lazy::maybe_unwrap_functional(*tensor);
return xtensor ? xtensor : XLATensor::Create(inner_tensor, device);
return GetOrCreateXlaTensor(*tensor, device);
}

std::vector<XLATensorPtr> GetOrCreateXlaTensors(
Expand Down Expand Up @@ -199,10 +198,10 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
continue;
}

auto xtensor = TryGetXlaTensor(tensor);
if (xtensor) {
auto xtensor_status = GetXlaTensor(tensor);
if (xtensor_status.ok()) {
to_translate[ix] = true;
xla_tensors.push_back(xtensor);
xla_tensors.push_back(xtensor_status.value());
} else {
aten_xla_tensors[ix] = tensor;
}
Expand Down Expand Up @@ -253,13 +252,14 @@ void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,
for (auto index : indices) {
at::Tensor dest = dest_xla_tensors.at(index);
at::Tensor source = source_cpu_tensors.at(index);
XLATensorImpl* dest_impl = GetXlaTensorImpl(dest);
if (dest_impl != nullptr) {
auto xla_source = TryGetXlaTensor(source);
if (!xla_source) {
dest_impl->tensor()->UpdateFromTensorOut(source);
auto dest_impl_status = GetXlaTensorImpl(dest);
if (dest_impl_status.ok()) {
auto dest_impl = std::move(dest_impl_status).value();
auto xla_source_status = GetXlaTensor(source);
if (xla_source_status.ok()) {
dest_impl->tensor()->UpdateFromTensorOut(xla_source_status.value());
} else {
dest_impl->tensor()->UpdateFromTensorOut(xla_source);
dest_impl->tensor()->UpdateFromTensorOut(source);
}
dest_impl->force_refresh_sizes();
} else {
Expand All @@ -270,11 +270,11 @@ void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,

std::optional<torch::lazy::BackendDevice> GetXlaDevice(
const at::Tensor& tensor) {
auto xtensor = TryGetXlaTensor(tensor);
if (!xtensor) {
auto xtensor_status = GetXlaTensor(tensor);
if (!xtensor_status.ok()) {
return std::nullopt;
}
return xtensor->GetDevice();
return xtensor_status.value()->GetDevice();
}

std::optional<torch::lazy::BackendDevice> GetXlaDevice(
Expand Down Expand Up @@ -469,12 +469,15 @@ std::vector<at::Tensor> CreateXlaTensors(
}

const at::Tensor& GetRootBase(const at::Tensor& tensor) {
auto xla_tensor = TryGetXlaTensor(tensor);
if (xla_tensor && xla_tensor->Base().defined()) {
return GetRootBase(xla_tensor->Base());
} else {
auto xla_tensor_status = GetXlaTensor(tensor);
if (!xla_tensor_status.ok()) {
return tensor;
}
auto xla_tensor = std::move(xla_tensor_status).value();
if (!xla_tensor->Base().defined()) {
return tensor;
}
return GetRootBase(xla_tensor->Base());
}

XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base) {
Expand Down
Loading
Loading