Skip to content

Conversation

ysiraichi
Copy link
Collaborator

This PR refactors the mm operation implementation by improving its error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::mm return Status
  • Refactor XLANativeFunctions::mm overloads to handle the status values
  • Improve error messages and error handling

Example 1: input is not a matrix

a = torch.rand(2, 4, 8, device=device)
b = torch.rand(8, 2, device=device)
torch.mm(a, b)

Before:

Traceback (most recent call last):
  File "examples/matmul.py", line 25, in <module>
    torch.mm(a, b)
RuntimeError: Cannot infer shape for dot operation: f32[2,4,8] <dot> f32[8,2]. Contracting dimension sizes are not compatible.

Status Propagation Trace:
    From: ShapeOfXlaOp at torch_xla/csrc/shape_helper.cpp:9

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/matmul.py", line 25, in <module>
    torch.mm(a, b)
RuntimeError: mm(): expected the first input tensor f32[2,4,8] to be a matrix (i.e. a 2D tensor).

Status Propagation Trace:
    From: CheckMMInputIsMatrix at torch_xla/csrc/tensor_methods.cpp:486 (error: mm(): expected the first input tensor f32[2,4,8] to be a matrix (i.e. a 2D tensor).)
    From: mm at torch_xla/csrc/tensor_methods.cpp:2381
    From: mm at torch_xla/csrc/aten_xla_type.cpp:2498

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Example 2: incompatible shapes

a = torch.rand(2, 4, device=device)
b = torch.rand(8, 2, device=device)
torch.mm(a, b)

Before:

Traceback (most recent call last):
  File "examples/matmul.py", line 25, in <module>
    torch.mm(a, b)
RuntimeError: Cannot infer shape for dot operation: f32[2,4] <dot> f32[8,2]. Contracting dimension sizes are not compatible.

Status Propagation Trace:
    From: ShapeOfXlaOp at torch_xla/csrc/shape_helper.cpp:9

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/matmul.py", line 25, in <module>
    torch.mm(a, b)
RuntimeError: mm(): cannot matrix-multiply tensors f32[2,4] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (4) to be equal the size of dimension 0 of the second input tensor (8).

Status Propagation Trace:
    From: CheckMMMatrixSizesAreCompatible at torch_xla/csrc/tensor_methods.cpp:498 (error: mm(): cannot matrix-multiply tensors f32[2,4] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (4) to be equal the size of dimension 0 of the second input tensor (8).)
    From: mm at torch_xla/csrc/tensor_methods.cpp:2383
    From: mm at torch_xla/csrc/aten_xla_type.cpp:2498

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

@ysiraichi ysiraichi merged commit 92dcabc into master Sep 5, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants