Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] .backward() encounters old nn.Embedding() weights #101198

Closed
rsbf opened this issue May 11, 2023 · 2 comments
Closed

[MPS] .backward() encounters old nn.Embedding() weights #101198

rsbf opened this issue May 11, 2023 · 2 comments
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@rsbf
Copy link

rsbf commented May 11, 2023

馃悰 Describe the bug

Code to reproduce:

import torch
from torch import nn
device = torch.device('mps')
torch.set_default_device(device)


input_batch = torch.tensor([[1, 3, 4, 4, 8],
                            [2, 5, 7, 7, 9]], device=device, dtype=torch.int64)

# forward and backward for embedding_dim == 36
emb_36_a = nn.Embedding(embedding_dim=36, num_embeddings=100, device=device)
res_36_a = emb_36_a(input_batch)
loss_36_a = torch.sum(res_36_a) - 1
loss_36_a.backward()

# forward and backward again, but with a new nn.Embedding() but same embedding_dim == 36,
# (and presumably new weights)
emb_36_b = nn.Embedding(embedding_dim=36, num_embeddings=100, device=device)
res_36_b = emb_36_b(input_batch)
loss_36_b = torch.sum(res_36_b) - 1
loss_36_b.backward()

# forward and backward, but with a new nn.Embedding() and increasing embedding_dim == 48
emb_48 = nn.Embedding(embedding_dim=48, num_embeddings=100, device=device)
res_48 = emb_48(input_batch)
loss_48 = torch.sum(res_48) - 1
loss_48.backward()  # <--- Error

The error (some directory information omitted):

-:8:10: error: invalid input tensor shape: updates tensor shape and data tensor shape must match along inner dimensions
-:8:10: note: see current operation: %5 = "mps.scatter_nd"(%0, %arg0, %4) {batch_dims = 0 : ui32, mode = 0 : i32} : (tensor<100x36xf32>, tensor<2x5x48xf32>, tensor<2x5x1xi64>) -> tensor<100x36xf32>
/<directory>/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1267: failed assertion `Error: MLIR pass manager failed'

Process finished with exit code 134 (interrupted by signal 6: SIGABRT)

It works fine when device = torch.device('cpu').
Note the 2nd line of the error, where it seems the underlying MPS operation was working with a tensor of size <100x36> (i.e the dimensions of emb_36_a and emb_36_b's underlying weights), but I'd expect it to be <100x48> (the dims of emb_48's weights).

Versions

PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.6 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.8.11 (default, Feb 22 2022, 14:24:07) [Clang 13.0.0 (clang-1300.0.27.3)] (64-bit runtime)
Python platform: macOS-12.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.0.1
[conda] No relevant packages

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels May 12, 2023
@rsbf
Copy link
Author

rsbf commented May 16, 2023

Same error in macOS 13.3.1, except the tensor dims aren't mentioned in the error message.

@qqaatw
Copy link
Collaborator

qqaatw commented May 19, 2023

It's a graph cache problem. Thanks for reporting.

qqaatw added a commit that referenced this issue May 19, 2023
qqaatw added a commit that referenced this issue May 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants