Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS memory issue, MPS backend out of memory, but works if I empty the MPS cache #105839

Open
Vargol opened this issue Jul 24, 2023 · 12 comments
Open
Labels
module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Vargol
Copy link

Vargol commented Jul 24, 2023

馃悰 Describe the bug

There appears to be something wrong the the MPS cache, I appears that either its not releasing memory when it ideally should be, or the freeable memory in the cache is not being taken into account when the check for space occurs.
The issue occurs on the currently nightly, see versions, and 2.0.1

This issue affects performance at best and terminates an application at worse.

Here's an example...

from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
from torch import mps
import torch
import fp16fixes
import gc

fp16fixes.fp16_fixes()

pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)
pipe_prior.to("mps")
prompt = "A car exploding into colorful dust"
out = pipe_prior(prompt)
image_emb = out.image_embeds
zero_image_emb = out.negative_image_embeds

pipe_prior = None
gc.collect()
mps.empty_cache()

pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.to("mps")
pipe.enable_attention_slicing()

image = pipe(
    image_embeds=image_emb,
    negative_image_embeds=zero_image_emb,
    height=1024,
    width=1024,
    num_inference_steps=30,
).images

image[0].save("cat.png")

This works on a 8GB M1 Mac Mini without issue the two models run at

100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 25/25 [00:07<00:00,  3.15it/s]
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 30/30 [04:24<00:00,  8.82s/it]

Remove the mps.empty_cache() and it fails during the second model run

  0%|                                                                                                                                    | 0/30 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/8GB_M1_Diffusers_Scripts/sag/k2img.py", line 25, in <module>
    image = pipe(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py", line 272, in __call__
    noise_pred = self.unet(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 905, in forward
    sample, res_samples = downsample_block(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 1662, in forward
    hidden_states = attn(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 321, in forward
    return self.processor(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1590, in __call__
    attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 374, in get_attention_scores
    attention_probs = attention_scores.softmax(dim=-1)
RuntimeError: MPS backend out of memory (MPS allocated: 3.90 GB, other allocations: 4.94 GB, max allowed: 9.07 GB). Tried to allocate 387.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

If I reduce the height and width values to 512 it'll run to completion but the second model runs at 40 seconds per iter with a lot of swap file access. With the cache emptied manually it runs at around 2 seconds per iter.

the fp16fixes file is required to work around some issues with using fp16 on mps which fails with a broadcast error on 2.0.1 and fails with a bad image on the nightly I'm currently using. If I remove it the issue still occurs on the nightly.

% cat fp16fixes.py 
import torch

def fp16_fixes():
  if torch.backends.mps.is_available():
      torch.empty = torch.zeros

  _torch_layer_norm = torch.nn.functional.layer_norm
  def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
      if input.device.type == "mps" and input.dtype == torch.float16:
          input = input.float()
          if weight is not None:
              weight = weight.float()
          if bias is not None:
              bias = bias.float()
          return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
      else:
          return _torch_layer_norm(input, normalized_shape, weight, bias, eps)

  torch.nn.functional.layer_norm = new_layer_norm


  def new_torch_tensor_permute(input, *dims):
      result = torch.permute(input, tuple(dims))
      if input.device == "mps" and input.dtype == torch.float16:
          result = result.contiguous()
      return result

  torch.Tensor.permute = new_torch_tensor_permute

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230724
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.4.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: version 3.24.4
Libc version: N/A

Python version: 3.10.11 (main, Apr 8 2023, 02:11:11) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-13.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] numpy==1.25.1
[pip3] torch==2.1.0.dev20230724
[pip3] torchvision==0.15.2
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

@janeyx99 janeyx99 added module: memory usage PyTorch is using more memory than it should, or it is leaking memory triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework high priority labels Jul 24, 2023
@zou3519
Copy link
Contributor

zou3519 commented Aug 7, 2023

This feels like an edge case, we may not do anything here.

@gernophil
Copy link

gernophil commented Oct 25, 2023

Since using torchaudio 2.1.0 I also frequently get oom errors:

Hey everyone,
I have a small app that uses torchaudio. In the former version I used torchaudio 2.0.2 with MPS and it was able to process a longer then 30 minutes audio file on both my 16GB and 8GB RAM machines. The results were not good though.

I updated to torchaudio 2.1.0 and on my 16GB RAM machine it runs really good and the results are far better then with 2.0.2, but on my other 8GB RAM machine I now get the above mentioned oom error:

MPS backend out of memory (MPS allocated: 1.45 GB, other allocations: 7.42 GB, max allowed: 9.07 GB). Tried to allocate 563.62 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

Why is there such a big difference in memory allocation between 2.0.2 and 2.1.0? Of course, since the results changed there must have happened a lot with MPS, but does anyone have a workaround for this?
(originally from https://discuss.pytorch.org/t/mps-backend-out-of-memory/183879/5?u=gernophil)

@Vargol
Copy link
Author

Vargol commented Oct 25, 2023

Something is using 7.42Gb of memory on our 8Gb MPS is using 1.45Gb.
The high watermark is set to allow 9.07 GB of memory, you need to set this to with a high value between 0.0 and 2.0.
The default is 1.4 IIRC, so setting it higher allows more memory.

If 2.0 isn't enough setting it 0.0 will allow torch to use as much memory as needed, but all the memory usage in total (including other applications) over 8GB will come from swap. This will increase wear and tear on your system SSD (I'm assuming your on an Apple Silicon Mac) and could potentially crash the OS.

Having said that I've used it a fair bit set to 0.0 on my 8GB M1 and its not caused a system crash since they added the watermark level system to pytorch

@gernophil
Copy link

When I check the RAM usage right after I get this error, it tells me only 2GB of my system memory is in use. There should be enough left.

@raffetazarius
Copy link

I'm also running into this error a lot trying to do SDXL 1.0 generations on a Mac Pro 2019 Intel with an AMD 6900XT 16GB GPU. I can do about 5-10 generations and then get the "MPS backend out of memory" error.

See AUTOMATIC1111/stable-diffusion-webui#5461 (reply in thread) for more context.

It would be great to get this fixed in the next version of PyTorch so that Mac users can SD!

@raffetazarius
Copy link

Tested a1111 with PyTorch 2.3.0.dev20240103 today on my aforementioned Mac Intel + AMD GPU rig and am no longer getting this MPS Out of Memory error! Yay!

@gernophil
Copy link

gernophil commented Jan 9, 2024

Any chance we will also get this fix in the stable 2.2.1?

@gernophil
Copy link

I still get this error using torch==2.3.0.dev20240212 (torchaudio is still at 2.2.0 in the dev branch: torchaudio==2.2.0.dev20240212).

@tonytorm
Copy link

having this problem too, clearing the cache and trying to be parsimonious with memory allocation changes nothing

@shubham-attri
Copy link

I am having same issue while running it in my Jupyter Notebook locally on a Mac M2. There is a similar issue at Apple.

@ratkins
Copy link

ratkins commented Mar 23, 2024

@shubham-attri You mean you get a similar error to the one people are complaining about in this thread when you try and follow that Apple tutorial?

@smartsastram
Copy link

To follow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants