Skip to content

tensor serialization on mps requires detach() (after cpu()) #90532

@HannesGitH

Description

@HannesGitH

🐛 Describe the bug

when serializing a tensor on an mps device it fails with RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead

in my case i was running textual inversion on stable diffusion (webui) and when training a new embedding it fails with

Traceback (most recent call last):
  File "stable-diffusion-webui/modules/textual_inversion/textual_inversion.py", line 357, in train_embedding
    save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
  File "table-diffusion-webui/modules/textual_inversion/textual_inversion.py", line 470, in save_embedding
    embedding.save(filename)
  File "stable-diffusion-webui/modules/textual_inversion/textual_inversion.py", line 40, in save
    torch.save(embedding_data, filename)
  File "stable-diffusion-webui/venv/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "stable-diffusion-webui/venv/lib/python3.8/site-packages/torch/serialization.py", line 589, in _save
    pickler.dump(obj)
  File "stable-diffusion-webui/venv/lib/python3.8/site-packages/torch/_tensor.py", line 177, in __reduce_ex__
    return self._reduce_ex_internal(proto)
  File "stable-diffusion-webui/venv/lib/python3.8/site-packages/torch/_tensor.py", line 223, in _reduce_ex_internal
    return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

i'm not totally certain whether this is an issue with pytorch or stable diffusion (or me) but at least i could fix it within pytorch by applying what the error message said, i'll open a PR for that in a sec in case you also think it might be torches fault

(the version output below is from within the venv of course)

Versions

PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.24.1
Libc version: N/A

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:14)  [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-13.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.3
[pip3] open-clip-torch==2.7.0
[pip3] pytorch-lightning==1.7.6
[pip3] torch==1.12.1
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.11.0
[pip3] torchsde==0.2.5
[pip3] torchvision==0.13.1
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] open-clip-torch           2.7.0                    pypi_0    pypi
[conda] pytorch-lightning         1.7.7                    pypi_0    pypi
[conda] torch                     1.12.1                   pypi_0    pypi
[conda] torchdiffeq               0.2.3                    pypi_0    pypi
[conda] torchmetrics              0.11.0                   pypi_0    pypi
[conda] torchsde                  0.2.5                    pypi_0    pypi
[conda] torchvision               0.13.1                   pypi_0    pypi

cc @mruberry @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: mpsRelated to Apple Metal Performance Shaders frameworkmodule: serializationIssues related to serialization (e.g., via pickle, or otherwise) of PyTorch objectstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions