-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
module: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
I have attempted to run Llama 2 on M-series (M1/M2) Mac GPUs here: https://github.com/Samyak2/llama-mps
Current status
The models loads correctly but inference fails because:
- The
ComplexFloat
dtype is not supported in MPS yet (Closest existing issue I found: FFT operators are not supported on MPS device #78044) - The
aten::view_as_complex
operator is not supported in MPS yet (General MPS op coverage tracking issue #77764) - The
aten::polar.out
operator is not supported in MPS yet. This can be worked around by settingPYTORCH_ENABLE_MPS_FALLBACK=1
which runs the operator on CPU instead. For full performance, this operator would need to be supported too.
There may be more operators and types that may need to be supported. I have not dug further on this since it crashes due to ComplexFloat
not being supported.
Alternatives
There have been forks of Llama to make it work on CPU instead. Examples: https://github.com/b0kch01/llama-cpu
These will leave a lot of performance on the table though.
Additional context
Failure logs for context (from https://github.com/Samyak2/llama-mps):
<redacted>/llama/llama/model.py:55: UserWarning: The operator 'aten::polar.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
Loaded in 11.68 seconds
<redacted>/llama/llama/model.py:72: UserWarning: 0The operator aten::view_as_complex appears to be a view operator, but it has no implementation for the backend "mps:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/CPUFallback.cpp:181.)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
<redacted>/llama/llama/model.py:73: UserWarning: 0The operator aten::view_as_complex appears to be a view operator, but it has no implementation for the backend "mps:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/CPUFallback.cpp:181.)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
libc++abi: terminating due to uncaught exception of type c10::TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.
Exception raised from getMPSScalarType at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/OperationUtils.mm:91 (most recent call first):
frame #0: at::native::mps::getMPSScalarType(c10::ScalarType) + 180 (0x116dc5954 in libtorch_cpu.dylib)
frame #1: invocation function for block in at::native::mps::binaryOpTensor(at::Tensor const&, at::Tensor const&, c10::Scalar const&, at::Tensor const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, MPSGraphTensor* (at::native::mps::BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*) block_pointer) + 108 (0x116de0814 in libtorch_cpu.dylib)
frame #2: invocation function for block in at::native::mps::MPSGraphCache::CreateCachedGraph(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, at::native::mps::MPSCachedGraph* () block_pointer) + 216 (0x116ddb8d4 in libtorch_cpu.dylib)
frame #3: _dispatch_client_callout + 20 (0x185114400 in libdispatch.dylib)
frame #4: _dispatch_lane_barrier_sync_invoke_and_complete + 56 (0x18512397c in libdispatch.dylib)
frame #5: at::native::mps::MPSGraphCache::CreateCachedGraph(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, at::native::mps::MPSCachedGraph* () block_pointer) + 160 (0x116dc99e0 in libtorch_cpu.dylib)
frame #6: at::native::mps::binaryOpTensor(at::Tensor const&, at::Tensor const&, c10::Scalar const&, at::Tensor const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, MPSGraphTensor* (at::native::mps::BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*) block_pointer) + 2352 (0x116ddf898 in libtorch_cpu.dylib)
frame #7: at::native::structured_mul_out_mps::impl(at::Tensor const&, at::Tensor const&, at::Tensor const&) + 128 (0x116de33f0 in libtorch_cpu.dylib)
frame #8: at::(anonymous namespace)::wrapper_MPS_mul_Tensor(at::Tensor const&, at::Tensor const&) + 140 (0x11457fea8 in libtorch_cpu.dylib)
frame #9: at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&) + 284 (0x1133bd898 in libtorch_cpu.dylib)
frame #10: torch::autograd::THPVariable_mul(_object*, _object*, _object*) + 396 (0x10726c2dc in libtorch_python.dylib)
frame #11: _object* torch::autograd::TypeError_to_NotImplemented_<&torch::autograd::THPVariable_mul(_object*, _object*, _object*)>(_object*, _object*, _object*) + 12 (0x1071c8330 in libtorch_python.dylib)
<omitting python frames>
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev
msaroufim, malfet and AbhimanyuAryan
Metadata
Metadata
Assignees
Labels
module: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module