diff --git a/test/test_operations.py b/test/test_operations.py index a544d9ba19a..a4e5b2e1044 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2521,6 +2521,39 @@ def test_full_raises_error_on_negative_size(self): f"positive values. However found negative ones: {shape}.") self.assertEqual(str(e), expected_error) + def test_gather_raises_error_on_rank_mismatch(self): + S = 2 + + input = torch.arange(4, device=torch_xla.device()).view(S, S) + index = torch.randint(0, S, (S, S, S), device=torch_xla.device()) + dim = 1 + + try: + torch.gather(input, dim, index) + except RuntimeError as e: + expected_error = ( + "gather(): expected rank of input (2) and index (3) tensors " + "to be the same.") + self.assertEqual(str(e), expected_error) + + def test_gather_raises_error_on_invalid_index_size(self): + S = 2 + X = S + 2 + + input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S) + index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device()) + dim = 1 + + try: + torch.gather(input, dim, index) + except RuntimeError as e: + expected_error = ( + f"gather(): expected sizes of index [{X}, {S}, {X}, {S}] to be " + f"smaller or equal those of input [{S}, {S}, {S}, {S}] on all " + f"dimensions, except on dimension {dim}. " + "However, that's not true on dimensions [0, 2].") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index ceacd59603e..5a75936b0c3 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1865,9 +1865,9 @@ at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( + return bridge::AtenFromXlaTensor(GetValueOrThrow( tensor_methods::gather(GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - GetValueOrThrow(bridge::GetXlaTensor(index)))); + GetValueOrThrow(bridge::GetXlaTensor(index))))); } at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7534191f042..9c50db2f7bd 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -442,6 +442,43 @@ absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) { }); } +absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input, + const XLATensorPtr& index) { + int64_t input_rank = input->shape().get().dimensions_size(); + int64_t index_rank = index->shape().get().dimensions_size(); + if (input_rank != index_rank) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "gather(): expected rank of input (", input_rank, ") and index (", + index_rank, ") tensors to be the same."))); + } + return absl::OkStatus(); +} + +// Checks that all index dimensions are smaller or equal to those of input, +// except on dimension canonical_dim. +absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input, + const XLATensorPtr& index, + int64_t canonical_dim) { + // Dimensions that fail the "smaller or equal" condition. + std::vector bad_dims; + for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) { + if (dim != canonical_dim && input->size(dim) < index->size(dim)) { + bad_dims.push_back(dim); + } + } + if (!bad_dims.empty()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "gather(): expected sizes of index [", + absl::StrJoin(index->shape().get().dimensions(), /* sep= */ ", "), + "] to be smaller or equal those of input [", + absl::StrJoin(input->shape().get().dimensions(), /* sep= */ ", "), + "] on all dimensions, except on dimension ", canonical_dim, + ". However, that's not true on dimensions [", + absl::StrJoin(bad_dims, /* sep= */ ", "), "]."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1838,18 +1875,14 @@ absl::StatusOr full_symint( device, scalar_type); } -XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, - const XLATensorPtr& index) { - xla::Shape input_shape = input->shape(); - xla::Shape index_shape = index->shape(); - XLA_CHECK_EQ(input_shape.dimensions_size(), index_shape.dimensions_size()); +absl::StatusOr gather(const XLATensorPtr& input, + int64_t dim, + const XLATensorPtr& index) { int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex( - dim, input_shape.dimensions_size()); - for (size_t dim = 0; dim < input_shape.dimensions_size(); dim++) { - if (dim != canonical_dim) { - XLA_CHECK_LE(index->size(dim), input->size(dim)); - } - } + dim, input->shape().get().dimensions_size()); + XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index)); + XLA_RETURN_IF_ERROR( + CheckGatherDimensionsAreCompatible(input, index, canonical_dim)); return input->CreateFrom(torch_xla::MakeNode( input->GetIrValue(), canonical_dim, index->GetIrValue())); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 869dcaa8dff..3c957083357 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -470,8 +470,9 @@ absl::StatusOr full_symint( at::SymIntArrayRef sym_size, const at::Scalar& fill_value, const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); -XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, - const XLATensorPtr& index); +absl::StatusOr gather(const XLATensorPtr& input, + int64_t dim, + const XLATensorPtr& index); XLATensorPtr ge(const XLATensorPtr& input, const at::Scalar& other);