Skip to content

Commit

Permalink
[MPS] Add PSO caching for advanced indexing kernels (#99855)
Browse files Browse the repository at this point in the history
Use bindless Argument Buffers (unbounded arrays) for advanced indexing kernels - this allows caching of the PSOs since we don't have to query anymore the main metal function for the AB size (this is filled directly now on the CPU).
Pull Request resolved: #99855
Approved by: https://github.com/kulinseth
  • Loading branch information
DenisVieriu97 authored and pytorchmergebot committed Apr 24, 2023
1 parent 09b189e commit dcd686f
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 91 deletions.
149 changes: 113 additions & 36 deletions aten/src/ATen/mps/IndexKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,42 @@ static const char * indexing_metal_shaders = R"INDEX_METAL(
using namespace metal;
constant uint32_t num_indices [[function_constant(0)]];
#if __METAL_VERSION__ < 300
struct IndexAB {
// Allow up to 16 indices
metal::array<constant void *, 16> indexArray [[ id(0) ]];
};
#else
struct IndexAB {
constant int64_t* indexArray;
};
#endif
template<typename T>
kernel void index_select(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -42,19 +57,30 @@ kernel void index_select(
template<typename T>
kernel void index_put(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -65,6 +91,7 @@ kernel void index_put(
*out = *in;
}
#if __METAL_VERSION__ < 300
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
Expand All @@ -75,7 +102,22 @@ kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
constant uint3 * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#else
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant uint3 * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#endif
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_INDEX_OP(8bit, char, INDEX_OP_TYPE); \
Expand All @@ -92,29 +134,40 @@ kernel void kernel_index_offsets(constant packed_uint3 * strides [[buffe
constant uint & num_dimensions [[buffer(3)]],
constant uint & num_offsets [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
data_offsets[thread_index] = 0;
uint32_t idx = thread_index;
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
uint32_t remainder = idx % iter_shape[dim];
idx /= iter_shape[dim];
for (uint32_t offset = 0; offset < num_offsets; offset++)
data_offsets[thread_index][offset] += remainder * strides[dim][offset];
data_offsets[thread_index] += remainder * strides[dim];
}
}
template<typename T, typename E>
kernel void index_put_accumulate_native_dtypes(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
kernel void index_put_accumulate_native_dtypes(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -136,18 +189,29 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a
}
template<typename T>
kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
kernel void atomic_index_put_accumulate(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -160,22 +224,35 @@ kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[b
template
[[host_name("index_put_accumulate_32bit_float")]]
kernel void atomic_index_put_accumulate<float>(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
kernel void atomic_index_put_accumulate<float>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_int")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
)INDEX_METAL";

static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLDevice> MTLDevice_t;
typedef id<MTLLibrary> MTLLibrary_t;
typedef id<MTLFunction> MTLFunction_t;
typedef MTLFunctionConstantValues* MTLFunctionConstantValues_t;
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
typedef id<MTLLibrary> MTLLibrary_t;
#else
typedef void* MTLDevice;
typedef void* MTLDevice_t;
typedef void* MTLLibrary_t;
typedef void* MTLFunction_t;
typedef void* MTLFunctionConstantValues_t;
typedef void* MTLComputePipelineState_t;
typedef void* MTLLibrary_t;
#endif

using namespace std;
Expand Down Expand Up @@ -66,7 +66,8 @@ class TORCH_API MPSDevice {
*/
bool isMacOS13Plus(MacOSVersion version) const;

MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues);
MTLComputePipelineState_t metalIndexingFunction(const std::string &kernel);
MTLLibrary_t getMetalIndexingLibrary();

~MPSDevice();

Expand Down
36 changes: 24 additions & 12 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2
MTLLanguageVersion languageVersion = MTLLanguageVersion2_2;
#if defined(__MAC_13_0)
if (macOS13Plus) {
languageVersion = MTLLanguageVersion3_0;
}
#endif

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
Expand All @@ -27,37 +32,44 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
return mps_device.get();
}

id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
id<MTLLibrary> MPSDevice::getMetalIndexingLibrary() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setFastMathEnabled:YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}
return _mtl_indexing_library;
}

id<MTLFunction> indexFunction = nil;
if (constantValues) {
indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]
constantValues:constantValues
error:&error] autorelease];
} else {
indexFunction =
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
id<MTLComputePipelineState> MPSDevice::metalIndexingFunction(const std::string& kernel) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLLibrary> indexing_lib = getMetalIndexingLibrary();
id<MTLComputePipelineState> state = psoCache[kernel];
if (state) {
return state;
}

id<MTLFunction> indexFunction =
[[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
TORCH_CHECK(indexFunction,
"Failed to create specialized function state object: ",
kernel,
", error: ",
[[error description] UTF8String]);

return indexFunction;
state = [_mtl_device newComputePipelineStateWithFunction:indexFunction error:&error];
TORCH_CHECK(state, error.localizedDescription.UTF8String);
psoCache[kernel] = state;
return state;
}

MPSDevice::~MPSDevice() {
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,8 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
}
}

id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/mps/operations/CrossKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,8 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
}
}

id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
Expand Down

0 comments on commit dcd686f

Please sign in to comment.