From 13aa6cbb624a84845c4acf64bfbc4cc828230d44 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 6 Nov 2025 16:58:08 -0300 Subject: [PATCH 1/4] Refactor cat and broadcast_tensors. --- torch_xla/csrc/tensor_methods.cpp | 90 ++++++++++++++++++++----------- torch_xla/csrc/tensor_ops.cpp | 7 ++- 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 0e55ea519327..046479224eb8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -721,6 +721,27 @@ absl::StatusOr> CustomCallImpl( return outputs; } +absl::Status CheckCatCompatibleShapes(xla::Shape s1, xla::Shape s2, + int64_t dim) { + xla::Shape s1_without_dim = s1; + xla::Shape s2_without_dim = s2; + + dim = torch::lazy::GetCanonicalDimensionIndex(dim, s1.dimensions().size()); + s1_without_dim.DeleteDimension(dim); + s2_without_dim.DeleteDimension(dim); + + if (!xla::ShapeUtil::CompatibleIgnoringElementType(s1_without_dim, + s2_without_dim)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "cat(): cannot concatenate tensors of shape ", s1.ToString(), " with ", + s2.ToString(), " at dimension ", dim, + ". Expected shapes to be equal (except at dimension ", dim, + ") or that either of them was a 1D empty tensor of size (0,)."))); + } + + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1473,15 +1494,17 @@ absl::StatusOr bmm(const XLATensorPtr& input, return matmul(input, mat2); } -std::vector broadcast_tensors( - absl::Span tensors) { - XLA_CHECK(!tensors.empty()) << "broadcast_tensors cannot take an empty list"; - std::vector tensor_ir_values; - for (const auto& tensor : tensors) { - tensor_ir_values.push_back(tensor->GetIrValue()); - } - torch::lazy::NodePtr node = BroadcastTensors(tensor_ir_values); - return tensors.front()->MakeOutputTensors(node); +absl::StatusOr> broadcast_tensors( + absl::Span tensors) { + XLA_RETURN_IF_ERROR(CheckNonEmptyInputs("broadcast_tensors()", tensors)); + + std::vector values(tensors.size()); + std::transform( + tensors.begin(), tensors.end(), values.begin(), + [](const XLATensorPtr& tensor) { return tensor->GetIrValue(); }); + + torch::lazy::NodePtr node = BroadcastTensors(values); + return tensors.front()->MakeOutputTensors(std::move(node)); } absl::StatusOr cat( @@ -1494,39 +1517,44 @@ absl::StatusOr cat( // - If empty dimension, other dimensions must be the same. // e.g. ([4, 0, 32, 32], [4, 2, 32, 32], dim=1) passes. // ([4, 0, 32, 32], [4, 2, 31, 32], dim=1) throws. - ABSL_CHECK(tensors.size() > 0); + XLA_RETURN_IF_ERROR(CheckNonEmptyInputs("cat()", tensors)); + + // Lazy ir values of all tensors that are not empty std::vector values; - std::vector shapes; - size_t last_tensor_index; + // Index of the last non-empty tensor. + std::size_t last_tensor_index = -1; + + // Gather the lazy ir value of all non-empty tensor, and check that + // all of them have the same shape. for (size_t i = 0; i < tensors.size(); ++i) { xla::Shape tensor_shape = tensors[i]->shape(); - if (tensor_shape.dimensions_size() == 1 && - tensor_shape.dimensions()[0] == 0) { + + // Ignore empty tensors. + if (tensor_shape.dimensions().size() == 1 && + tensor_shape.dimensions(0) == 0) { continue; } - dim = torch::lazy::GetCanonicalDimensionIndex( - dim, tensor_shape.dimensions_size()); - tensor_shape.DeleteDimension(dim); - if (!shapes.empty() && !xla::ShapeUtil::CompatibleIgnoringElementType( - shapes.back(), tensor_shape)) { - auto last_tensor = tensors[last_tensor_index]; - auto tensor = tensors[i]; - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( - "cat(): cannot concatenate tensors of shape ", - last_tensor->shape().get().ToString(), " with ", - tensor->shape().get().ToString(), " at dimension ", dim, - ". Expected shapes to be equal (except at dimension ", dim, - ") or that either of them was a 1D empty tensor of size (0,)."))); + + // Check that the current tensor has compatible shapes with the + // previously found non-empty tensors. + if (last_tensor_index != -1) { + xla::Shape last_tensor_shape = tensors[last_tensor_index]->shape(); + XLA_RETURN_IF_ERROR( + CheckCatCompatibleShapes(tensor_shape, last_tensor_shape, dim)); } - shapes.push_back(tensor_shape); - values.push_back(tensors[i]->GetIrValue()); + last_tensor_index = i; + values.push_back(tensors[i]->GetIrValue()); } + + // If there are no non-empty tensors, just return an empty tensor. + // e.g. the first one from the list. if (values.empty()) { return tensors[0]; } - return tensors[0]->CreateFrom(torch_xla::MakeNode(values, dim, dtype), - dtype); + + torch::lazy::NodePtr node = torch_xla::MakeNode(values, dim, dtype); + return tensors[0]->CreateFrom(std::move(node), dtype); } XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2, diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index ef74063540c2..47e69dbd183d 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -3,6 +3,7 @@ #include #include +#include "absl/base/nullability.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -82,7 +83,8 @@ XLATensorPtr MakeMatrixWithDiagonal(const XLATensorPtr& input, XLATensorPtr SmoothL1Loss(const XLATensorPtr& input, const XLATensorPtr& target, ReductionMode reduction, double beta) { torch::lazy::ScopePusher ir_scope(at::aten::smooth_l1_loss.toQualString()); - auto broadcasted_inputs = tensor_methods::broadcast_tensors({input, target}); + XLA_ASSIGN_OR_THROW(std::vector broadcasted_inputs, + tensor_methods::broadcast_tensors({input, target})); XLA_CHECK_EQ(broadcasted_inputs.size(), 2); const XLATensorPtr& broadcasted_input = broadcasted_inputs[0]; const XLATensorPtr& broadcasted_target = broadcasted_inputs[1]; @@ -121,7 +123,8 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output, ReductionMode reduction, double beta) { torch::lazy::ScopePusher ir_scope( at::aten::smooth_l1_loss_backward.toQualString()); - auto broadcasted_inputs = tensor_methods::broadcast_tensors({input, target}); + XLA_ASSIGN_OR_THROW(std::vector broadcasted_inputs, + tensor_methods::broadcast_tensors({input, target})); XLA_CHECK_EQ(broadcasted_inputs.size(), 2); const XLATensorPtr& broadcasted_input = broadcasted_inputs[0]; const XLATensorPtr& broadcasted_target = broadcasted_inputs[1]; From 6b5bb8a0c55446083c76f437b540c58cae911f48 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 6 Nov 2025 17:36:11 -0300 Subject: [PATCH 2/4] Update header. --- torch_xla/csrc/tensor_methods.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 7c5f2848891c..7b18ff37ec99 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -305,8 +305,8 @@ absl::StatusOr bmm(const XLATensorPtr& input, const XLATensorPtr& mat2); // Broadcasts the given tensors according to broadcasting semantics. -std::vector broadcast_tensors( - absl::Span tensors); +absl::StatusOr> broadcast_tensors( + absl::Span tensors); absl::StatusOr cat( absl::Span tensors, int64_t dim, From 01c47fd4f44fd467d123347e31225dc9d3d1572c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 7 Nov 2025 14:26:56 -0300 Subject: [PATCH 3/4] Fix non-canonical dim issue. --- torch_xla/csrc/tensor_methods.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 046479224eb8..03865b29dccd 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -726,7 +726,6 @@ absl::Status CheckCatCompatibleShapes(xla::Shape s1, xla::Shape s2, xla::Shape s1_without_dim = s1; xla::Shape s2_without_dim = s2; - dim = torch::lazy::GetCanonicalDimensionIndex(dim, s1.dimensions().size()); s1_without_dim.DeleteDimension(dim); s2_without_dim.DeleteDimension(dim); @@ -1523,6 +1522,9 @@ absl::StatusOr cat( std::vector values; // Index of the last non-empty tensor. std::size_t last_tensor_index = -1; + // Cache the canonical dimension, so that we won't have to recompute + // it every time. + std::optional cannonical_dim; // Gather the lazy ir value of all non-empty tensor, and check that // all of them have the same shape. @@ -1535,12 +1537,17 @@ absl::StatusOr cat( continue; } + if (!cannonical_dim.has_value()) { + cannonical_dim = torch::lazy::GetCanonicalDimensionIndex( + dim, tensor_shape.dimensions().size()); + } + // Check that the current tensor has compatible shapes with the // previously found non-empty tensors. if (last_tensor_index != -1) { xla::Shape last_tensor_shape = tensors[last_tensor_index]->shape(); - XLA_RETURN_IF_ERROR( - CheckCatCompatibleShapes(tensor_shape, last_tensor_shape, dim)); + XLA_RETURN_IF_ERROR(CheckCatCompatibleShapes( + tensor_shape, last_tensor_shape, *cannonical_dim)); } last_tensor_index = i; @@ -1553,7 +1560,8 @@ absl::StatusOr cat( return tensors[0]; } - torch::lazy::NodePtr node = torch_xla::MakeNode(values, dim, dtype); + torch::lazy::NodePtr node = + torch_xla::MakeNode(values, *cannonical_dim, dtype); return tensors[0]->CreateFrom(std::move(node), dtype); } From c49579c35ee84cb3e566d5f8e78a7fa02aa1625d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 7 Nov 2025 15:19:38 -0300 Subject: [PATCH 4/4] Fix test. --- test/test_ops_error_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 6c361706d137..ff92a5807e7f 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -52,7 +52,7 @@ def test(): self.assertExpectedRaisesInline( exc_type=RuntimeError, callable=test, - expect="""cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).""" + expect="""cat(): cannot concatenate tensors of shape f32[5,1] with f32[2,2] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).""" ) def test_div_raises_error_on_invalid_rounding_mode(self):