Skip to content

torch.load() fails on MPS backend ("don't know how to restore data location") #79384

@Birch-san

Description

@Birch-san

🐛 Describe the bug

# warning: 5.8GB file
wget https://huggingface.co/Cene655/ImagenT5-3B/resolve/main/model.pt
import torch
torch.load('./model.pt', map_location='mps')

Error thrown from serialization.py:

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps)
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 178, in default_restore_location
    raise RuntimeError("don't know how to restore data location of "
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 970, in restore_location
    return default_restore_location(storage, map_location)
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 1001, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 1019, in persistent_load
    load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 1049, in _load
    result = unpickler.load()
  File "/Users/birch/git/imagen-pytorch-cene/venv/lib/python3.9/site-packages/torch/serialization.py", line 712, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/Users/birch/git/imagen-pytorch-cene/repro.py", line 2, in <module>
    torch.load('./ImagenT5-3B/model.pt', map_location='mps')
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 268, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

I think the solution will involve adding a register_package() entry for the mps backend.

Versions

PyTorch version: 1.13.0.dev20220610
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.9.12 (main, Jun  1 2022, 06:34:44)  [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] imagen-pytorch==0.0.0
[pip3] numpy==1.22.4
[pip3] torch==1.13.0.dev20220610
[pip3] torchaudio==0.14.0.dev20220603
[pip3] torchvision==0.14.0.dev20220609
[conda] numpy                     1.23.0rc2                pypi_0    pypi
[conda] torch                     1.13.0.dev20220606          pypi_0    pypi
[conda] torchaudio                0.14.0.dev20220603          pypi_0    pypi
[conda] torchvision               0.14.0a0+f9f721d          pypi_0    pypi

cc @mruberry @kulinseth @albanD

Metadata

Metadata

Assignees

Labels

module: mpsRelated to Apple Metal Performance Shaders frameworkmodule: serializationIssues related to serialization (e.g., via pickle, or otherwise) of PyTorch objectstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions