From efbabe0b624d8b71fee59be53cc9b4f191b9143a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 19 Aug 2025 16:25:08 -0300 Subject: [PATCH] Improve error handling for `random_` operation. --- test/test_operations.py | 27 ++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 54 ++++++++++++++++++++----------- torch_xla/csrc/tensor_methods.cpp | 9 ++++-- torch_xla/csrc/tensor_methods.h | 2 +- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a4e5b2e1044..9d377083da5 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2554,6 +2554,33 @@ def test_gather_raises_error_on_invalid_index_size(self): "However, that's not true on dimensions [0, 2].") self.assertEqual(str(e), expected_error) + def test_random__raises_error_on_empty_interval(self): + a = torch.empty(10, device=torch_xla.device()) + from_ = 3 + to_ = 1 + + try: + a.random_(from_, to_) + except RuntimeError as e: + expected_error = ( + f"random_(): expected `from` ({from_}) to be smaller than " + f"`to` ({to_}).") + self.assertEqual(str(e), expected_error) + + def test_random__raises_error_on_value_out_of_type_value_range(self): + a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16) + from_ = 3 + to_ = 65504 + 1 + + try: + a.random_(from_, to_) + except RuntimeError as e: + expected_error = ( + f"random_(): expected `to` to be within the range " + f"[-65504, 65504]. However got value {to_}, which is greater " + "than the upper bound.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5a75936b0c3..9606a989f83 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -18,6 +18,7 @@ #include #include "absl/log/absl_check.h" +#include "status.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/shape_inference.h" #include "torch/csrc/lazy/core/tensor_util.h" @@ -317,18 +318,27 @@ int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) { } } -void CheckRangeValues(torch::ScalarType dtype, int64_t from, int64_t to) { - XlaHelpers::MinMax min_max; - // Bound the min_max by int64_t since types of "from" and "to" are int64. - if (IsTypeWithLargerRangeThanLong(dtype)) { - min_max = XlaHelpers::MinMaxValues(xla::PrimitiveType::S64); - } else { - min_max = XlaHelpers::MinMaxValues(XlaTypeFromTorchType(dtype)); +absl::Status CheckValueWithinTypeRange(const std::string_view op, + const std::string_view arg, + torch::ScalarType dtype, int64_t value) { + xla::PrimitiveType type = IsTypeWithLargerRangeThanLong(dtype) + ? xla::PrimitiveType::S64 + : XlaTypeFromTorchType(dtype); + + XlaHelpers::MinMax mm = XlaHelpers::MinMaxValues(type); + int64_t min = mm.min.toLong(); + int64_t max = mm.max.toLong(); + + if (value < min || value > max) { + const std::string_view comparison = value < min ? "lower" : "greater"; + const std::string_view bound = value < min ? "lower bound" : "upper bound"; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat(op, "(): expected `", arg, "` to be within the range [", + min, ", ", max, "]. However got value ", value, + ", which is ", comparison, " than the ", bound, "."))); } - XLA_CHECK_GE(from, min_max.min.toLong()); - XLA_CHECK_LE(from, min_max.max.toLong()); - XLA_CHECK_GE(to, min_max.min.toLong()); - XLA_CHECK_LE(to, min_max.max.toLong()); + + return absl::OkStatus(); } std::pair GetBinaryOperands( @@ -3025,12 +3035,14 @@ at::Tensor& XLANativeFunctions::random_( } XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); at::ScalarType dtype = self_tensor->dtype(); + // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType(dtype) + inc; - XLA_CHECK_LE(from, to_val); - CheckRangeValues(self_tensor->dtype(), from, to_val - 1); - tensor_methods::random_(self_tensor, from, to_val); + + OkOrThrow(CheckValueWithinTypeRange("random_", "from", dtype, from)); + OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to_val - 1)); + OkOrThrow(tensor_methods::random_(self_tensor, from, to_val)); return self; } @@ -3043,10 +3055,12 @@ at::Tensor& XLANativeFunctions::random_( ATEN_OP2(random_, to)>::call(self, to, generator); } + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLA_CHECK_GT(to, 0); - CheckRangeValues(self_tensor->dtype(), 0, to - 1); - tensor_methods::random_(self_tensor, 0, to); + at::ScalarType dtype = self_tensor->dtype(); + + OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to - 1)); + OkOrThrow(tensor_methods::random_(self_tensor, 0, to)); return self; } @@ -3060,10 +3074,12 @@ at::Tensor& XLANativeFunctions::random_( } XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); at::ScalarType dtype = self_tensor->dtype(); + // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; - tensor_methods::random_(self_tensor, 0, - GetIntegerUpperLimitForType(dtype) + inc); + int64_t to_val = GetIntegerUpperLimitForType(dtype) + inc; + + OkOrThrow(tensor_methods::random_(self_tensor, 0, to_val)); return self; } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 9c50db2f7bd..2786ca1718b 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2922,8 +2922,12 @@ XLATensorPtr dynamic_view(const XLATensorPtr& input, ////////////////////////////////////////////////////////////////////////////// -void random_(XLATensorPtr& input, int64_t from, int64_t to) { - XLA_CHECK_LE(from, to); +absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to) { + if (from >= to) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("random_(): expected `from` (", from, + ") to be smaller than `to` (", to, ")."))); + } auto input_shape = input->shape(); input->SetInPlaceIrValue(torch_xla::MakeNode( XLAGraphExecutor::Get()->GetIrValueForScalar( @@ -2931,6 +2935,7 @@ void random_(XLATensorPtr& input, int64_t from, int64_t to) { XLAGraphExecutor::Get()->GetIrValueForScalar(to, xla::PrimitiveType::S64, input->GetDevice()), XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); + return absl::OkStatus(); } XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 3c957083357..c28d7f2165e 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -776,7 +776,7 @@ void put_(XLATensorPtr& input, const XLATensorPtr& index, std::tuple qr(const XLATensorPtr& input, bool some); -void random_(XLATensorPtr& input, int64_t from, int64_t to); +absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to); XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device, at::ScalarType scalar_type);