-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Bug Description
torch.nn.functional.linear produces incorrect numerical results on MPS when the weight tensor is non-contiguous, but works correctly on CPU.
To Reproduce
import torch
import einops
# Create test tensors on MPS
device = 'mps'
W = torch.randn(12, 64, 768, device=device)
x = torch.randn(1, 3, 768, device=device)
bias = torch.randn(768, device=device)
# Create non-contiguous weight via rearrange
w_noncontig = einops.rearrange(W, "h d m -> m (h d)")
w_contig = w_noncontig.contiguous()
print(f"Weight contiguous: {w_contig.is_contiguous()}")
print(f"Weight non-contiguous: {w_noncontig.is_contiguous()}")
# These should be identical but aren't on MPS
result1 = torch.nn.functional.linear(x, w_noncontig, bias)
result2 = torch.nn.functional.linear(x, w_contig, bias)
print(f"Results match: {torch.allclose(result1, result2, atol=1e-5)}")
print(f"Max difference: {torch.abs(result1 - result2).max()}")
# Compare with CPU (works correctly)
result_cpu_noncontig = torch.nn.functional.linear(x.cpu(), w_noncontig.cpu(), bias.cpu())
result_cpu_contig = torch.nn.functional.linear(x.cpu(), w_contig.cpu(), bias.cpu())
print(f"CPU contiguous vs non-contiguous match: {torch.allclose(result_cpu_noncontig, result_cpu_contig, atol=1e-5)}")Versions
PyTorch version: 2.8.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.6.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.6.1-arm64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Pro
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.8.0
[pip3] torchvision==0.23.0
[conda] complexpytorch 0.4 pypi_0 pypi
[conda] numpy 1.24.3 py311hb57d4eb_0
[conda] numpy-base 1.24.3 py311h1d85a46_0
[conda] numpydoc 1.5.0 py311hca03da5_0
[conda] pytorch 2.2.2 py3.11_0 pytorch
[conda] tbb 2021.8.0 h48ca7d4_0
[conda] torch 2.2.2 pypi_0 pypi
[conda] torchvision 0.17.2 py311_cpu pytorch