Skip to content

Commit

Permalink
[MPS] Fix naive matmul for BFloat16
Browse files Browse the repository at this point in the history
Will only work on MacOS14 or newer, so compile the shader with `MTLLanguageVersion_3_1` when appropriate

Fixes #121583
  • Loading branch information
malfet committed Mar 12, 2024
1 parent edf22f3 commit 520986e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.mm
Expand Up @@ -210,6 +210,9 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
return "float";
case ScalarType::Half:
return "half";
case ScalarType::BFloat16:
checkSupportsBFloat16();
return "bfloat";
case ScalarType::Int:
return "int";
case ScalarType::Long:
Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Expand Up @@ -4,6 +4,8 @@
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
// For MTLLanguageVersion_3_1
#include <ATen/native/mps/MPSGraphSonomaOps.h>
#include <ATen/native/mps/OperationUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
Expand All @@ -29,7 +31,7 @@
using namespace metal;
template<typename T>
T dot_product(constant T *v1, constant T* v2, ulong2 strides, uint32_t size) {
T rc = 0.0;
T rc = T(0.0);
for (uint32_t i = 0; i < size; ++i) {
rc += v1[i * strides.x] * v2[i * strides.y];
}
Expand Down Expand Up @@ -69,6 +71,9 @@ kernel void naive_matmul(
INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_NAIVE_MM(bfloat);
#endif
)MATMUL_METAL";

id<MTLLibrary> compileLinalgOpLibrary(id<MTLDevice> device) {
Expand All @@ -79,7 +84,8 @@ kernel void naive_matmul(

NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
linalgLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_LINALG encoding:NSASCIIStringEncoding]
options:options
error:&error];
Expand Down
8 changes: 8 additions & 0 deletions test/test_mps.py
Expand Up @@ -6931,6 +6931,14 @@ def compare_mm(m, n, k):
# see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
compare_mm(32769, 1, 1025)

if product_version >= 14.0:
# Test bfloat16 mm
x = torch.rand(182, 182, 4, dtype=torch.bfloat16, device='mps')
y = torch.rand(4, 3, dtype=torch.bfloat16, device='mps')
z = torch.matmul(x, y).cpu()
z_cpu = torch.matmul(x.cpu(), y.cpu())
self.assertEqual(z, z_cpu)

# Test flip
def test_flip(self):
def helper(shape, dims):
Expand Down

0 comments on commit 520986e

Please sign in to comment.