Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_ops_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test():
self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,)."""
expect="""cat(): cannot concatenate tensors of shape f32[5,1] with f32[2,2] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,)."""
)

def test_div_raises_error_on_invalid_rounding_mode(self):
Expand Down
98 changes: 67 additions & 31 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,26 @@ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> CustomCallImpl(
return outputs;
}

absl::Status CheckCatCompatibleShapes(xla::Shape s1, xla::Shape s2,
int64_t dim) {
xla::Shape s1_without_dim = s1;
xla::Shape s2_without_dim = s2;

s1_without_dim.DeleteDimension(dim);
s2_without_dim.DeleteDimension(dim);

if (!xla::ShapeUtil::CompatibleIgnoringElementType(s1_without_dim,
s2_without_dim)) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
"cat(): cannot concatenate tensors of shape ", s1.ToString(), " with ",
s2.ToString(), " at dimension ", dim,
". Expected shapes to be equal (except at dimension ", dim,
") or that either of them was a 1D empty tensor of size (0,).")));
}

return absl::OkStatus();
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1473,15 +1493,17 @@ absl::StatusOr<absl_nonnull XLATensorPtr> bmm(const XLATensorPtr& input,
return matmul(input, mat2);
}

std::vector<XLATensorPtr> broadcast_tensors(
absl::Span<const XLATensorPtr> tensors) {
XLA_CHECK(!tensors.empty()) << "broadcast_tensors cannot take an empty list";
std::vector<torch::lazy::Value> tensor_ir_values;
for (const auto& tensor : tensors) {
tensor_ir_values.push_back(tensor->GetIrValue());
}
torch::lazy::NodePtr node = BroadcastTensors(tensor_ir_values);
return tensors.front()->MakeOutputTensors(node);
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> broadcast_tensors(
absl::Span<const absl_nonnull XLATensorPtr> tensors) {
XLA_RETURN_IF_ERROR(CheckNonEmptyInputs("broadcast_tensors()", tensors));

std::vector<torch::lazy::Value> values(tensors.size());
std::transform(
tensors.begin(), tensors.end(), values.begin(),
[](const XLATensorPtr& tensor) { return tensor->GetIrValue(); });

torch::lazy::NodePtr node = BroadcastTensors(values);
return tensors.front()->MakeOutputTensors(std::move(node));
}

absl::StatusOr<absl_nonnull XLATensorPtr> cat(
Expand All @@ -1494,39 +1516,53 @@ absl::StatusOr<absl_nonnull XLATensorPtr> cat(
// - If empty dimension, other dimensions must be the same.
// e.g. ([4, 0, 32, 32], [4, 2, 32, 32], dim=1) passes.
// ([4, 0, 32, 32], [4, 2, 31, 32], dim=1) throws.
ABSL_CHECK(tensors.size() > 0);
XLA_RETURN_IF_ERROR(CheckNonEmptyInputs("cat()", tensors));

// Lazy ir values of all tensors that are not empty
std::vector<torch::lazy::Value> values;
std::vector<xla::Shape> shapes;
size_t last_tensor_index;
// Index of the last non-empty tensor.
std::size_t last_tensor_index = -1;
// Cache the canonical dimension, so that we won't have to recompute
// it every time.
std::optional<int64_t> cannonical_dim;

// Gather the lazy ir value of all non-empty tensor, and check that
// all of them have the same shape.
for (size_t i = 0; i < tensors.size(); ++i) {
xla::Shape tensor_shape = tensors[i]->shape();
if (tensor_shape.dimensions_size() == 1 &&
tensor_shape.dimensions()[0] == 0) {

// Ignore empty tensors.
if (tensor_shape.dimensions().size() == 1 &&
tensor_shape.dimensions(0) == 0) {
continue;
}
dim = torch::lazy::GetCanonicalDimensionIndex(
dim, tensor_shape.dimensions_size());
tensor_shape.DeleteDimension(dim);
if (!shapes.empty() && !xla::ShapeUtil::CompatibleIgnoringElementType(
shapes.back(), tensor_shape)) {
auto last_tensor = tensors[last_tensor_index];
auto tensor = tensors[i];
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
"cat(): cannot concatenate tensors of shape ",
last_tensor->shape().get().ToString(), " with ",
tensor->shape().get().ToString(), " at dimension ", dim,
". Expected shapes to be equal (except at dimension ", dim,
") or that either of them was a 1D empty tensor of size (0,).")));

if (!cannonical_dim.has_value()) {
cannonical_dim = torch::lazy::GetCanonicalDimensionIndex(
dim, tensor_shape.dimensions().size());
}
shapes.push_back(tensor_shape);
values.push_back(tensors[i]->GetIrValue());

// Check that the current tensor has compatible shapes with the
// previously found non-empty tensors.
if (last_tensor_index != -1) {
xla::Shape last_tensor_shape = tensors[last_tensor_index]->shape();
XLA_RETURN_IF_ERROR(CheckCatCompatibleShapes(
tensor_shape, last_tensor_shape, *cannonical_dim));
}

last_tensor_index = i;
values.push_back(tensors[i]->GetIrValue());
}

// If there are no non-empty tensors, just return an empty tensor.
// e.g. the first one from the list.
if (values.empty()) {
return tensors[0];
}
return tensors[0]->CreateFrom(torch_xla::MakeNode<Cat>(values, dim, dtype),
dtype);

torch::lazy::NodePtr node =
torch_xla::MakeNode<Cat>(values, *cannonical_dim, dtype);
return tensors[0]->CreateFrom(std::move(node), dtype);
}

XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2,
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ absl::StatusOr<absl_nonnull XLATensorPtr> bmm(const XLATensorPtr& input,
const XLATensorPtr& mat2);

// Broadcasts the given tensors according to broadcasting semantics.
std::vector<XLATensorPtr> broadcast_tensors(
absl::Span<const XLATensorPtr> tensors);
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> broadcast_tensors(
absl::Span<const absl_nonnull XLATensorPtr> tensors);

absl::StatusOr<absl_nonnull XLATensorPtr> cat(
absl::Span<const absl_nonnull XLATensorPtr> tensors, int64_t dim,
Expand Down
7 changes: 5 additions & 2 deletions torch_xla/csrc/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/util.h>

#include "absl/base/nullability.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/runtime/computation_client.h"
Expand Down Expand Up @@ -82,7 +83,8 @@ XLATensorPtr MakeMatrixWithDiagonal(const XLATensorPtr& input,
XLATensorPtr SmoothL1Loss(const XLATensorPtr& input, const XLATensorPtr& target,
ReductionMode reduction, double beta) {
torch::lazy::ScopePusher ir_scope(at::aten::smooth_l1_loss.toQualString());
auto broadcasted_inputs = tensor_methods::broadcast_tensors({input, target});
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> broadcasted_inputs,
tensor_methods::broadcast_tensors({input, target}));
XLA_CHECK_EQ(broadcasted_inputs.size(), 2);
const XLATensorPtr& broadcasted_input = broadcasted_inputs[0];
const XLATensorPtr& broadcasted_target = broadcasted_inputs[1];
Expand Down Expand Up @@ -121,7 +123,8 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output,
ReductionMode reduction, double beta) {
torch::lazy::ScopePusher ir_scope(
at::aten::smooth_l1_loss_backward.toQualString());
auto broadcasted_inputs = tensor_methods::broadcast_tensors({input, target});
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> broadcasted_inputs,
tensor_methods::broadcast_tensors({input, target}));
XLA_CHECK_EQ(broadcasted_inputs.size(), 2);
const XLATensorPtr& broadcasted_input = broadcasted_inputs[0];
const XLATensorPtr& broadcasted_target = broadcasted_inputs[1];
Expand Down