Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some operation are not implemented when using mps backend #77754

Closed
thipokKub opened this issue May 18, 2022 · 19 comments
Closed

Some operation are not implemented when using mps backend #77754

thipokKub opened this issue May 18, 2022 · 19 comments
Labels
feature A request for a proper, new feature. module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@thipokKub
Copy link

馃悰 Describe the bug

Recently, pytorch add support for metal backend (see #47702 (comment)) but it seems like there are some missing operations. For example

NotImplementedError: Could not run 'aten::bitwise_xor.Tensor_out' with arguments from the 'MPS' backend

and

NotImplementedError: Could not run 'aten::_index_put_impl_' with arguments from the 'MPS' backend

To reproduce

import torch
import numpy as np
from torchvision.models.resnet import ResNet, BasicBlock
from pytorch_metric_learning import miners, losses
from pytorch_metric_learning.distances import CosineSimilarity

device = torch.device("mps")

model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=64).to(device)
distance = CosineSimilarity()

miner = miners.MultiSimilarityMiner()
loss_func = losses.SoftTripleLoss(
    num_classes=10,
    embedding_size=64,
    distance=distance,
).to(device)

X, y = torch.rand(16, 3, 64, 64).to(device), torch.from_numpy(np.random.choice(np.arange(10), size=16)).to(device)
embeddings = model(X)

# NotImplementedError: Could not run 'aten::bitwise_xor.Tensor_out' with arguments from the 'MPS' backend
hard_pairs = miner(embeddings, y)
loss = loss_func(embeddings, y, hard_pairs)

# NotImplementedError: Could not run 'aten::_index_put_impl_' with arguments from the 'MPS' backend
loss = loss_func(embeddings, y)

Versions

PyTorch version: 1.12.0.dev20220518
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.3 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:25:14) [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-12.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] pytorch-lightning==1.6.3
[pip3] pytorch-metric-learning==1.3.0
[pip3] torch==1.12.0.dev20220518
[pip3] torchaudio==0.11.0
[pip3] torchinfo==1.6.6
[pip3] torchmetrics==0.8.2
[pip3] torchvision==0.12.0
[conda] numpy 1.21.6 pypi_0 pypi
[conda] pytorch-lightning 1.6.3 pypi_0 pypi
[conda] pytorch-metric-learning 1.3.0 pypi_0 pypi
[conda] torch 1.12.0.dev20220518 pypi_0 pypi
[conda] torchaudio 0.11.0 pypi_0 pypi
[conda] torchinfo 1.6.6 pypi_0 pypi
[conda] torchmetrics 0.8.2 pypi_0 pypi
[conda] torchvision 0.12.0 pypi_0 pypi

@albanD albanD added feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels May 18, 2022
@singularity-s0
Copy link

singularity-s0 commented May 18, 2022

For me, even when executing sample code like

mps_device = torch.device("mps")
z = torch.ones(5, device=mps_device)

Result in

NotImplementedError: Could not run 'aten::empty.memory_format' with arguments from the 'MPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty.memory_format' is only available for these backends: [Dense, Conjugate, Negative, VmapMode, FuncTorchGradWrapper, MPS, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].

FYI, I'm using an Intel Mac with AMD graphics card.

@kulinseth
Copy link
Collaborator

For me, even when executing sample code like

mps_device = torch.device("mps")
z = torch.ones(5, device=mps_device)

Result in

NotImplementedError: Could not run 'aten::empty.memory_format' with arguments from the 'MPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty.memory_format' is only available for these backends: [Dense, Conjugate, Negative, VmapMode, FuncTorchGradWrapper, MPS, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].

FYI, I'm using an Intel Mac with AMD graphics card.

I am curious are you building PyTorch yourself ? or is it based on nightly binaries? Because my understanding was currently the runners don't have the MPS support enabled. Also what's the OS version ?

@gautierdag
Copy link

Recently, pytorch add support for metal backend (see #47702 (comment)) but it seems like there are some missing operations. For example

NotImplementedError: Could not run 'aten::bitwise_xor.Tensor_out' with arguments from the 'MPS' backend

I'm running in a similar issue trying to run transformers code where the cumsum operation does not yet exist:

NotImplementedError: Could not run 'aten::cumsum.out' with arguments from the 'MPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::cumsum.out' is only available for these backends: [Dense, Conjugate, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].

For me, even when executing sample code like
...
FYI, I'm using an Intel Mac with AMD graphics card.

Check that you are using the right python version (built for ARM and not x84): #77748 (comment)

@albanD
Copy link
Collaborator

albanD commented May 18, 2022

@singularity-s0 I am afraid the x86+AMD GPU are not fully finalized right now (we plan on getting them ready as soon as possible). You can build from source to get an MPS-enabled build or wait for #77662 to land which will enable MPS on the nightly intel build.

@albanD
Copy link
Collaborator

albanD commented May 18, 2022

@thipokKub thanks for the report! And also sharing the exact model you're looking for.

I created a tracking issue for this #77764 so that we have a centralized place to know who is working on what.

@singularity-s0
Copy link

singularity-s0 commented May 18, 2022

Oh... I didn't realize the MPS was (currently) M1-only. The blog post mentioned "GPU-accelerated PyTorch training on Mac" so I tried the nightly version on my Intel Mac (macOS 12.4). Thanks for the clarification.

@aidanalphafund
Copy link

Another unsupported operation:

Could not run 'aten::multinomial' with arguments from the 'MPS' backend

Reproduction:

import torch

dist = torch.distributions.Categorical(torch.tensor([0.5, 0.5]).to('mps'))
print(dist.sample())

@johnnynunez
Copy link

johnnynunez commented May 20, 2022

Oh... I didn't realize the MPS was (currently) M1-only. The blog post mentioned "GPU-accelerated PyTorch training on Mac" so I tried the nightly version on my Intel Mac (macOS 12.4). Thanks for the clarification.

for intel it's better to use https://github.com/oneapi-src/oneDNN
https://gist.github.com/mingfeima/363a9ab850be54d5837f9cc542ad2b38

@kulinseth
Copy link
Collaborator

Thanks @thipokKub for reporting the ops. These ops are captured as part of #77764.

Closing this issue, please re-open or comment on the linked issue for any other ops.

@kulinseth
Copy link
Collaborator

Oh... I didn't realize the MPS was (currently) M1-only. The blog post mentioned "GPU-accelerated PyTorch training on Mac" so I tried the nightly version on my Intel Mac (macOS 12.4). Thanks for the clarification.

for intel it's better to use https://github.com/oneapi-src/oneDNN https://gist.github.com/mingfeima/363a9ab850be54d5837f9cc542ad2b38

#80760

@matrix2vector
Copy link

NotImplementedError: The operator 'aten::cumsum.out' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on #77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

@Joheun-Kang
Copy link

Tried to Fine tune with transformer with M2, but got

NotImplementedError: The operator 'aten::cumsum.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on #77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS

any solution yet?

@nigelparsad
Copy link

Running on an M1 Max:

NotImplementedError: The operator 'aten::upsample_linear1d.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on #77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

@kulinseth
Copy link
Collaborator

Tried to Fine tune with transformer with M2, but got

NotImplementedError: The operator 'aten::cumsum.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on #77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS

any solution yet?

cumsum has been implemented in MPS and in v2.0. Which pytorch verison are you using ? Also, can you try with Nightlies or v2.0 version.

@kulinseth
Copy link
Collaborator

Thanks @nigelparsad , we will look into adding this op. Can you please provide more information about your use-case or application which you are targeting with this?

@nigelparsad
Copy link

nigelparsad commented Aug 5, 2023

@kulinseth Thank you for looking into this. I have run into this issue using NeuralForecast, specifically their NHITS model:
https://nixtla.github.io/neuralforecast/models.nhits.html#nhits

In the latter link, the NHITS model's last parameter is **trainer_kwargs, which are the keyword trainer arguments inherited from PyTorch Lighning鈥檚 trainer.

I pass Lightning's accelerator='mps' argument here resulting in the error listed above. Please let me know if you require any more information for this particular use case.

@karppmik
Copy link

karppmik commented Mar 2, 2024

Getting this error with torch 2.2.1 and torchvision 0.17.1

NotImplementedError: The operator 'torchvision::nms' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on #77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

@yaniv92648
Copy link

yaniv92648 commented Mar 17, 2024

Getting the following error

NotImplementedError: The operator 'aten::nanmedian.dim_values' is not currently implemented for the MPS device. 
If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. 
As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. 
WARNING: this will be slower than running natively on MPS.

when running

model = NeuralForecast(models=[AutoLSTM(h=30, loss=MAPE(), num_samples=1)], freq='D')
model.fit(df=train_df)
predictions = model.predict()

train_df is configured according to the NIXTLA's API.
Running on MacBook Pro M2 MAX.
Ran in the code - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - still gets the error.
Ran in the environment - export PYTORCH_ENABLE_MPS_FALLBACK=1 - still gets the error.
Anyone has an idea what to do here?

@Smendowski
Copy link

@yaniv92648, I am facing the same issue with AutoTCN on M1 Pro

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests