Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,6 +2554,33 @@ def test_gather_raises_error_on_invalid_index_size(self):
"However, that's not true on dimensions [0, 2].")
self.assertEqual(str(e), expected_error)

def test_random__raises_error_on_empty_interval(self):
a = torch.empty(10, device=torch_xla.device())
from_ = 3
to_ = 1

try:
a.random_(from_, to_)
except RuntimeError as e:
expected_error = (
f"random_(): expected `from` ({from_}) to be smaller than "
f"`to` ({to_}).")
self.assertEqual(str(e), expected_error)

def test_random__raises_error_on_value_out_of_type_value_range(self):
a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16)
from_ = 3
to_ = 65504 + 1

try:
a.random_(from_, to_)
except RuntimeError as e:
expected_error = (
f"random_(): expected `to` to be within the range "
f"[-65504, 65504]. However got value {to_}, which is greater "
"than the upper bound.")
self.assertEqual(str(e), expected_error)


class MNISTComparator(nn.Module):

Expand Down
54 changes: 35 additions & 19 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <optional>

#include "absl/log/absl_check.h"
#include "status.h"
#include "torch/csrc/lazy/core/helpers.h"
#include "torch/csrc/lazy/core/shape_inference.h"
#include "torch/csrc/lazy/core/tensor_util.h"
Expand Down Expand Up @@ -317,18 +318,27 @@ int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) {
}
}

void CheckRangeValues(torch::ScalarType dtype, int64_t from, int64_t to) {
XlaHelpers::MinMax min_max;
// Bound the min_max by int64_t since types of "from" and "to" are int64.
if (IsTypeWithLargerRangeThanLong(dtype)) {
min_max = XlaHelpers::MinMaxValues(xla::PrimitiveType::S64);
} else {
min_max = XlaHelpers::MinMaxValues(XlaTypeFromTorchType(dtype));
absl::Status CheckValueWithinTypeRange(const std::string_view op,
const std::string_view arg,
torch::ScalarType dtype, int64_t value) {
xla::PrimitiveType type = IsTypeWithLargerRangeThanLong(dtype)
? xla::PrimitiveType::S64
: XlaTypeFromTorchType(dtype);

XlaHelpers::MinMax mm = XlaHelpers::MinMaxValues(type);
int64_t min = mm.min.toLong();
int64_t max = mm.max.toLong();

if (value < min || value > max) {
const std::string_view comparison = value < min ? "lower" : "greater";
const std::string_view bound = value < min ? "lower bound" : "upper bound";
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat(op, "(): expected `", arg, "` to be within the range [",
min, ", ", max, "]. However got value ", value,
", which is ", comparison, " than the ", bound, ".")));
}
XLA_CHECK_GE(from, min_max.min.toLong());
XLA_CHECK_LE(from, min_max.max.toLong());
XLA_CHECK_GE(to, min_max.min.toLong());
XLA_CHECK_LE(to, min_max.max.toLong());

return absl::OkStatus();
}

std::pair<XLATensorPtr, XLATensorPtr> GetBinaryOperands(
Expand Down Expand Up @@ -3025,12 +3035,14 @@ at::Tensor& XLANativeFunctions::random_(
}
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
at::ScalarType dtype = self_tensor->dtype();

// Prevent "to_val" from overflowing with at::ScalarType::Long.
int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1;
int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType(dtype) + inc;
XLA_CHECK_LE(from, to_val);
CheckRangeValues(self_tensor->dtype(), from, to_val - 1);
tensor_methods::random_(self_tensor, from, to_val);

OkOrThrow(CheckValueWithinTypeRange("random_", "from", dtype, from));
OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to_val - 1));
OkOrThrow(tensor_methods::random_(self_tensor, from, to_val));
return self;
}

Expand All @@ -3043,10 +3055,12 @@ at::Tensor& XLANativeFunctions::random_(
ATEN_OP2(random_, to)>::call(self, to,
generator);
}

XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
XLA_CHECK_GT(to, 0);
CheckRangeValues(self_tensor->dtype(), 0, to - 1);
tensor_methods::random_(self_tensor, 0, to);
at::ScalarType dtype = self_tensor->dtype();

OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to - 1));
OkOrThrow(tensor_methods::random_(self_tensor, 0, to));
return self;
}

Expand All @@ -3060,10 +3074,12 @@ at::Tensor& XLANativeFunctions::random_(
}
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
at::ScalarType dtype = self_tensor->dtype();

// Prevent "to_val" from overflowing with at::ScalarType::Long.
int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1;
tensor_methods::random_(self_tensor, 0,
GetIntegerUpperLimitForType(dtype) + inc);
int64_t to_val = GetIntegerUpperLimitForType(dtype) + inc;

OkOrThrow(tensor_methods::random_(self_tensor, 0, to_val));
return self;
}

Expand Down
9 changes: 7 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2922,15 +2922,20 @@ XLATensorPtr dynamic_view(const XLATensorPtr& input,

//////////////////////////////////////////////////////////////////////////////

void random_(XLATensorPtr& input, int64_t from, int64_t to) {
XLA_CHECK_LE(from, to);
absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to) {
if (from >= to) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("random_(): expected `from` (", from,
") to be smaller than `to` (", to, ").")));
}
auto input_shape = input->shape();
input->SetInPlaceIrValue(torch_xla::MakeNode<DiscreteUniform>(
XLAGraphExecutor::Get()->GetIrValueForScalar(
from, xla::PrimitiveType::S64, input->GetDevice()),
XLAGraphExecutor::Get()->GetIrValueForScalar(to, xla::PrimitiveType::S64,
input->GetDevice()),
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
return absl::OkStatus();
}

XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ void put_(XLATensorPtr& input, const XLATensorPtr& index,

std::tuple<XLATensorPtr, XLATensorPtr> qr(const XLATensorPtr& input, bool some);

void random_(XLATensorPtr& input, int64_t from, int64_t to);
absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to);

XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type);
Expand Down