Skip to content

Commit

Permalink
[MPS][BE] Introduce MetalShaderLibary class (#125550)
Browse files Browse the repository at this point in the history
That factors out a repeated pattern of creating a library/fetching a func from source

Typical usecase
```cpp
static MetalShaderLibrary lib(SHADER_SOURCE);
...

id<MTLComputePipelineState> cplState = lib.getPipelieStateForFunc("kernel_name")
```
- Make it possible to use with templated sources
- Add `scalarToMetalTypeString(const Tensor&)` variant to avoid repeated `scalarToMetalTypeString(t.scalar_type())` calls in the code

I.e. it makes no functional changes, but reduces MPS codebase size by 365 lines
Pull Request resolved: #125550
Approved by: https://github.com/kulinseth
  • Loading branch information
malfet authored and pytorchmergebot committed May 6, 2024
1 parent 7bf6ed0 commit e30e6d3
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 528 deletions.
28 changes: 28 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include <initializer_list>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
Expand Down Expand Up @@ -71,6 +72,9 @@ static inline std::string getMPSTypeString(const Tensor& t, bool short_name = fa
return getMPSTypeString(t.scalar_type(), short_name);
}
std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
static inline std::string scalarToMetalTypeString(const Tensor& t) {
return scalarToMetalTypeString(t.scalar_type());
}
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
Expand Down Expand Up @@ -329,6 +333,30 @@ inline bool is_dense_in_storage(const at::Tensor& t) {
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
}


class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname);
}
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname);
}
private:
id<MTLComputePipelineState> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
id<MTLLibrary> getLibrary();
id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);

id<MTLLibrary> compileLibrary(const std::string& src);
std::string shaderSource;
unsigned nparams;
id<MTLLibrary> library = nil;
std::unordered_map<std::string, id<MTLLibrary>> libMap;
std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
};

static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
[encoder setBuffer:getMTLBufferStorage(t)
offset:t.storage_offset() * t.element_size()
Expand Down
71 changes: 71 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/native/mps/MPSGraphSonomaOps.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <fmt/format.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -616,4 +617,74 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
return kernelDataOffsets;
}

id<MTLLibrary> MetalShaderLibrary::getLibrary() {
if (C10_UNLIKELY(!library)) {
TORCH_INTERNAL_ASSERT(nparams == 0);
library = compileLibrary(shaderSource);
}
return library;
}

id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::string>& params) {
TORCH_INTERNAL_ASSERT(nparams == params.size());
std::string key = "";
for (auto p : params) {
key += ":" + p;
}
auto lib = libMap[key];
if (lib) {
return lib;
}
auto it = params.begin();
switch (nparams) {
case 1:
lib = compileLibrary(fmt::format(shaderSource, *it));
break;
case 2: {
auto& first = *it++;
auto& second = *it;
lib = compileLibrary(fmt::format(shaderSource, first, second));
break;
}
case 3: {
auto& first = *it++;
auto& second = *it++;
auto& third = *it;
lib = compileLibrary(fmt::format(shaderSource, first, second, third));
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "Unsupported number of paramaters ", nparams);
}
return libMap[key] = lib;
}

id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
auto device = MPSDevice::getInstance()->device();
library = [device newLibraryWithSource:str options:options error:&error];
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
return library;
}

id<MTLComputePipelineState> MetalShaderLibrary::getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname) {
auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
auto cpl = cplMap[key];
if (cpl) {
return cpl;
}

NSError* error = nil;
id<MTLFunction> func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func, "Failed to create function state object for: ", fname);
cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

return cplMap[key] = cpl;
}

} // namespace at::native::mps
45 changes: 4 additions & 41 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace at::native {
namespace mps {

static const char* METAL_BINARY = R"BINARY_METAL(
static MetalShaderLibrary lib(R"BINARY_METAL(
#include <metal_stdlib>
using namespace metal;
Expand Down Expand Up @@ -252,44 +252,7 @@ kernel void complex_kernel(constant void * real_ [[buffer(0)]],
REGISTER_COMPLEX_OUT_OP(float);
REGISTER_COMPLEX_OUT_OP(half);
)BINARY_METAL";

using namespace mps;

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> binaryLibrary = nil;
if (binaryLibrary) {
return binaryLibrary;
}

NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BINARY encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
return binaryLibrary;
}

static id<MTLComputePipelineState> binaryPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> binaryLib = compileBinaryOpsLibrary(device);
id<MTLFunction> binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

psoCache[kernel] = pso;
return pso;
}
)BINARY_METAL");

static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
Expand All @@ -306,10 +269,10 @@ static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_nam
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input.scalar_type());
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input);
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);

id<MTLComputePipelineState> binaryPSO = binaryPipelineState(device, kernel);
id<MTLComputePipelineState> binaryPSO = lib.getPipelineStateForFunc(kernel);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});
Expand Down
60 changes: 11 additions & 49 deletions aten/src/ATen/native/mps/operations/BitwiseOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace at::native {
namespace mps {
static const char* BITWISE_OPS_TEMPLATE = R"METAL(
static MetalShaderLibrary lib(R"METAL(
kernel void bitwise_and_tensor(constant uint& length [[buffer(0)]],
device {0} *out [[buffer(1)]],
Expand Down Expand Up @@ -90,7 +90,8 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
}}
out[offset] = ~a[offset];
}}
)METAL";
)METAL",
3);

static const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for bitwise ops
Expand All @@ -117,48 +118,12 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
return getMetalType(s.type());
}

static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& t3) {
auto key = t1 + t2 + t3;
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
return it->second;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
return rc;
}

static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& t3,
template <typename ScalarOrTensor>
static id<MTLComputePipelineState> getCPLState(const Tensor& t1,
const Tensor& t2,
const ScalarOrTensor& t3,
const std::string& fname) {
auto key = t1 + t2 + t3 + fname;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
auto it = cplMap.find(key);
if (it != cplMap.end()) {
return it->second;
}
NSError* error = nil;
auto library = compileBitwiseOpsLibrary(device, t1, t2, t3);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;
return lib.getPipelineStateForFunc(fname, {getMetalType(t1), getMetalType(t2), getMetalType(t3)});
}

static void handle_tensor_tensor_binary_op(const Tensor& self,
Expand All @@ -167,8 +132,7 @@ static void handle_tensor_tensor_binary_op(const Tensor& self,
const std::string& kernel_name) {
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
auto cplState = getCPLState(output, self, other, kernel_name);
uint32_t length = output.numel();
if (length == 0) {
return;
Expand Down Expand Up @@ -198,8 +162,7 @@ static void handle_tensor_scalar_binary_op(const Tensor& self,
const std::string& kernel_name) {
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
auto cplState = getCPLState(output, self, other, kernel_name);
uint64_t sval = other.to<int64_t>();
uint32_t length = output.numel();
if (length == 0) {
Expand Down Expand Up @@ -296,8 +259,7 @@ static void _bitwise_not_out_mps(const Tensor& self, const Tensor& output_) {
}
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(self), "bitwise_not");
auto cplState = getCPLState(output, self, self, "bitwise_not");
dispatch_sync(stream->queue(), ^() {
getMPSProfiler().beginProfileKernel(cplState, "bitwise_not", {self});

Expand Down
48 changes: 5 additions & 43 deletions aten/src/ATen/native/mps/operations/Bucketization.mm
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace at::native {
namespace mps {

static const char* METAL_BUCKETIZATION = R"BUCKETIZE_METAL(
static MetalShaderLibrary lib(R"BUCKETIZE_METAL(
#include <metal_stdlib>
using namespace metal;
Expand Down Expand Up @@ -194,44 +194,7 @@ kernel void searchsorted(
REGISTER_SEARCHSORTED_OP(long, int);
REGISTER_SEARCHSORTED_OP(long, long);
)BUCKETIZE_METAL";

static id<MTLLibrary> compileBucketizationOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> bucketizationLibrary = nil;
if (bucketizationLibrary) {
return bucketizationLibrary;
}

NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
bucketizationLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BUCKETIZATION
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(
bucketizationLibrary, "Failed to create metal bucketization library, error: ", [[error description] UTF8String]);
return bucketizationLibrary;
}

static id<MTLComputePipelineState> bucketizationPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> bucketizationLib = compileBucketizationOpsLibrary(device);
id<MTLFunction> bucketizationFunc =
[bucketizationLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(bucketizationFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:bucketizationFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

psoCache[kernel] = pso;
return pso;
}
)BUCKETIZE_METAL");

static void searchsorted_mps_contiguous(Tensor& result,
const Tensor& input,
Expand All @@ -250,15 +213,14 @@ static void searchsorted_mps_contiguous(Tensor& result,
int64_t right_i64 = right;
int64_t is_1d_boundaries = boundaries.dim() == 1;

id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();

const std::string kernel = "searchsorted_" + scalarToMetalTypeString(input.scalar_type()) + "_" +
scalarToMetalTypeString(result.scalar_type()) + (sorter.defined() ? "_sorter" : "");
id<MTLComputePipelineState> bucketizationPSO = mps::bucketizationPipelineState(device, kernel);
const std::string kernel = "searchsorted_" + scalarToMetalTypeString(input) + "_" +
scalarToMetalTypeString(result) + (sorter.defined() ? "_sorter" : "");
id<MTLComputePipelineState> bucketizationPSO = lib.getPipelineStateForFunc(kernel);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(bucketizationPSO, kernel, {input, boundaries, sorter});
Expand Down

0 comments on commit e30e6d3

Please sign in to comment.