From 970d04d82f1133fbe36fb7ae49325c493099bdfa Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Aug 2025 13:52:55 -0300 Subject: [PATCH 1/3] Improve error messages for `flip`. --- torch_xla/csrc/aten_xla_type.cpp | 6 +++-- torch_xla/csrc/tensor_methods.cpp | 43 ++++++++++++++++++++++++++++--- torch_xla/csrc/tensor_methods.h | 3 ++- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 4d5286c2b04..7d8ba352059 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1804,8 +1804,10 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, at::Tensor XLANativeFunctions::flip(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::flip( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims))); + auto xself = GetValueOrThrow(bridge::GetXlaTensor(self)); + auto output = + GetValueOrThrow(tensor_methods::flip(xself, XlaHelpers::I64List(dims))); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ffeff7bab88..ce25bd31794 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -8,6 +8,7 @@ #include #include +#include #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" @@ -1680,12 +1681,48 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) { input->SetInPlaceIrValue(std::move(constant)); } -XLATensorPtr flip(const XLATensorPtr& input, absl::Span dims) { +absl::StatusOr flip(const XLATensorPtr& input, + absl::Span dims) { auto dimensions = torch::lazy::GetCanonicalDimensionIndices( torch_xla::runtime::util::ToVector(dims), input->shape().get().dimensions_size()); - std::set unique_dims(dimensions.begin(), dimensions.end()); - XLA_CHECK_EQ(unique_dims.size(), dimensions.size()); + + // Count the number of times each dimension appears. + std::map dim_count; + for (auto dim : dimensions) { + int64_t count = dim_count.find(dim) == dim_count.end() ? 0 : dim_count[dim]; + dim_count[dim] = count + 1; + } + + // If the number of uniquely counted dimensions is not the same as the number + // of given dimensions, it means that there were some dimensions that were + // given more than once. + if (dim_count.size() < dimensions.size()) { + // Collect `dim_count` keys to suggest a corresponding `dims` value that + // wouldn't error. + std::vector dims_suggestion; + std::transform(dim_count.begin(), dim_count.end(), + std::back_inserter(dims_suggestion), + [](auto& pair) { return pair.first; }); + // Collect the bad dimensions, i.e. those that appeared more than once. + std::vector> bad_dim_count; + std::copy_if(dim_count.begin(), dim_count.end(), + std::back_inserter(bad_dim_count), + [](auto& pair) { return pair.second > 1; }); + auto bad_dimensions_str = absl::StrJoin( + bad_dim_count, /* sep= */ ", ", + [](std::string* out, const std::pair& pair) { + absl::StrAppend(out, pair.first, " (", pair.second, " times)"); + }); + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("flip(): expected each dimension to appear at most once. " + "Found dimensions: ", + bad_dimensions_str, ". Consider changing dims from [", + absl::StrJoin(dims, /* sep= */ ", "), "] to [", + absl::StrJoin(dims_suggestion, /* sep= */ ", "), "]."))); + } + + ABSL_CHECK(dim_count.size() == dimensions.size()); return input->CreateFrom( torch_xla::MakeNode(input->GetIrValue(), dimensions)); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 395768fc867..fb7eae93f8d 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -450,7 +450,8 @@ void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols); void fill_(XLATensorPtr& input, const at::Scalar& value); // Flips (reverses) the values in the dimensions of the input tensor. -XLATensorPtr flip(const XLATensorPtr& input, absl::Span dims); +absl::StatusOr flip(const XLATensorPtr& input, + absl::Span dims); XLATensorPtr fmod( const XLATensorPtr& input, const XLATensorPtr& other, From 667ab2e70ff02175ae78dd8c799efd6938be203d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Aug 2025 14:55:12 -0300 Subject: [PATCH 2/3] Add test. --- test/test_operations.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 4c0395ff286..cb790a07414 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2497,6 +2497,20 @@ def test_div_raises_error_on_invalid_rounding_mode(self): "'trunc', 'floor', or be left unspecified.") self.assertEqual(str(e), expected_error) + def test_flip_raises_error_on_duplicated_dims(self): + a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) + dims = [0, 0, 0, 1, 2, 3, -1] + dims_suggestion = [0, 1, 2, 3] + + try: + torch.flip(a, dims=dims) + except RuntimeError as e: + expected_error = ( + "flip(): expected each dimension to appear at most once. Found " + "dimensions: 0 (3 times), 3 (2 times). Consider changing dims " + f"from {dims} to {dims_suggestion}.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): From 3d78f321d2a5de8d8b40a33dae3aaea399b33c25 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Aug 2025 14:20:30 -0300 Subject: [PATCH 3/3] Extract function for checking canonical dimensions are unique. --- torch_xla/csrc/tensor_methods.cpp | 106 +++++++++++++++++++----------- 1 file changed, 66 insertions(+), 40 deletions(-) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ce25bd31794..4a749d50ac7 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -346,6 +346,69 @@ XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input, return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool); } +// Checks that the canonical dimensions out of the given dimensions are unique +// for the `flip` operation. +// +// This function fails if any canonical dimension appears more than once. +// Notice that its error message is specialized for the `flip` operation. +// +// @param rank Input rank +// @param dims (Error Message) `flip` operation original `dims` argument +// @param canonical_dims (Error Message) Canonical dimensions extracted from +// the `dims` argument +absl::Status CheckFlipDimensionsAreUnique( + int64_t rank, absl::Span dims, + absl::Span canonical_dims) { + // Counter that maps each given dimension to the number of times it has + // appeared. + std::vector count(rank, 0); + + // Count the number of times each dimension appears. + for (auto dim : canonical_dims) { + count[dim] += 1; + } + + bool any_dimension_appears_more_than_once = std::any_of( + count.begin(), count.end(), [](const auto n) { return n > 1; }); + + if (any_dimension_appears_more_than_once) { + // Suggestion for the value of dims that wouldn't raise an error. + std::vector dims_suggestion; + // Each "bad" dimension is represented as a string of the form: + // + // ( times) + // + // To be later joined with commas. + std::vector bad_count_str; + + // Iterates each dimension, populating both `dims_suggestion` and + // `bad_count_str`. + for (int64_t i : c10::irange(rank)) { + // Dimension does not appear. Do nothing. + if (count[i] == 0) { + continue; + } + + // Dimension appears in `dims`. Add it to the suggestion list. + dims_suggestion.push_back(i); + + // Dimension appears more than once. Add it to the "bad" list. + if (count[i] > 1) { + bad_count_str.push_back(absl::StrCat(i, " (", count[i], " times)")); + } + } + + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "flip(): expected each dimension to appear at most once. Found " + "dimensions: ", + absl::StrJoin(bad_count_str, /* sep= */ ", "), + ". Consider changing dims from [", absl::StrJoin(dims, /* sep= */ ", "), + "] to [", absl::StrJoin(dims_suggestion, /* sep= */ ", "), "]."))); + } + + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1683,46 +1746,9 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) { absl::StatusOr flip(const XLATensorPtr& input, absl::Span dims) { - auto dimensions = torch::lazy::GetCanonicalDimensionIndices( - torch_xla::runtime::util::ToVector(dims), - input->shape().get().dimensions_size()); - - // Count the number of times each dimension appears. - std::map dim_count; - for (auto dim : dimensions) { - int64_t count = dim_count.find(dim) == dim_count.end() ? 0 : dim_count[dim]; - dim_count[dim] = count + 1; - } - - // If the number of uniquely counted dimensions is not the same as the number - // of given dimensions, it means that there were some dimensions that were - // given more than once. - if (dim_count.size() < dimensions.size()) { - // Collect `dim_count` keys to suggest a corresponding `dims` value that - // wouldn't error. - std::vector dims_suggestion; - std::transform(dim_count.begin(), dim_count.end(), - std::back_inserter(dims_suggestion), - [](auto& pair) { return pair.first; }); - // Collect the bad dimensions, i.e. those that appeared more than once. - std::vector> bad_dim_count; - std::copy_if(dim_count.begin(), dim_count.end(), - std::back_inserter(bad_dim_count), - [](auto& pair) { return pair.second > 1; }); - auto bad_dimensions_str = absl::StrJoin( - bad_dim_count, /* sep= */ ", ", - [](std::string* out, const std::pair& pair) { - absl::StrAppend(out, pair.first, " (", pair.second, " times)"); - }); - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( - absl::StrCat("flip(): expected each dimension to appear at most once. " - "Found dimensions: ", - bad_dimensions_str, ". Consider changing dims from [", - absl::StrJoin(dims, /* sep= */ ", "), "] to [", - absl::StrJoin(dims_suggestion, /* sep= */ ", "), "]."))); - } - - ABSL_CHECK(dim_count.size() == dimensions.size()); + auto rank = input->shape().get().dimensions_size(); + auto dimensions = torch::lazy::GetCanonicalDimensionIndices(dims, rank); + XLA_RETURN_IF_ERROR(CheckFlipDimensionsAreUnique(rank, dims, dimensions)); return input->CreateFrom( torch_xla::MakeNode(input->GetIrValue(), dimensions)); }