Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Undefined type BFloat16 from matmul above certain sizes (maybe MPS only) #121583

Closed
Vargol opened this issue Mar 9, 2024 · 2 comments
Assignees
Labels
module: bfloat16 module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@Vargol
Copy link

Vargol commented Mar 9, 2024

馃悰 Describe the bug

When using torch.matmul (and the @ operator) on bfloat16 tensors above a certain size I'm getting the error
RuntimeError: Undefined type BFloat16

Example code

import torch

sdxl_latent_rgb_factors = torch.tensor(
            [
                #   R        G        B
                [0.3816, 0.4930, 0.5320],
                [-0.3753, 0.1631, 0.1739],
                [0.1770, 0.3588, -0.2048],
                [-0.4350, -0.2644, -0.4289],
            ],
            dtype=torch.bfloat16,
            device='mps',
        )

print('-------181,181,4--works----------')

samples = torch.zeros(  181 , 181, 4, 
                     dtype=torch.bfloat16,
                    device='mps',
)

x = torch.matmul(samples, sdxl_latent_rgb_factors)

print('-------182,182,4--fails--------')

samples2 = torch.zeros(  182 , 182, 4, 
                     dtype=torch.bfloat16,
                    device='mps',
)

x = torch.matmul(samples2, sdxl_latent_rgb_factors)

results in

 python ../Diffusers/matmul2.py 
-------181,181,4--works----------
-------182,182,4--fails--------
Traceback (most recent call last):
  File "/Volumes/SSD2TB/AI/Cascade/../Diffusers/matmul2.py", line 31, in <module>
    x = torch.matmul(samples2, sdxl_latent_rgb_factors)
RuntimeError: Undefined type BFloat16

Versions

Collecting environment information...
PyTorch version: 2.3.0.dev20240309
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.24.4
Libc version: N/A

Python version: 3.10.13 (main, Nov  9 2023, 13:59:31) [Clang 15.0.0 (clang-1500.0.40.1)] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.3.0.dev20240309
[pip3] torchaudio==2.2.0.dev20240301
[pip3] torchvision==0.18.0.dev20240309
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr

@albanD albanD added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: bfloat16 module: mps Related to Apple Metal Performance Shaders framework labels Mar 11, 2024
@malfet malfet self-assigned this Mar 11, 2024
@malfet
Copy link
Contributor

malfet commented Mar 11, 2024

Ok, I know the problem: MPS supports BFloat16, but metal is not, and as MPS matmul is buggy for large matrices, there is a naive Metal matmul implementation, that yet not support bf16

malfet added a commit that referenced this issue Mar 12, 2024
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes #121583
malfet added a commit that referenced this issue Mar 12, 2024
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes #121583
malfet added a commit that referenced this issue Mar 13, 2024
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes #121583
pytorchbot pushed a commit that referenced this issue Apr 3, 2024
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes #121583
Pull Request resolved: #121731
Approved by: https://github.com/albanD

(cherry picked from commit 5498804)
atalman pushed a commit that referenced this issue Apr 4, 2024
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes #121583
Pull Request resolved: #121731
Approved by: https://github.com/albanD

(cherry picked from commit 5498804)

Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
@atalman atalman added this to the 2.3.0 milestone Apr 10, 2024
@atalman
Copy link
Contributor

atalman commented Apr 11, 2024

validated with 2.3:

-------181,181,4--works----------
-------182,182,4--fails--------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: bfloat16 module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants