Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions backends/apple/metal/runtime/shims/et_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ class ETMetalKernelFunction {
void startEncoding();
void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor);
void setArg(unsigned idx, int64_t val);
void setArg(unsigned idx, uint32_t val);
void setArg(unsigned idx, float val);
void setArg(unsigned idx, bool val);
void setArg(unsigned idx, const void* data, size_t size);

// Helper for Metal uint3 struct
void setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z);

void dispatchSingle(uint64_t length);
void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size);
Expand All @@ -191,6 +198,15 @@ class ETMetalKernelFunction {
const uint64_t* group_size,
size_t group_size_size);

// Dispatch with explicit threadgroup count (not thread count)
void dispatchThreadgroups(
uint64_t gridX,
uint64_t gridY,
uint64_t gridZ,
uint64_t threadsX,
uint64_t threadsY,
uint64_t threadsZ);

void runCommandBlock(std::function<void(void)> f);

private:
Expand Down
87 changes: 87 additions & 0 deletions backends/apple/metal/runtime/shims/et_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
#import <Foundation/Foundation.h>
#include <simd/simd.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/backends/apple/metal/runtime/shims/et_metal.h>
Expand Down Expand Up @@ -377,6 +378,58 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx);
}

void ETMetalKernelFunction::setArg(unsigned idx, uint32_t val) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
return;
}

[encoder_ setBytes:&val length:sizeof(uint32_t) atIndex:idx];
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set uint32_t value %u at index %u", val, idx);
}

void ETMetalKernelFunction::setArg(unsigned idx, float val) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
return;
}

[encoder_ setBytes:&val length:sizeof(float) atIndex:idx];
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set float value %f at index %u", val, idx);
}

void ETMetalKernelFunction::setArg(unsigned idx, bool val) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
return;
}

[encoder_ setBytes:&val length:sizeof(bool) atIndex:idx];
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bool value %s at index %u", val ? "true" : "false", idx);
}

void ETMetalKernelFunction::setArg(unsigned idx, const void* data, size_t size) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
return;
}

[encoder_ setBytes:data length:size atIndex:idx];
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bytes at index %u (size: %zu)", idx, size);
}

void ETMetalKernelFunction::setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::setArgUint3: No active encoder");
return;
}

// Use SIMD library's uint3 type which matches Metal shader's uint3 layout
simd_uint3 val = {x, y, z};
[encoder_ setBytes:&val length:sizeof(simd_uint3) atIndex:idx];
ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx);
}

void ETMetalKernelFunction::dispatchSingle(uint64_t length) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder");
Expand Down Expand Up @@ -502,6 +555,40 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev

}

void ETMetalKernelFunction::dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ,
uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ) {
if (!encoder_) {
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No active encoder");
return;
}

if (!cps_) {
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No compute pipeline state");
return;
}

// Calculate total threads per threadgroup
uint64_t totalThreads = threadsX * threadsY * threadsZ;

const auto maxThreadsPerGroup = static_cast<uint64_t>([cps_ maxTotalThreadsPerThreadgroup]);

// Validate total thread count
if (totalThreads > maxThreadsPerGroup) {
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: Requested %llu total threads per threadgroup exceeds device maximum of %llu",
(unsigned long long)totalThreads, (unsigned long long)maxThreadsPerGroup);
return;
}

MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ);
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ);

[encoder_ dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];

ET_LOG(Debug, "ETMetalKernelFunction::dispatchThreadgroups: Dispatched grid [%llu, %llu, %llu] with threadgroup [%llu, %llu, %llu]",
(unsigned long long)gridX, (unsigned long long)gridY, (unsigned long long)gridZ,
(unsigned long long)threadsX, (unsigned long long)threadsY, (unsigned long long)threadsZ);
}

void ETMetalKernelFunction::runCommandBlock(std::function<void(void)> f) {
// Use dispatch_sync with the stream's serial queue for thread safety and synchronization
// This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...)
Expand Down
Loading
Loading