Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.load() fails on MPS backend ("don't know how to restore data location") #79384

Closed
Birch-san opened this issue Jun 12, 2022 · 4 comments
Closed
Assignees
Labels
module: mps Related to Apple Metal Performance Shaders framework module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Birch-san
Copy link

Birch-san commented Jun 12, 2022

馃悰 Describe the bug

# warning: 5.8GB file
wget https://huggingface.co/Cene655/ImagenT5-3B/resolve/main/model.pt
import torch
torch.load('./model.pt', map_location='mps')

Error thrown from serialization.py:

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps)
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 178, in default_restore_location
    raise RuntimeError("don't know how to restore data location of "
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 970, in restore_location
    return default_restore_location(storage, map_location)
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 1001, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 1019, in persistent_load
    load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 1049, in _load
    result = unpickler.load()
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 712, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/Users/birch/git/imagen-pytorch-cene/repro.py", line 2, in <module>
    torch.load('./ImagenT5-3B/model.pt', map_location='mps')
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 268, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

I think the solution will involve adding a register_package() entry for the mps backend.

Versions

PyTorch version: 1.13.0.dev20220610
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.9.12 (main, Jun  1 2022, 06:34:44)  [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] imagen-pytorch==0.0.0
[pip3] numpy==1.22.4
[pip3] torch==1.13.0.dev20220610
[pip3] torchaudio==0.14.0.dev20220603
[pip3] torchvision==0.14.0.dev20220609
[conda] numpy                     1.23.0rc2                pypi_0    pypi
[conda] torch                     1.13.0.dev20220606          pypi_0    pypi
[conda] torchaudio                0.14.0.dev20220603          pypi_0    pypi
[conda] torchvision               0.14.0a0+f9f721d          pypi_0    pypi

cc @mruberry @kulinseth @albanD

@mikaylagawarecki mikaylagawarecki added module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 13, 2022
@albanD albanD added the module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects label Jun 13, 2022
@albanD albanD self-assigned this Jun 13, 2022
facebook-github-bot pushed a commit that referenced this issue Jun 16, 2022
Summary:
Fix #79384

Pull Request resolved: #79465
Approved by: https://github.com/kulinseth, https://github.com/malfet

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/64c2a275c4d463b936b9469da948a666e016bbb8

Reviewed By: osalpekar

Differential Revision: D37156509

Pulled By: malfet

fbshipit-source-id: 3c7ece64b0b519662bc7e5f19873bf579c6ffd93
@aw632
Copy link

aw632 commented Jul 6, 2022

@albanD The commit seems to have been reverted.
Are there still any plans to fix this issue? ce6ce74

@albanD
Copy link
Collaborator

albanD commented Jul 8, 2022

Hi,

This was reverted and re-landed again in 0a651a2 so this should be properly fixed on master. Do you have any issue with this on latest nightly?

@aw632
Copy link

aw632 commented Jul 8, 2022

Hi,

This was reverted and re-landed again in 0a651a2 so this should be properly fixed on master. Do you have any issue with this on latest nightly?

I'm unable to get Nightly installed with the command conda update pytorch torchvision torchaudio -c pytorch-nightly - stable works fine though. When I install, it's telling me the packages are already available (# All requested packages already installed.) - but I'm experiencing the same mps issue so I don't think the nightly packages have truly been installed. torch.__version__ is 1.12.0.

Edit: I created a new conda environment. I installed all but torchaudio (which I can't install from nightly due to it lacking an osx-arm64 build). Can confirm it works on nightly!

@Ziyan-swu
Copy link

鎴睆2022-08-06 17 24 31

I came up with the same problem,I don't know how to deal with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants