Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,6 +2581,34 @@ def test_random__raises_error_on_value_out_of_type_value_range(self):
"than the upper bound.")
self.assertEqual(str(e), expected_error)

def test_mm_raises_error_on_non_matrix_input(self):
device = torch_xla.device()
a = torch.rand(2, 2, 2, device=device)
b = torch.rand(2, 2, device=device)

try:
torch.mm(a, b)
except RuntimeError as e:
expected_error = (
"mm(): expected the first input tensor f32[2,2,2] to be a "
"matrix (i.e. a 2D tensor).")
self.assertEqual(str(e), expected_error)

def test_mm_raises_error_on_incompatible_shapes(self):
device = torch_xla.device()
a = torch.rand(2, 5, device=device)
b = torch.rand(8, 2, device=device)

try:
torch.mm(a, b)
except RuntimeError as e:
expected_error = (
"mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. "
"Expected the size of dimension 1 of the first input tensor (5) "
"to be equal the size of dimension 0 of the second input "
"tensor (8).")
self.assertEqual(str(e), expected_error)


class MNISTComparator(nn.Module):

Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2495,7 +2495,9 @@ at::Tensor XLANativeFunctions::mm(const at::Tensor& self,
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2));
return bridge::AtenFromXlaTensor(tensor_methods::mm(xla_self, xla_mat2));
XLA_ASSIGN_OR_THROW(XLATensorPtr output,
tensor_methods::mm(xla_self, xla_mat2));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self,
Expand Down
48 changes: 40 additions & 8 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <iterator>

#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "torch_xla/csrc/LazyIr.h"
Expand Down Expand Up @@ -453,14 +454,14 @@ absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input,
return absl::OkStatus();
}

// Checks that all index dimensions are smaller or equal to those of input,
// except on dimension canonical_dim.
absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
const XLATensorPtr& index,
int64_t canonical_dim) {
// Checks that all index dimension sizes are smaller or equal to those of
// input, except on dimension canonical_dim.
absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input,
const XLATensorPtr& index,
int64_t canonical_dim) {
// Dimensions that fail the "smaller or equal" condition.
std::vector<int64_t> bad_dims;
for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) {
for (int64_t dim = 0; dim < input->shape().get().dimensions().size(); dim++) {
if (dim != canonical_dim && input->size(dim) < index->size(dim)) {
bad_dims.push_back(dim);
}
Expand All @@ -478,6 +479,33 @@ absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
return absl::OkStatus();
}

absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat,
const std::string_view arg) {
xla::Shape shape = mat->shape();
if (shape.dimensions().size() != 2) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("mm(): expected the ", arg, " input tensor ",
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
}
return absl::OkStatus();
}

absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
const XLATensorPtr& mat2) {
xla::Shape shape1 = mat1->shape();
xla::Shape shape2 = mat2->shape();
if (shape1.dimensions(1) != shape2.dimensions(0)) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
"mm(): cannot matrix-multiply tensors ", shape1.ToString(), " and ",
shape2.ToString(),
". Expected the size of dimension 1 of the first input tensor (",
shape1.dimensions(1),
") to be equal the size of dimension 0 of the second input tensor (",
shape2.dimensions(0), ").")));
}
return absl::OkStatus();
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1844,7 +1872,7 @@ absl::StatusOr<absl_nonnull XLATensorPtr> gather(const XLATensorPtr& input,
dim, input->shape().get().dimensions_size());
XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index));
XLA_RETURN_IF_ERROR(
CheckGatherDimensionsAreCompatible(input, index, canonical_dim));
CheckGatherSizesAreCompatible(input, index, canonical_dim));
return input->CreateFrom(torch_xla::MakeNode<Gather>(
input->GetIrValue(), canonical_dim, index->GetIrValue()));
}
Expand Down Expand Up @@ -2349,7 +2377,11 @@ XLATensorPtr mish(const XLATensorPtr& input) {
tensor_ops::Softplus(input, 1, 20)->GetIrValue()));
}

XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight) {
absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
const XLATensorPtr& weight) {
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first"));
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second"));
XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight));
return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue()));
}

Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,8 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices,

XLATensorPtr mish(const XLATensorPtr& input);

XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight);
absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
const XLATensorPtr& weight);

XLATensorPtr mse_loss(const XLATensorPtr& input, const XLATensorPtr& target,
int64_t reduction);
Expand Down