diff --git a/test/test_operations.py b/test/test_operations.py index 9d377083da5..00329ede588 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2581,6 +2581,34 @@ def test_random__raises_error_on_value_out_of_type_value_range(self): "than the upper bound.") self.assertEqual(str(e), expected_error) + def test_mm_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + b = torch.rand(2, 2, device=device) + + try: + torch.mm(a, b) + except RuntimeError as e: + expected_error = ( + "mm(): expected the first input tensor f32[2,2,2] to be a " + "matrix (i.e. a 2D tensor).") + self.assertEqual(str(e), expected_error) + + def test_mm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + b = torch.rand(8, 2, device=device) + + try: + torch.mm(a, b) + except RuntimeError as e: + expected_error = ( + "mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. " + "Expected the size of dimension 1 of the first input tensor (5) " + "to be equal the size of dimension 0 of the second input " + "tensor (8).") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 86b4cb84707..c042d703aa3 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2495,7 +2495,9 @@ at::Tensor XLANativeFunctions::mm(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); - return bridge::AtenFromXlaTensor(tensor_methods::mm(xla_self, xla_mat2)); + XLA_ASSIGN_OR_THROW(XLATensorPtr output, + tensor_methods::mm(xla_self, xla_mat2)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index e7814ce517d..21f4db59713 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -11,6 +11,7 @@ #include #include "absl/log/absl_check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "torch_xla/csrc/LazyIr.h" @@ -453,14 +454,14 @@ absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input, return absl::OkStatus(); } -// Checks that all index dimensions are smaller or equal to those of input, -// except on dimension canonical_dim. -absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input, - const XLATensorPtr& index, - int64_t canonical_dim) { +// Checks that all index dimension sizes are smaller or equal to those of +// input, except on dimension canonical_dim. +absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input, + const XLATensorPtr& index, + int64_t canonical_dim) { // Dimensions that fail the "smaller or equal" condition. std::vector bad_dims; - for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) { + for (int64_t dim = 0; dim < input->shape().get().dimensions().size(); dim++) { if (dim != canonical_dim && input->size(dim) < index->size(dim)) { bad_dims.push_back(dim); } @@ -478,6 +479,33 @@ absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input, return absl::OkStatus(); } +absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat, + const std::string_view arg) { + xla::Shape shape = mat->shape(); + if (shape.dimensions().size() != 2) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("mm(): expected the ", arg, " input tensor ", + shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); + } + return absl::OkStatus(); +} + +absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1, + const XLATensorPtr& mat2) { + xla::Shape shape1 = mat1->shape(); + xla::Shape shape2 = mat2->shape(); + if (shape1.dimensions(1) != shape2.dimensions(0)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "mm(): cannot matrix-multiply tensors ", shape1.ToString(), " and ", + shape2.ToString(), + ". Expected the size of dimension 1 of the first input tensor (", + shape1.dimensions(1), + ") to be equal the size of dimension 0 of the second input tensor (", + shape2.dimensions(0), ")."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1844,7 +1872,7 @@ absl::StatusOr gather(const XLATensorPtr& input, dim, input->shape().get().dimensions_size()); XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index)); XLA_RETURN_IF_ERROR( - CheckGatherDimensionsAreCompatible(input, index, canonical_dim)); + CheckGatherSizesAreCompatible(input, index, canonical_dim)); return input->CreateFrom(torch_xla::MakeNode( input->GetIrValue(), canonical_dim, index->GetIrValue())); } @@ -2349,7 +2377,11 @@ XLATensorPtr mish(const XLATensorPtr& input) { tensor_ops::Softplus(input, 1, 20)->GetIrValue())); } -XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight) { +absl::StatusOr mm(const XLATensorPtr& input, + const XLATensorPtr& weight) { + XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first")); + XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second")); + XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight)); return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue())); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 597640bf4c4..b25b423d49c 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -646,7 +646,8 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices, XLATensorPtr mish(const XLATensorPtr& input); -XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight); +absl::StatusOr mm(const XLATensorPtr& input, + const XLATensorPtr& weight); XLATensorPtr mse_loss(const XLATensorPtr& input, const XLATensorPtr& target, int64_t reduction);