Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Add PSO caching for advanced indexing kernels #99855

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
149 changes: 113 additions & 36 deletions aten/src/ATen/mps/IndexKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,42 @@ static const char * indexing_metal_shaders = R"INDEX_METAL(

using namespace metal;

constant uint32_t num_indices [[function_constant(0)]];

#if __METAL_VERSION__ < 300
struct IndexAB {
// Allow up to 16 indices
metal::array<constant void *, 16> indexArray [[ id(0) ]];
};
#else
struct IndexAB {
constant int64_t* indexArray;
};

#endif

template<typename T>
kernel void index_select(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -42,19 +57,30 @@ kernel void index_select(

template<typename T>
kernel void index_put(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {

constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];

if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -65,6 +91,7 @@ kernel void index_put(
*out = *in;
}

#if __METAL_VERSION__ < 300
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
Expand All @@ -75,7 +102,22 @@ kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
constant uint3 * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#else
#define REGISTER_INDEX_OP(DTYPE_SIZE, DTYPE, INDEX_OP_TYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant uint3 * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#endif

#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_INDEX_OP(8bit, char, INDEX_OP_TYPE); \
Expand All @@ -92,29 +134,40 @@ kernel void kernel_index_offsets(constant packed_uint3 * strides [[buffe
constant uint & num_dimensions [[buffer(3)]],
constant uint & num_offsets [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
data_offsets[thread_index] = 0;
uint32_t idx = thread_index;
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
uint32_t remainder = idx % iter_shape[dim];
idx /= iter_shape[dim];

for (uint32_t offset = 0; offset < num_offsets; offset++)
data_offsets[thread_index][offset] += remainder * strides[dim][offset];
data_offsets[thread_index] += remainder * strides[dim];
}
}

template<typename T, typename E>
kernel void index_put_accumulate_native_dtypes(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
kernel void index_put_accumulate_native_dtypes(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -136,18 +189,29 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a
}

template<typename T>
kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
kernel void atomic_index_put_accumulate(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
Expand All @@ -160,22 +224,35 @@ kernel void atomic_index_put_accumulate(constant IndexAB & indexAB [[b

template
[[host_name("index_put_accumulate_32bit_float")]]
kernel void atomic_index_put_accumulate<float>(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
kernel void atomic_index_put_accumulate<float>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);

template
[[host_name("index_put_accumulate_32bit_int")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(constant IndexAB & indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
kernel void index_put_accumulate_native_dtypes<atomic_int, int>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
)INDEX_METAL";

static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLDevice> MTLDevice_t;
typedef id<MTLLibrary> MTLLibrary_t;
typedef id<MTLFunction> MTLFunction_t;
typedef MTLFunctionConstantValues* MTLFunctionConstantValues_t;
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
typedef id<MTLLibrary> MTLLibrary_t;
#else
typedef void* MTLDevice;
typedef void* MTLDevice_t;
typedef void* MTLLibrary_t;
typedef void* MTLFunction_t;
typedef void* MTLFunctionConstantValues_t;
typedef void* MTLComputePipelineState_t;
typedef void* MTLLibrary_t;
#endif

using namespace std;
Expand Down Expand Up @@ -66,7 +66,8 @@ class TORCH_API MPSDevice {
*/
bool isMacOS13Plus(MacOSVersion version) const;

MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues);
MTLComputePipelineState_t metalIndexingFunction(const std::string &kernel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @DenisVieriu97, do you think it would be better to rename it as metalIndexingPSO? If so, I can propose a PR for it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qqaatw this sounds good! Please go ahead with the change

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DenisVieriu97 here is the PR :) #101156

MTLLibrary_t getMetalIndexingLibrary();

~MPSDevice();

Expand Down
36 changes: 24 additions & 12 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2
MTLLanguageVersion languageVersion = MTLLanguageVersion2_2;
#if defined(__MAC_13_0)
if (macOS13Plus) {
languageVersion = MTLLanguageVersion3_0;
}
#endif

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
Expand All @@ -27,37 +32,44 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
return mps_device.get();
}

id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
id<MTLLibrary> MPSDevice::getMetalIndexingLibrary() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setFastMathEnabled:YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}
return _mtl_indexing_library;
}

id<MTLFunction> indexFunction = nil;
if (constantValues) {
indexFunction = [[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]
constantValues:constantValues
error:&error] autorelease];
} else {
indexFunction =
[[_mtl_indexing_library newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
id<MTLComputePipelineState> MPSDevice::metalIndexingFunction(const std::string& kernel) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLLibrary> indexing_lib = getMetalIndexingLibrary();
id<MTLComputePipelineState> state = psoCache[kernel];
if (state) {
return state;
}

id<MTLFunction> indexFunction =
[[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
TORCH_CHECK(indexFunction,
"Failed to create specialized function state object: ",
kernel,
", error: ",
[[error description] UTF8String]);

return indexFunction;
state = [_mtl_device newComputePipelineStateWithFunction:indexFunction error:&error];
TORCH_CHECK(state, error.localizedDescription.UTF8String);
psoCache[kernel] = state;
return state;
}

MPSDevice::~MPSDevice() {
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,8 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
}
}

id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/mps/operations/CrossKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,8 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
}
}

id<MTLFunction> kernelDataOffsetsFunction =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
[[device newComputePipelineStateWithFunction:kernelDataOffsetsFunction error:&error] autorelease];
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
Expand Down