Skip to content

Running Llama 2 on Apple Silicon GPUs - missing MPS types and operators #105665

@Samyak2

Description

@Samyak2

🚀 The feature, motivation and pitch

I have attempted to run Llama 2 on M-series (M1/M2) Mac GPUs here: https://github.com/Samyak2/llama-mps

Current status

The models loads correctly but inference fails because:

There may be more operators and types that may need to be supported. I have not dug further on this since it crashes due to ComplexFloat not being supported.

Alternatives

There have been forks of Llama to make it work on CPU instead. Examples: https://github.com/b0kch01/llama-cpu
These will leave a lot of performance on the table though.

Additional context

Failure logs for context (from https://github.com/Samyak2/llama-mps):

<redacted>/llama/llama/model.py:55: UserWarning: The operator 'aten::polar.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
Loaded in 11.68 seconds
<redacted>/llama/llama/model.py:72: UserWarning: 0The operator aten::view_as_complex appears to be a view operator, but it has no implementation for the backend "mps:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/CPUFallback.cpp:181.)
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
<redacted>/llama/llama/model.py:73: UserWarning: 0The operator aten::view_as_complex appears to be a view operator, but it has no implementation for the backend "mps:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/CPUFallback.cpp:181.)
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
libc++abi: terminating due to uncaught exception of type c10::TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.
Exception raised from getMPSScalarType at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/OperationUtils.mm:91 (most recent call first):
frame #0: at::native::mps::getMPSScalarType(c10::ScalarType) + 180 (0x116dc5954 in libtorch_cpu.dylib)
frame #1: invocation function for block in at::native::mps::binaryOpTensor(at::Tensor const&, at::Tensor const&, c10::Scalar const&, at::Tensor const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, MPSGraphTensor* (at::native::mps::BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*) block_pointer) + 108 (0x116de0814 in libtorch_cpu.dylib)
frame #2: invocation function for block in at::native::mps::MPSGraphCache::CreateCachedGraph(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, at::native::mps::MPSCachedGraph* () block_pointer) + 216 (0x116ddb8d4 in libtorch_cpu.dylib)
frame #3: _dispatch_client_callout + 20 (0x185114400 in libdispatch.dylib)
frame #4: _dispatch_lane_barrier_sync_invoke_and_complete + 56 (0x18512397c in libdispatch.dylib)
frame #5: at::native::mps::MPSGraphCache::CreateCachedGraph(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, at::native::mps::MPSCachedGraph* () block_pointer) + 160 (0x116dc99e0 in libtorch_cpu.dylib)
frame #6: at::native::mps::binaryOpTensor(at::Tensor const&, at::Tensor const&, c10::Scalar const&, at::Tensor const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, MPSGraphTensor* (at::native::mps::BinaryOpCachedGraph*, MPSGraphTensor*, MPSGraphTensor*) block_pointer) + 2352 (0x116ddf898 in libtorch_cpu.dylib)
frame #7: at::native::structured_mul_out_mps::impl(at::Tensor const&, at::Tensor const&, at::Tensor const&) + 128 (0x116de33f0 in libtorch_cpu.dylib)
frame #8: at::(anonymous namespace)::wrapper_MPS_mul_Tensor(at::Tensor const&, at::Tensor const&) + 140 (0x11457fea8 in libtorch_cpu.dylib)
frame #9: at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&) + 284 (0x1133bd898 in libtorch_cpu.dylib)
frame #10: torch::autograd::THPVariable_mul(_object*, _object*, _object*) + 396 (0x10726c2dc in libtorch_python.dylib)
frame #11: _object* torch::autograd::TypeError_to_NotImplemented_<&torch::autograd::THPVariable_mul(_object*, _object*, _object*)>(_object*, _object*, _object*) + 12 (0x1071c8330 in libtorch_python.dylib)
<omitting python frames>

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions