Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] load checkpoints gives zero weights when map_location is mps #85230

Open
m-reuter opened this issue Sep 18, 2022 · 4 comments
Open

[MPS] load checkpoints gives zero weights when map_location is mps #85230

m-reuter opened this issue Sep 18, 2022 · 4 comments
Labels
has workaround module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@m-reuter
Copy link

m-reuter commented Sep 18, 2022

馃悰 Describe the bug

When loading checkpoints of a larger model directly to MPS the weights are zero. For CPU it works. I tested it also with a tiny model and there it seems to work, so to replicate you need to download our weights (see code). It may be related to #79384 and #78551 . As a workaround it works to move model to CPU, set the weights there and then move everything to MPS. Here is an example to replicate:

import torch
import requests
# Download checkpoint file
ckpt='FastSurferVINN_training_state_coronal.pkl'
fileurl='https://b2share.fz-juelich.de/api/files/0114331a-f788-48d2-9d09-f85d7494ed48/FastSurferVINN_training_state_coronal.pkl'
response = requests.get(fileurl, verify=False)
with open(ckpt, 'wb') as f:
    f.write(response.content)

# CPU load works:
model_state = torch.load(ckpt, map_location="cpu")
print(model_state["model_state"]["inp_block.bn0.weight"])
# ouput: tensor([2.0432, 1.2577, 4.1133, 7.4062, 3.9921, 1.8011, 2.0956])

# MPS load gives zeros:
model_state = torch.load(ckpt, map_location="mps")
print(model_state["model_state"]["inp_block.bn0.weight"])
#output tensor([0., 0., 0., 0., 0., 0., 0.], device='mps:0')

Versions

Collecting environment information...
PyTorch version: 1.13.0.dev20220917
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.102)
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.6 (default, Aug 5 2022, 15:21:02) [Clang 14.0.0 (clang-1400.0.29.102)] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.2
[pip3] torch==1.13.0.dev20220917
[pip3] torchio==0.18.83
[pip3] torchvision==0.14.0.dev20220916
[conda] Could not collect

cc @kulinseth @albanD

@malfet malfet added module: mps Related to Apple Metal Performance Shaders framework has workaround module: correctness (silent) issue that returns an incorrect result silently labels Sep 20, 2022
@albanD
Copy link
Collaborator

albanD commented Sep 20, 2022

Hi,

What is the location of the original saved weights?
It does work for the simple examples I tried before and we even have a test for that:

pytorch/test/test_mps.py

Lines 6430 to 6441 in a4dca98

def test_serialization_map_location(self):
# Ensures that cpu Tensor can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2)
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="mps")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 20, 2022
@m-reuter
Copy link
Author

@albanD thanks for looking into this. Not sure what you are asking. I think they came from training on a cuda device. You can replicate the problem by downloading the weights and then loading them (see example). Thanks.

@09wakharet
Copy link

This issue still occurs when trying to load large models (200 MB+) with map_location = "mps". It silently fails and sets all model weights to zero. Sending the model to CPU, loading model weights to the CPU, and then sending the model back to MPS works.

@efeakaroz13
Copy link

@09wakharet having the same problem, can you tell how to send back to mps?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants