From c19e785fad0b51edb0de71988f2ebd887a95ca83 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 26 Nov 2025 15:18:36 -0300 Subject: [PATCH 1/3] Improve error handling and message of lowering helper functions. --- torch_xla/csrc/helpers.cpp | 358 ++++++++++++++++++++++++++++++------- torch_xla/csrc/helpers.h | 81 +++++++-- 2 files changed, 365 insertions(+), 74 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index f3a9730a5dc..cbd29bd4019 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -1,20 +1,42 @@ #include "torch_xla/csrc/helpers.h" +#include +#include +#include +#include +#include #include #include - -#include -#include - +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/primitive_util.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" #include "torch_xla/csrc/convert_ops.h" -#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/status.h" @@ -72,6 +94,152 @@ xla::XlaComputation CreateMinMaxComputation(const std::string& name, return min_max_computation; } +// Joins the given `span` into a `std::string`. +// +// For each element, this function will append to the output string: +// 1. A star (*), if `needs_star(el)` returns true +// 2. The return value of `get_printable_data(el)` +template +std::string SpanToStringWithStar( + absl::Span span, FStar&& needs_star, + FPrint&& get_printable_data = [](T tee) -> T { return tee; }) { + return absl::StrJoin( + span, /* separator= */ ", ", [&](std::string* out, const T el) { + absl::StrAppend(out, needs_star(el) ? "*" : "", get_printable_data(el)); + }); +} + +// Joins the given `sizes` into a `std::string`. +// +// Similarly to `SpanToStringWithStar` function above, this function will also +// append to the output string the 2 described items. However, with the +// following changes: +// 1. The `needs_star` function will also depend on the index of the element +// 2. It will always print the element, instead of the index (i.e. the +// `get_printable_data` is already set) +template +std::string EnumeratedSizesToStringWithStar(absl::Span sizes, + F&& needs_star) { + std::vector indices(sizes.size()); + std::iota(indices.begin(), indices.end(), 0); + + return SpanToStringWithStar( + absl::MakeConstSpan(indices), /* needs_star= */ + [&, sizes](size_t i) -> bool { return needs_star(i, sizes[i]); }, + /* get_printable_data= */ + [&, sizes](size_t i) -> int64_t { return sizes[i]; }); +} + +bool IsUnboundedDynamicSize(const int64_t size) { + return size == xla::Shape::kUnboundedSize; +} + +// Checks that none of the `output_sizes` are unbounded dynamic sizes. +// +// This function is exclusively called by `SafeDynamicReshape()`. +absl::Status CheckDynamicReshapeAnyOfOutputSizesIsUnboundedDynamic( + absl::Span output_sizes) { + if (std::any_of(output_sizes.begin(), output_sizes.end(), + IsUnboundedDynamicSize)) { + std::string output_sizes_with_unbounded_mark_str = SpanToStringWithStar( + output_sizes, /* needs_star= */ [](int64_t size) -> bool { + return IsUnboundedDynamicSize(size); + }); + + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "Error when calling DynamicReshape() in the lower phase: expected " + "output sizes [", + output_sizes_with_unbounded_mark_str, + "] to have no unbounded dynamic dimensions (*)."))); + } + + return absl::OkStatus(); +} + +// Checks that when mapping the `input_dynamic_dimension`, we won't split a +// dynamic dimension in the output. +// +// This function is exclusively called by +// `GetDynamicReshapeInfoOutputDynamicDimension()`. +template +absl::Status CheckGetDynamicReshapeInfoDynamicDimensionNotSplit( + F&& get_error_prefix, int64_t input_dynamic_dimension, size_t i, + int64_t output_elements, int64_t input_elements_before_dynamic_dimension) { + if (output_elements > input_elements_before_dynamic_dimension) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + get_error_prefix(), + " Expected the number of output elements before dimension ", i, + " (which was ", output_elements, + ") to be less-than or equal the number of input elements before the " + "input dynamic dimension ", + input_dynamic_dimension, " (which was ", + input_elements_before_dynamic_dimension, ")."))); + } + return absl::OkStatus(); +} + +// Checks that we were able to find a dynamic dimension in the output. +// +// This function is exclusively called by +// `GetDynamicReshapeInfoOutputDynamicDimension()`. +template +absl::Status CheckGetDynamicReshapeInfoFoundOutputDynamicDimension( + F&& get_error_prefix, std::optional output_dynamic_dimension) { + if (!output_dynamic_dimension.has_value()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + get_error_prefix(), + " The mapped dynamic dimension of the output was not found."))); + } + return absl::OkStatus(); +} + +// Tries to map the `input_dynamic_dimension` onto `output_sizes`. If +// successful, returns the mapped output dynamic dimension. +// +// This function is exclusively called by `GetDynamicReshapeInfo()`. +absl::StatusOr GetDynamicReshapeInfoOutputDynamicDimension( + absl::Span input_sizes, + absl::Span output_sizes, int64_t input_dynamic_dimension) { + // Function for building the error prefix, which prints the `input_sizes` + // (with the dynamic dimension), and `output_sizes`. + auto get_error_prefix = [=]() -> std::string { + std::string input_sizes_str = EnumeratedSizesToStringWithStar( + input_sizes, + [=](size_t i, int64_t) { return i == input_dynamic_dimension; }); + std::string output_sizes_str = + absl::StrJoin(output_sizes, /* separator= */ ", "); + return absl::StrCat( + "Error when calling GetDynamicReshapeInfo() in the lower phase: unable " + "to map dynamic dimension when reshaping input [", + input_sizes_str, "] into output [", output_sizes_str, "]."); + }; + + std::optional output_dynamic_dimension; + + int64_t input_elements_before_dynamic_dimension = std::accumulate( + input_sizes.begin(), input_sizes.begin() + input_dynamic_dimension, 1, + std::multiplies<>()); + int64_t input_elements_including_dynamic_dimension = + input_elements_before_dynamic_dimension * + input_sizes[input_dynamic_dimension]; + int64_t output_elements = 1; + + for (size_t i = 0; i < output_sizes.size(); ++i) { + XLA_RETURN_IF_ERROR(CheckGetDynamicReshapeInfoDynamicDimensionNotSplit( + get_error_prefix, input_dynamic_dimension, i, output_elements, + input_elements_before_dynamic_dimension)); + output_elements *= output_sizes[i]; + if (output_elements >= input_elements_including_dynamic_dimension) { + output_dynamic_dimension = i; + break; + } + } + + XLA_RETURN_IF_ERROR(CheckGetDynamicReshapeInfoFoundOutputDynamicDimension( + get_error_prefix, output_dynamic_dimension)); + return *output_dynamic_dimension; +} + } // namespace xla::PrecisionConfig::Precision XlaHelpers::s_mat_mul_precision = @@ -112,16 +280,58 @@ xla::XlaOp XlaHelpers::CreateReturnValue( } int64_t XlaHelpers::GetDynamicDimension(const xla::Shape& shape) { - int64_t dynamic_dimension = -1; - for (int64_t i = 0; i < shape.dimensions_size(); ++i) { - if (shape.is_dynamic_dimension(i)) { - XLA_CHECK(dynamic_dimension < 0) - << "Only one dynamic dimension is supported: " << i << " and " - << dynamic_dimension << " in " << shape; - dynamic_dimension = i; - } + XLA_ASSIGN_OR_THROW(std::optional dynamic_dimension, + CheckAtMostOneDynamicDimension(shape)); + return dynamic_dimension.value_or(-1); +} + +absl::StatusOr> +XlaHelpers::CheckAtMostOneDynamicDimension(const xla::Shape& shape) { + // Function for conveniently checking whether dimension `i` is dynamic. + auto check_is_dynamic_dimension = [&shape](size_t i) { + return shape.is_dynamic_dimension(i); + }; + + // Indices representing each dimension of `shape`. We shall use this to find + // out the number of dynamic dimensions and, if needed, build an error + // message. + std::vector indices(shape.dimensions().size()); + std::iota(indices.begin(), indices.end(), 0); + + // Dynamic dimensions should be placed first. The returned iterator points + // after the last element of the dynamic dimensions. So, taking its distance + // should give us the number of dynamic dimensions. + std::vector::iterator end_of_dynamic_dimensions = + std::stable_partition(indices.begin(), indices.end(), + check_is_dynamic_dimension); + + size_t dynamic_dimensions_number = + std::distance(indices.begin(), end_of_dynamic_dimensions); + + switch (dynamic_dimensions_number) { + case 0: + return std::nullopt; + case 1: + return indices.front(); } - return dynamic_dimension; + + // From this point onwards, we know that there are more than 1 dynamic + // dimension. Therefore, we shall return an error status. + + std::string shape_str = EnumeratedSizesToStringWithStar( + shape.dimensions(), /* needs_star= */ [&](size_t i, int64_t) { + return check_is_dynamic_dimension(i); + }); + + std::string dynamic_dimensions_str = absl::StrJoin( + indices.begin(), end_of_dynamic_dimensions, /* separator= */ ", "); + + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "Error when calling GetSingleDynamicDimension() in the lower phase: " + "expected shape [", + shape_str, "] to have a single dynamic dimension (*). However, found ", + dynamic_dimensions_number, " dynamic dimensions at indices [", + dynamic_dimensions_str, "]."))); } XlaHelpers::DynamicSize XlaHelpers::GetDimensionsSize( @@ -151,7 +361,7 @@ XlaHelpers::DynamicSize XlaHelpers::GetDimensionsSize( } } } - absl::optional scalar_size; + std::optional scalar_size; if (size_scalar >= 0) { scalar_size = size_scalar; } @@ -288,41 +498,42 @@ xla::XlaOp XlaHelpers::ReshapeToRank(xla::XlaOp input, int64_t expected_rank, return xla::Reshape(input, dimensions); } -absl::optional -XlaHelpers::GetDynamicReshapeInfo(const xla::Shape& input_shape, - absl::Span output_sizes) { - int64_t input_dyndim_idx = GetDynamicDimension(input_shape); - if (input_dyndim_idx < 0) { - return absl::nullopt; +std::optional XlaHelpers::GetDynamicReshapeInfo( + const xla::Shape& input_shape, absl::Span output_sizes) { + XLA_ASSIGN_OR_THROW(std::optional info, + SafeGetDynamicReshapeInfo(input_shape, output_sizes)); + return info; +} + +absl::StatusOr> +XlaHelpers::SafeGetDynamicReshapeInfo(const xla::Shape& input_shape, + absl::Span output_sizes) { + // Make sure `input_shape` has up to a single dynamic dimension. + XLA_ASSIGN_OR_RETURN(std::optional opt_input_dynamic_dimension, + CheckAtMostOneDynamicDimension(input_shape)); + // Short-circuit if no dynamic dimension was found. + if (!opt_input_dynamic_dimension.has_value()) { + return std::nullopt; } + + // From this point onwards, we know that `input_shape` has exactly 1 dynamic + // dimension. + int64_t input_dynamic_dimension = opt_input_dynamic_dimension.value(); + DynamicReshapeInfo info; - info.output_shape = - xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes); - if (info.output_shape.dimensions_size() > 0) { - int64_t size_prod_until_dyndim = 1; - for (int64_t i = 0; i <= input_dyndim_idx; ++i) { - size_prod_until_dyndim *= input_shape.dimensions(i); - } - int64_t dynamic_dimension = -1; - int64_t out_size = 1; - for (int64_t i = 0; i < output_sizes.size(); ++i) { - XLA_CHECK_LE(out_size, size_prod_until_dyndim / - input_shape.dimensions(input_dyndim_idx)) - << "Unable to map dynamic dimension of shape " << input_shape - << " to output sizes (" << absl::StrJoin(output_sizes, ", ") << ")"; - out_size *= output_sizes[i]; - if (out_size >= size_prod_until_dyndim) { - dynamic_dimension = i; - break; - } - } - XLA_CHECK(dynamic_dimension >= 0) - << "Unable to map dynamic dimension of shape " << input_shape - << " to output sizes (" << absl::StrJoin(output_sizes, ", ") << ")"; - info.dynamic_dimension = dynamic_dimension; + XLA_ASSIGN_OR_RETURN(info.output_shape, + xla::ShapeUtil::MakeValidatedShape( + input_shape.element_type(), output_sizes)); + + if (info.output_shape.dimensions().size() > 0) { + XLA_ASSIGN_OR_RETURN( + info.dynamic_dimension, + GetDynamicReshapeInfoOutputDynamicDimension( + input_shape.dimensions(), output_sizes, input_dynamic_dimension)); info.output_shape.set_dynamic_dimension(info.dynamic_dimension, true); } - return std::move(info); + + return info; } xla::Shape XlaHelpers::GetDynamicReshape( @@ -336,17 +547,25 @@ xla::Shape XlaHelpers::GetDynamicReshape( xla::XlaOp XlaHelpers::DynamicReshape(xla::XlaOp input, absl::Span output_sizes) { - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - bool is_output_sizes_unbounded_dynamic = std::any_of( - output_sizes.begin(), output_sizes.end(), - [](int64_t size) { return size == xla::Shape::kUnboundedSize; }); - XLA_CHECK(!is_output_sizes_unbounded_dynamic) - << "reshape operation does not support unbounded dynamic output shape."; - if (output_sizes == input_shape.dimensions()) { + XLA_ASSIGN_OR_THROW(xla::XlaOp output, + SafeDynamicReshape(input, output_sizes)); + return output; +} + +absl::StatusOr XlaHelpers::SafeDynamicReshape( + xla::XlaOp input, absl::Span output_sizes) { + XLA_CHECK_OK( + CheckDynamicReshapeAnyOfOutputSizesIsUnboundedDynamic(output_sizes)); + + XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull input_shape, + GetShape(input)); + if (output_sizes == input_shape->dimensions()) { return input; } - auto info = GetDynamicReshapeInfo(input_shape, output_sizes); - if (info) { + + XLA_ASSIGN_OR_RETURN(std::optional info, + SafeGetDynamicReshapeInfo(*input_shape, output_sizes)); + if (info.has_value()) { return xla::ReshapeWithInferredDimension(input, output_sizes, info->dynamic_dimension); } @@ -431,14 +650,27 @@ bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, shape1.dimensions() == shape2.dimensions(); } -xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* input_shape) { - runtime::util::MaybePtr input_shape_tmp(input_shape); - *input_shape_tmp = ShapeHelper::ShapeOfXlaOp(input); - if (input_shape_tmp->dimensions_size() == 1) { +xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* shape) { + if (shape != nullptr) { + XLA_ASSIGN_OR_THROW(const xla::Shape* absl_nonnull input_shape, + GetShape(input)); + *shape = *input_shape; + } + XLA_ASSIGN_OR_THROW(xla::XlaOp output, SafeFlatten(input)); + return output; +} + +absl::StatusOr XlaHelpers::SafeFlatten(xla::XlaOp input) { + XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull shape, GetShape(input)); + + if (shape->dimensions().size() == 1) { return input; } - int64_t input_elements = xla::ShapeUtil::ElementsIn(*input_shape_tmp); - return DynamicReshape(input, {input_elements}); + + XLA_ASSIGN_OR_RETURN( + xla::XlaOp output, + SafeDynamicReshape(input, {xla::ShapeUtil::ElementsIn(*shape)})); + return output; } xla::XlaOp XlaHelpers::FlattenDimRange(xla::XlaOp input, int64_t start, @@ -1074,7 +1306,7 @@ std::vector XlaHelpers::ExtractInputShardings( // entry computation is always the last computation for (const xla::HloInstructionProto& instr : computation.proto().computations().rbegin()->instructions()) { - if (instr.opcode() == HloOpcodeString(xla::HloOpcode::kParameter)) { + if (instr.opcode() == xla::HloOpcodeString(xla::HloOpcode::kParameter)) { const int64_t index = instr.parameter_number(); // we assume that parameter is ordered. XLA_CHECK_EQ(index, param_shardings.size()); diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index c01820dc27a..83d43a64672 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -1,24 +1,29 @@ #ifndef XLA_TORCH_XLA_CSRC_HELPERS_H_ #define XLA_TORCH_XLA_CSRC_HELPERS_H_ +#include +#include #include #include #include +#include #include #include #include #include -#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tsl/platform/bfloat16.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/permutation_util.h" +#include "xla/shape.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" @@ -36,7 +41,7 @@ class XlaHelpers { struct DynamicSize { xla::XlaOp size; - absl::optional scalar_size; + std::optional scalar_size; }; struct DynamicReshapeInfo { @@ -153,14 +158,46 @@ class XlaHelpers { shape.dimensions(), builder); } - static absl::optional GetDynamicReshapeInfo( - const xla::Shape& input_shape, absl::Span output_sizes); + // Computes the necessary information for reshaping `input_shape` into + // `output_sizes`, propagating the dynamic dimension (only one allowed), if + // necessary. + [[deprecated("Use SafeGetDynamicReshapeInfo for better error handling.")]] // + static std::optional + GetDynamicReshapeInfo(const xla::Shape& input_shape, + absl::Span output_sizes); + // Computes the necessary information for reshaping `input_shape` into + // `output_sizes`, propagating the dynamic dimension (only one allowed), + // if necessary. + // + // This function shall return an error status if: + // 1. `input_shape` has more than 1 dynamic dimension + // + // 2. The product of `output_sizes` overflows + // + // 3. In the presence of a dynamic shape in the input, we are unable to map + // it to any of the dimensions of the output + // + static absl::StatusOr> + SafeGetDynamicReshapeInfo(const xla::Shape& input_shape, + absl::Span output_sizes); static xla::Shape GetDynamicReshape(const xla::Shape& input_shape, absl::Span output_sizes); - static xla::XlaOp DynamicReshape(xla::XlaOp input, - absl::Span output_sizes); + // Reshapes `input`, so that its shape dimensions becomes `output_sizes`. + [[deprecated("Use SafeDynamicReshape for better error handling.")]] // + static xla::XlaOp + DynamicReshape(xla::XlaOp input, absl::Span output_sizes); + // Reshapes `input`, so that its shape dimensions becomes `output_sizes`. + // + // This function shall return an error status if: + // 1. There was a lowering error in the last `XlaBuilder::` call, where + // the `XlaBuilder` was used to create `input` + // + // 2. `SafeGetDynamicReshapeInfo()` call fails + // + static absl::StatusOr SafeDynamicReshape( + xla::XlaOp input, absl::Span output_sizes); static bool IsUnboundedDynamic(const xla::Shape& shape); @@ -210,7 +247,17 @@ class XlaHelpers { absl::Span padding); // Retrieves the dynamic dimension of an input shape, or returns -1 if none. - static int64_t GetDynamicDimension(const xla::Shape& shape); + [[deprecated( + "Use CheckAtMostOneDynamicDimension for better error " + "handling")]] // + static int64_t + GetDynamicDimension(const xla::Shape& shape); + // Check if `shape` has at most 1 dynamic dimension, and retrieves it. + // + // It shall return an error status if there's 2 or more dynamic dimensions. If + // `shape` has no dynamic dimensions, it returns a `std::nullopt`. + static absl::StatusOr> CheckAtMostOneDynamicDimension( + const xla::Shape& shape); static DynamicSize GetDimensionsSize(absl::Span inputs, absl::Span dimensions); @@ -242,8 +289,20 @@ class XlaHelpers { static xla::XlaOp ReshapeToRank(xla::XlaOp input, int64_t expected_rank, int64_t offset = 0); - static xla::XlaOp Flatten(xla::XlaOp input, - xla::Shape* input_shape = nullptr); + // Reshapes `input` into a flattened 1-dimensional tensor. + // Deprecated: if not null, `shape` is set to the shape of `input`. + [[deprecated("Use SafeFlatten for better error handling.")]] // + static xla::XlaOp + Flatten(xla::XlaOp input, xla::Shape* shape = nullptr); + // Reshapes `input` into a flattened 1-dimensional tensor. + // + // This function shall return an error status if: + // 1. There was a lowering error in the last `XlaBuilder::` call, where + // the `XlaBuilder` was used to create `input` + // + // 2. `SafeDynamicReshape()` function call fails. + // + static absl::StatusOr SafeFlatten(xla::XlaOp input); static xla::XlaOp FlattenDimRange(xla::XlaOp input, int64_t start, int64_t range, From 1528223314f3245fc8886f15993a5823f5d3d681 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Dec 2025 09:46:05 -0300 Subject: [PATCH 2/3] Address reviews. --- torch_xla/csrc/helpers.cpp | 140 ++++++++++++++++++++++--------------- torch_xla/csrc/helpers.h | 7 +- 2 files changed, 88 insertions(+), 59 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index cbd29bd4019..46d3f45bb74 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -98,14 +98,14 @@ xla::XlaComputation CreateMinMaxComputation(const std::string& name, // // For each element, this function will append to the output string: // 1. A star (*), if `needs_star(el)` returns true -// 2. The return value of `get_printable_data(el)` -template -std::string SpanToStringWithStar( - absl::Span span, FStar&& needs_star, - FPrint&& get_printable_data = [](T tee) -> T { return tee; }) { +// 2. The return value of `format_element(el)` +template +std::string SpanToStringWithStar(absl::Span span, + const NeedsStar& needs_star, + const FormatElement& format_element) { return absl::StrJoin( span, /* separator= */ ", ", [&](std::string* out, const T el) { - absl::StrAppend(out, needs_star(el) ? "*" : "", get_printable_data(el)); + absl::StrAppend(out, needs_star(el) ? "*" : "", format_element(el)); }); } @@ -114,22 +114,31 @@ std::string SpanToStringWithStar( // Similarly to `SpanToStringWithStar` function above, this function will also // append to the output string the 2 described items. However, with the // following changes: -// 1. The `needs_star` function will also depend on the index of the element -// 2. It will always print the element, instead of the index (i.e. the -// `get_printable_data` is already set) -template +// +// 1. `needs_star` function has 2 parameters: index and element. +// +// 2. There's no `format_element` parameter. This function provides one, by +// default, that prints the underlying size element. +template std::string EnumeratedSizesToStringWithStar(absl::Span sizes, - F&& needs_star) { + const NeedsStar& needs_star) { std::vector indices(sizes.size()); std::iota(indices.begin(), indices.end(), 0); return SpanToStringWithStar( absl::MakeConstSpan(indices), /* needs_star= */ [&, sizes](size_t i) -> bool { return needs_star(i, sizes[i]); }, - /* get_printable_data= */ - [&, sizes](size_t i) -> int64_t { return sizes[i]; }); + /* format_element= */ + [&, sizes](size_t i) -> std::string { return std::to_string(sizes[i]); }); } +// Explicit `std::string(*)(int64_t)` function for converting `int64_t` to +// +// `std::string`. This is needed so that `SpanToStringWithStar` manages to +// deduce the `FormatElement` type, since `std::string` has no overload for +// `int64_t`. +std::string Int64ToString(int64_t i) { return std::to_string(i); } + bool IsUnboundedDynamicSize(const int64_t size) { return size == xla::Shape::kUnboundedSize; } @@ -137,14 +146,13 @@ bool IsUnboundedDynamicSize(const int64_t size) { // Checks that none of the `output_sizes` are unbounded dynamic sizes. // // This function is exclusively called by `SafeDynamicReshape()`. -absl::Status CheckDynamicReshapeAnyOfOutputSizesIsUnboundedDynamic( +absl::Status CheckNoOutputSizesAreUnbounded( absl::Span output_sizes) { if (std::any_of(output_sizes.begin(), output_sizes.end(), IsUnboundedDynamicSize)) { std::string output_sizes_with_unbounded_mark_str = SpanToStringWithStar( - output_sizes, /* needs_star= */ [](int64_t size) -> bool { - return IsUnboundedDynamicSize(size); - }); + output_sizes, /* needs_star= */ IsUnboundedDynamicSize, + /* format_element= */ Int64ToString); return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( "Error when calling DynamicReshape() in the lower phase: expected " @@ -156,15 +164,37 @@ absl::Status CheckDynamicReshapeAnyOfOutputSizesIsUnboundedDynamic( return absl::OkStatus(); } +// Checks that `shape` has exactly 1 dynamic dimension. +// +// This function is exclusively called by `SafeGetDynamicReshapeInfo()`. +absl::StatusOr CheckAndGetExactlyOneDynamicDimension( + const xla::Shape& shape) { + XLA_ASSIGN_OR_RETURN(std::optional opt_dynamic_dimension, + XlaHelpers::CheckAtMostOneDynamicDimension(shape)); + + if (!opt_dynamic_dimension.has_value()) { + std::string shape_str = EnumeratedSizesToStringWithStar( + shape.dimensions(), /* needs_star= */ [&](size_t i, int64_t) { + return shape.is_dynamic_dimension(i); + }); + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "Error when calling GetDynamicReshapeInfo() in the lower phase: " + "expected exactly 1 dynamic dimension in the input shape ", + shape_str, "."))); + } + + return *opt_dynamic_dimension; +} + // Checks that when mapping the `input_dynamic_dimension`, we won't split a // dynamic dimension in the output. // -// This function is exclusively called by -// `GetDynamicReshapeInfoOutputDynamicDimension()`. -template -absl::Status CheckGetDynamicReshapeInfoDynamicDimensionNotSplit( - F&& get_error_prefix, int64_t input_dynamic_dimension, size_t i, - int64_t output_elements, int64_t input_elements_before_dynamic_dimension) { +// This function is exclusively called by `GetOutputDynamicDimension()`. +template +absl::Status CheckDynamicDimensionNotSplit( + const ErrorPrefix& get_error_prefix, int64_t input_dynamic_dimension, + size_t i, int64_t output_elements, + int64_t input_elements_before_dynamic_dimension) { if (output_elements > input_elements_before_dynamic_dimension) { return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( get_error_prefix(), @@ -180,11 +210,11 @@ absl::Status CheckGetDynamicReshapeInfoDynamicDimensionNotSplit( // Checks that we were able to find a dynamic dimension in the output. // -// This function is exclusively called by -// `GetDynamicReshapeInfoOutputDynamicDimension()`. -template -absl::Status CheckGetDynamicReshapeInfoFoundOutputDynamicDimension( - F&& get_error_prefix, std::optional output_dynamic_dimension) { +// This function is exclusively called by `GetOutputDynamicDimension()`. +template +absl::Status CheckFoundOutputDynamicDimension( + const ErrorPrefix& get_error_prefix, + std::optional output_dynamic_dimension) { if (!output_dynamic_dimension.has_value()) { return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( get_error_prefix(), @@ -196,8 +226,8 @@ absl::Status CheckGetDynamicReshapeInfoFoundOutputDynamicDimension( // Tries to map the `input_dynamic_dimension` onto `output_sizes`. If // successful, returns the mapped output dynamic dimension. // -// This function is exclusively called by `GetDynamicReshapeInfo()`. -absl::StatusOr GetDynamicReshapeInfoOutputDynamicDimension( +// This function is exclusively called by `SafeGetDynamicReshapeInfo()`. +absl::StatusOr GetOutputDynamicDimension( absl::Span input_sizes, absl::Span output_sizes, int64_t input_dynamic_dimension) { // Function for building the error prefix, which prints the `input_sizes` @@ -205,7 +235,9 @@ absl::StatusOr GetDynamicReshapeInfoOutputDynamicDimension( auto get_error_prefix = [=]() -> std::string { std::string input_sizes_str = EnumeratedSizesToStringWithStar( input_sizes, - [=](size_t i, int64_t) { return i == input_dynamic_dimension; }); + /* needs_star= */ [=](size_t i, int64_t) { + return i == input_dynamic_dimension; + }); std::string output_sizes_str = absl::StrJoin(output_sizes, /* separator= */ ", "); return absl::StrCat( @@ -225,7 +257,7 @@ absl::StatusOr GetDynamicReshapeInfoOutputDynamicDimension( int64_t output_elements = 1; for (size_t i = 0; i < output_sizes.size(); ++i) { - XLA_RETURN_IF_ERROR(CheckGetDynamicReshapeInfoDynamicDimensionNotSplit( + XLA_RETURN_IF_ERROR(CheckDynamicDimensionNotSplit( get_error_prefix, input_dynamic_dimension, i, output_elements, input_elements_before_dynamic_dimension)); output_elements *= output_sizes[i]; @@ -235,7 +267,7 @@ absl::StatusOr GetDynamicReshapeInfoOutputDynamicDimension( } } - XLA_RETURN_IF_ERROR(CheckGetDynamicReshapeInfoFoundOutputDynamicDimension( + XLA_RETURN_IF_ERROR(CheckFoundOutputDynamicDimension( get_error_prefix, output_dynamic_dimension)); return *output_dynamic_dimension; } @@ -500,36 +532,33 @@ xla::XlaOp XlaHelpers::ReshapeToRank(xla::XlaOp input, int64_t expected_rank, std::optional XlaHelpers::GetDynamicReshapeInfo( const xla::Shape& input_shape, absl::Span output_sizes) { - XLA_ASSIGN_OR_THROW(std::optional info, - SafeGetDynamicReshapeInfo(input_shape, output_sizes)); - return info; + if (input_shape.is_dynamic()) { + XLA_ASSIGN_OR_THROW(DynamicReshapeInfo info, + SafeGetDynamicReshapeInfo(input_shape, output_sizes)); + return info; + } + return std::nullopt; } -absl::StatusOr> +absl::StatusOr XlaHelpers::SafeGetDynamicReshapeInfo(const xla::Shape& input_shape, absl::Span output_sizes) { - // Make sure `input_shape` has up to a single dynamic dimension. - XLA_ASSIGN_OR_RETURN(std::optional opt_input_dynamic_dimension, - CheckAtMostOneDynamicDimension(input_shape)); - // Short-circuit if no dynamic dimension was found. - if (!opt_input_dynamic_dimension.has_value()) { - return std::nullopt; - } - - // From this point onwards, we know that `input_shape` has exactly 1 dynamic - // dimension. - int64_t input_dynamic_dimension = opt_input_dynamic_dimension.value(); - DynamicReshapeInfo info; XLA_ASSIGN_OR_RETURN(info.output_shape, xla::ShapeUtil::MakeValidatedShape( input_shape.element_type(), output_sizes)); + // Make sure `input_shape` has exactly 1 dynamic dimension. + absl::StatusOr exactly_1_dynamic_dimension_check = + CheckAndGetExactlyOneDynamicDimension(input_shape); + XLA_CHECK_OK(exactly_1_dynamic_dimension_check); + if (info.output_shape.dimensions().size() > 0) { XLA_ASSIGN_OR_RETURN( info.dynamic_dimension, - GetDynamicReshapeInfoOutputDynamicDimension( - input_shape.dimensions(), output_sizes, input_dynamic_dimension)); + GetOutputDynamicDimension(input_shape.dimensions(), output_sizes, + /* input_dynamic_dimension= */ + exactly_1_dynamic_dimension_check.value())); info.output_shape.set_dynamic_dimension(info.dynamic_dimension, true); } @@ -554,8 +583,7 @@ xla::XlaOp XlaHelpers::DynamicReshape(xla::XlaOp input, absl::StatusOr XlaHelpers::SafeDynamicReshape( xla::XlaOp input, absl::Span output_sizes) { - XLA_CHECK_OK( - CheckDynamicReshapeAnyOfOutputSizesIsUnboundedDynamic(output_sizes)); + XLA_CHECK_OK(CheckNoOutputSizesAreUnbounded(output_sizes)); XLA_ASSIGN_OR_RETURN(const xla::Shape* absl_nonnull input_shape, GetShape(input)); @@ -563,9 +591,9 @@ absl::StatusOr XlaHelpers::SafeDynamicReshape( return input; } - XLA_ASSIGN_OR_RETURN(std::optional info, - SafeGetDynamicReshapeInfo(*input_shape, output_sizes)); - if (info.has_value()) { + if (input_shape->is_dynamic()) { + XLA_ASSIGN_OR_RETURN(DynamicReshapeInfo info, + SafeGetDynamicReshapeInfo(*input_shape, output_sizes)); return xla::ReshapeWithInferredDimension(input, output_sizes, info->dynamic_dimension); } diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 83d43a64672..9715b259334 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -177,9 +177,10 @@ class XlaHelpers { // 3. In the presence of a dynamic shape in the input, we are unable to map // it to any of the dimensions of the output // - static absl::StatusOr> - SafeGetDynamicReshapeInfo(const xla::Shape& input_shape, - absl::Span output_sizes); + // Precondition: `input_shape` should have, at least, one dynamic dimension. + // Otherwise, it will crash! + static absl::StatusOr SafeGetDynamicReshapeInfo( + const xla::Shape& input_shape, absl::Span output_sizes); static xla::Shape GetDynamicReshape(const xla::Shape& input_shape, absl::Span output_sizes); From c29a3367b3923f1e3e89928ad5436333c47f0f25 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Dec 2025 13:25:18 -0300 Subject: [PATCH 3/3] Fix compilation error. --- torch_xla/csrc/helpers.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 46d3f45bb74..c2608bb0fe4 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -595,7 +595,7 @@ absl::StatusOr XlaHelpers::SafeDynamicReshape( XLA_ASSIGN_OR_RETURN(DynamicReshapeInfo info, SafeGetDynamicReshapeInfo(*input_shape, output_sizes)); return xla::ReshapeWithInferredDimension(input, output_sizes, - info->dynamic_dimension); + info.dynamic_dimension); } return xla::Reshape(input, output_sizes); }