From a9108f8c1cfd8d481b4f59be2a433a858871577e Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 4 Dec 2025 16:11:13 -0500 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 11 + .../apple/metal/runtime/shims/et_metal.mm | 74 ++ .../apple/metal/runtime/shims/et_metal_ops.mm | 1087 +++++++++-------- 3 files changed, 636 insertions(+), 536 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index 0e012d18c8f..f00cde5d22f 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -181,6 +181,13 @@ class ETMetalKernelFunction { void startEncoding(); void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); void setArg(unsigned idx, int64_t val); + void setArg(unsigned idx, uint32_t val); + void setArg(unsigned idx, float val); + void setArg(unsigned idx, bool val); + void setArg(unsigned idx, const void* data, size_t size); + + // Helper for Metal uint3 struct + void setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z); void dispatchSingle(uint64_t length); void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); @@ -191,6 +198,10 @@ class ETMetalKernelFunction { const uint64_t* group_size, size_t group_size_size); + // Dispatch with explicit threadgroup count (not thread count) + void dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ, + uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ); + void runCommandBlock(std::function f); private: diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index 2ba058de40a..e6c919ba61f 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -377,6 +377,63 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx); } +void ETMetalKernelFunction::setArg(unsigned idx, uint32_t val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(uint32_t) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set uint32_t value %u at index %u", val, idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, float val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(float) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set float value %f at index %u", val, idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, bool val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(bool) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bool value %s at index %u", val ? "true" : "false", idx); +} + +void ETMetalKernelFunction::setArg(unsigned idx, const void* data, size_t size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:data length:size atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bytes at index %u (size: %zu)", idx, size); +} + +void ETMetalKernelFunction::setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArgUint3: No active encoder"); + return; + } + + // Metal's uint3 is a packed struct of 3 uint32_t values + struct uint3 { + uint32_t x; + uint32_t y; + uint32_t z; + }; + uint3 val = {x, y, z}; + [encoder_ setBytes:&val length:sizeof(uint3) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx); +} + void ETMetalKernelFunction::dispatchSingle(uint64_t length) { if (!encoder_) { ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder"); @@ -502,6 +559,23 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev } +void ETMetalKernelFunction::dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ, + uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No active encoder"); + return; + } + + MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ); + MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ); + + [encoder_ dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + + ET_LOG(Debug, "ETMetalKernelFunction::dispatchThreadgroups: Dispatched grid [%llu, %llu, %llu] with threadgroup [%llu, %llu, %llu]", + (unsigned long long)gridX, (unsigned long long)gridY, (unsigned long long)gridZ, + (unsigned long long)threadsX, (unsigned long long)threadsY, (unsigned long long)threadsZ); +} + void ETMetalKernelFunction::runCommandBlock(std::function f) { // Use dispatch_sync with the stream's serial queue for thread safety and synchronization // This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index b150c68fe6d..55d96aa9c49 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -19,6 +19,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -122,6 +123,233 @@ void logStats() { return it->second; } +// Helper function to get the Metal shader source for SDPA +static std::string get_sdpa_metal_source() { + return R"( +// Ported from PyTorch's Attention.metal +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal +// Largely influeneced by +// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +// Modified to support floating point masks and transposed middle dimensions (dims 1 & 2) + +#include +#include +#include + +using namespace metal; + +typedef half float16_t; +typedef bfloat bfloat16_t; + +// PyTorch's sdpa_vector kernel (one-pass variant) +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + constant uint& gqa_factor [[buffer(4)]], + constant uint& N [[buffer(5)]], + constant uint3& qkv_head_strides [[buffer(6)]], + constant uint3& qkv_seq_strides [[buffer(7)]], + constant float& scale [[buffer(8)]], + const device T* mask [[buffer(9)]], // Changed from bool* to T* for floating point masks + constant uint3& mask_strides [[buffer(10)]], + constant bool& has_mask [[buffer(11)]], + constant uint3& qkv_batch_strides [[buffer(12)]], // NEW: batch strides for Q, K, V + constant uint& num_q_heads [[buffer(13)]], // NEW: number of query heads + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr uint BN = 32; + constexpr uint BD = 32; + constexpr uint qk_per_thread = D / BD; + constexpr uint v_per_thread = V / BD; + const uint q_head_stride = qkv_head_strides.x; + const uint q_seq_stride = qkv_seq_strides.x; + const uint q_batch_stride = qkv_batch_strides.x; + const uint k_head_stride = qkv_head_strides.y; + const uint k_seq_stride = qkv_seq_strides.y; + const uint k_batch_stride = qkv_batch_strides.y; + const uint v_head_stride = qkv_head_strides.z; + const uint v_seq_stride = qkv_seq_strides.z; + const uint v_batch_stride = qkv_batch_strides.z; + const uint mask_head_stride = mask_strides.x; + const uint mask_kv_seq_stride = mask_strides.y; + const uint mask_q_seq_stride = mask_strides.z; + uint inner_k_stride = BN * int(k_seq_stride); + uint inner_v_stride = BN * int(v_seq_stride); + + typedef float U; + + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.x; // Flattened batch*heads index + const int q_seq_idx = tid.y; + + // Decompose flattened head_idx into batch and head indices + const int batch_idx = head_idx / num_q_heads; + const int head_in_batch = head_idx % num_q_heads; + const int kv_head_idx = head_in_batch / gqa_factor; + + const int Q = tpg.y; + const int group_offset = head_idx * Q + q_seq_idx; + const int o_offset = group_offset; + + // Use decomposed indices with separate batch and head strides + queries += batch_idx * q_batch_stride + head_in_batch * q_head_stride + q_seq_idx * q_seq_stride + + simd_lid * qk_per_thread; + keys += batch_idx * k_batch_stride + kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + values += batch_idx * v_batch_stride + kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + if (has_mask) { + mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + + out += o_offset * V + simd_gid * v_per_thread; + + // Read the query and 0 the output accumulator + for (uint i = 0; i < qk_per_thread; i++) { + q[i] = scale * static_cast(queries[i]); + } + for (uint i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (uint i = simd_gid; i < N; i += BN) { + // Check mask: for floating point masks, values > -1e9 are considered valid (not masked) + // Masked positions typically have -inf or very negative values + const bool is_valid = !has_mask || (static_cast(mask[0]) > -1e9f); + + if (is_valid) { + // Read the key + for (uint j = 0; j < qk_per_thread; j++) { + k[j] = static_cast(keys[j]); + } + + // Compute the i-th score + U score = 0; + for (uint j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + + // Add mask value to score if mask is present + if (has_mask) { + score += static_cast(mask[0]); + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = metal::fast::exp(max_score - new_max); + U exp_score = metal::fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (uint j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * static_cast(values[j]); + } + } + + // Move the pointers to the next kv + keys += inner_k_stride; + values += inner_v_stride; + if (has_mask) { + mask += BN * mask_kv_seq_stride; + } + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = metal::fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (uint i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + const U safe_sum = (sum_exp_score == 0 ? 1e-6f : sum_exp_score); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / safe_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (uint i = 0; i < v_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +#define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \ + template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \ + "_" #VALUE_DIM)]] kernel void \ + sdpa_vector( \ + const device DTYPE* queries [[buffer(0)]], \ + const device DTYPE* keys [[buffer(1)]], \ + const device DTYPE* values [[buffer(2)]], \ + device DTYPE* out [[buffer(3)]], \ + constant uint& gqa_factor [[buffer(4)]], \ + constant uint& N [[buffer(5)]], \ + constant uint3& qkv_head_strides [[buffer(6)]], \ + constant uint3& qkv_seq_strides [[buffer(7)]], \ + constant float& scale [[buffer(8)]], \ + const device DTYPE* mask [[buffer(9)]], \ + constant uint3& mask_strides [[buffer(10)]], \ + constant bool& has_mask [[buffer(11)]], \ + constant uint3& qkv_batch_strides [[buffer(12)]], \ + constant uint& num_q_heads [[buffer(13)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 tpg [[threadgroups_per_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \ + INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); + +INSTANTIATE_SDPA_VECTOR_HEADS(float); +INSTANTIATE_SDPA_VECTOR_HEADS(half); +INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); +)"; +} + +// Global shader library cache for SDPA +static std::unique_ptr sdpa_shader_library = nullptr; + +static ETMetalShaderLibrary* get_sdpa_shader_library() { + if (!sdpa_shader_library) { + std::string source = get_sdpa_metal_source(); + sdpa_shader_library = std::make_unique(source); + } + return sdpa_shader_library.get(); +} + } // anonymous namespace extern "C" { @@ -930,7 +1158,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( AOTITensorHandle* ret0, AOTITensorHandle* ret1) { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with MPSGraph implementation"); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with Metal kernel implementation"); if (!query || !key || !value || !ret0 || !ret1) { ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: null required tensor handles"); @@ -953,560 +1181,347 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Converted tensor handles to ET tensors"); - // Validate tensor dimensions - if (query_tensor->dim() < 3 || key_tensor->dim() < 3 || value_tensor->dim() < 3) { - std::string error_msg = "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: tensors must be at least 3-D, got " + - std::to_string(query_tensor->dim()) + ", " + - std::to_string(key_tensor->dim()) + ", " + - std::to_string(value_tensor->dim()); - ET_LOG(Error, "%s", error_msg.c_str()); - throw std::runtime_error(error_msg); - } + // Log query tensor shape and strides + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor - dim=%d, shape=[%lld, %lld, %lld, %lld], strides=[%lld, %lld, %lld, %lld]", + (int)query_tensor->dim(), + query_tensor->dim() > 0 ? query_tensor->sizes()[0] : 0, + query_tensor->dim() > 1 ? query_tensor->sizes()[1] : 0, + query_tensor->dim() > 2 ? query_tensor->sizes()[2] : 0, + query_tensor->dim() > 3 ? query_tensor->sizes()[3] : 0, + query_tensor->dim() > 0 ? query_tensor->strides()[0] : 0, + query_tensor->dim() > 1 ? query_tensor->strides()[1] : 0, + query_tensor->dim() > 2 ? query_tensor->strides()[2] : 0, + query_tensor->dim() > 3 ? query_tensor->strides()[3] : 0); + + // Log key tensor shape and strides + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor - dim=%d, shape=[%lld, %lld, %lld, %lld], strides=[%lld, %lld, %lld, %lld]", + (int)key_tensor->dim(), + key_tensor->dim() > 0 ? key_tensor->sizes()[0] : 0, + key_tensor->dim() > 1 ? key_tensor->sizes()[1] : 0, + key_tensor->dim() > 2 ? key_tensor->sizes()[2] : 0, + key_tensor->dim() > 3 ? key_tensor->sizes()[3] : 0, + key_tensor->dim() > 0 ? key_tensor->strides()[0] : 0, + key_tensor->dim() > 1 ? key_tensor->strides()[1] : 0, + key_tensor->dim() > 2 ? key_tensor->strides()[2] : 0, + key_tensor->dim() > 3 ? key_tensor->strides()[3] : 0); + + // Log value tensor shape and strides + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor - dim=%d, shape=[%lld, %lld, %lld, %lld], strides=[%lld, %lld, %lld, %lld]", + (int)value_tensor->dim(), + value_tensor->dim() > 0 ? value_tensor->sizes()[0] : 0, + value_tensor->dim() > 1 ? value_tensor->sizes()[1] : 0, + value_tensor->dim() > 2 ? value_tensor->sizes()[2] : 0, + value_tensor->dim() > 3 ? value_tensor->sizes()[3] : 0, + value_tensor->dim() > 0 ? value_tensor->strides()[0] : 0, + value_tensor->dim() > 1 ? value_tensor->strides()[1] : 0, + value_tensor->dim() > 2 ? value_tensor->strides()[2] : 0, + value_tensor->dim() > 3 ? value_tensor->strides()[3] : 0); - // Get tensor dimensions (assuming [batch, num_heads, seq_len, head_dim] format) - int64_t batchSize = query_tensor->sizes()[0]; - int64_t num_heads = query_tensor->sizes()[1]; - int64_t qSize = query_tensor->sizes()[2]; - int64_t headSize = query_tensor->sizes()[3]; - int64_t kvSeqLength = key_tensor->sizes()[2]; - - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: batchSize=%lld, num_heads=%lld, qSize=%lld, headSize=%lld, kvSeqLength=%lld", - batchSize, num_heads, qSize, headSize, kvSeqLength); - - // Detect non-contiguous layouts for query, key, and value tensors - // For a 4D tensor [batch, num_heads, seq_len, head_dim], common non-contiguous patterns: - // - Transposed last 2 dims (dims 2,3): strides[2] == 1 && strides[3] == seq_len (seq_len and head_dim swapped) - // - Transposed internal dims (dims 1,2): strides[1] == head_dim && strides[2] == num_heads*head_dim (num_heads and seq_len swapped) - // - Other permutations may exist depending on upstream operations - - bool query_is_transposed_last2 = false; // transpose of dims -2 and -1 - bool query_is_transposed_internal = false; // transpose of dims 1 and 2 - bool key_is_transposed_last2 = false; - bool key_is_transposed_internal = false; - bool value_is_transposed_last2 = false; - bool value_is_transposed_internal = false; - - // Expected contiguous strides for query [batch, num_heads, qSize, headSize] - int64_t expected_q_stride_3 = 1; - int64_t expected_q_stride_2 = headSize; - int64_t expected_q_stride_1 = qSize * headSize; - int64_t expected_q_stride_0 = num_heads * qSize * headSize; - - // Check query tensor layout - auto q_strides = query_tensor->strides(); - if (q_strides[3] != expected_q_stride_3 || q_strides[2] != expected_q_stride_2 || - q_strides[1] != expected_q_stride_1) { - // Check if it's a transpose of the last two dimensions (dims 2 and 3) - if (q_strides[2] == 1 && q_strides[3] == qSize && q_strides[1] == qSize * headSize) { - query_is_transposed_last2 = true; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", - (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); - } - // Check if it's a transpose of the internal dimensions (dims 1 and 2) - else if (q_strides[1] == headSize && q_strides[2] == num_heads * headSize && q_strides[3] == 1) { - query_is_transposed_internal = true; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", - (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); - } else { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", - (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); - } - } else { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", - (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); - } + // Validate tensor dimensions + if (query_tensor->dim() < 3 || key_tensor->dim() < 3 || value_tensor->dim() < 3) { + std::string error_msg = "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: tensors must be at least 3-D, got " + + std::to_string(query_tensor->dim()) + ", " + + std::to_string(key_tensor->dim()) + ", " + + std::to_string(value_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } - // Expected contiguous strides for key [batch, num_heads, kvSeqLength, headSize] - int64_t expected_k_stride_3 = 1; - int64_t expected_k_stride_2 = headSize; - int64_t expected_k_stride_1 = kvSeqLength * headSize; - int64_t expected_k_stride_0 = num_heads * kvSeqLength * headSize; - - // Check key tensor layout - auto k_strides = key_tensor->strides(); - if (k_strides[3] != expected_k_stride_3 || k_strides[2] != expected_k_stride_2 || - k_strides[1] != expected_k_stride_1) { - // Check if it's a transpose of the last two dimensions (dims 2 and 3) - if (k_strides[2] == 1 && k_strides[3] == kvSeqLength && k_strides[1] == kvSeqLength * headSize) { - key_is_transposed_last2 = true; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", - (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); - } - // Check if it's a transpose of the internal dimensions (dims 1 and 2) - else if (k_strides[1] == headSize && k_strides[2] == num_heads * headSize && k_strides[3] == 1) { - key_is_transposed_internal = true; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", - (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); - } else { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", - (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); - } - } else { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", - (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); - } + // Get tensor dimensions (assuming [batch, num_heads, seq_len, head_dim] format) + int64_t batchSize = query_tensor->sizes()[0]; + int64_t num_heads = query_tensor->sizes()[1]; + int64_t qSize = query_tensor->sizes()[2]; + int64_t headSize = query_tensor->sizes()[3]; + int64_t kvSeqLength = key_tensor->sizes()[2]; - // Expected contiguous strides for value [batch, num_heads, kvSeqLength, headSize] - int64_t expected_v_stride_3 = 1; - int64_t expected_v_stride_2 = headSize; - int64_t expected_v_stride_1 = kvSeqLength * headSize; - int64_t expected_v_stride_0 = num_heads * kvSeqLength * headSize; - - // Check value tensor layout - auto v_strides = value_tensor->strides(); - if (v_strides[3] != expected_v_stride_3 || v_strides[2] != expected_v_stride_2 || - v_strides[1] != expected_v_stride_1) { - // Check if it's a transpose of the last two dimensions (dims 2 and 3) - if (v_strides[2] == 1 && v_strides[3] == kvSeqLength && v_strides[1] == kvSeqLength * headSize) { - value_is_transposed_last2 = true; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", - (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); - } - // Check if it's a transpose of the internal dimensions (dims 1 and 2) - else if (v_strides[1] == headSize && v_strides[2] == num_heads * headSize && v_strides[3] == 1) { - value_is_transposed_internal = true; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", - (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); - } else { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", - (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); - } - } else { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", - (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); - } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: batchSize=%lld, num_heads=%lld, qSize=%lld, headSize=%lld, kvSeqLength=%lld", + batchSize, num_heads, qSize, headSize, kvSeqLength); + + // Determine data type and element size + int32_t dtype = static_cast(query_tensor->scalar_type()); + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for scaled dot product attention"); + } - // Determine data type and element size - int32_t dtype = static_cast(query_tensor->scalar_type()); - MPSDataType mps_dtype; - size_t element_size; - - if (dtype == static_cast(SupportedDTypes::FLOAT32)) { - mps_dtype = MPSDataTypeFloat32; - element_size = sizeof(float); - } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { - mps_dtype = MPSDataTypeBFloat16; - element_size = sizeof(uint16_t); // bfloat16 is 16 bits + // Check that headSize is not zero to avoid division by zero + if (headSize == 0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: headSize is zero"); + throw std::runtime_error("headSize must be non-zero for scaled dot product attention"); + } + + // Calculate scale factor + double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast(headSize))); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor); + + // Calculate output tensor dimensions + std::vector output_sizes = {batchSize, num_heads, qSize, headSize}; + std::vector attn_sizes = {batchSize, num_heads, qSize, kvSeqLength}; + + // Calculate strides for contiguous tensors + std::vector out_strides = { + num_heads * qSize * headSize, + qSize * headSize, + headSize, + 1 + }; + + std::vector attn_strides = { + num_heads * qSize * kvSeqLength, + qSize * kvSeqLength, + kvSeqLength, + 1 + }; + + // Allocate output Metal buffers via AOTI API to keep GPU residency and reuse + size_t out_size_bytes = batchSize * num_heads * qSize * headSize * element_size; + size_t attn_size_bytes = batchSize * num_heads * qSize * kvSeqLength * element_size; + + void* out_contents_ptr = nullptr; + allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); + + void* attn_contents_ptr = nullptr; + allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); + + // Use MLX-style Metal kernels instead of MPSGraph + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MLX Metal kernels"); + + // Get shader library + ETMetalShaderLibrary* library = get_sdpa_shader_library(); + if (!library) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get shader library"); + throw std::runtime_error("Failed to get SDPA shader library"); + } + + // Determine kernel name based on dtype and head_dim (PyTorch format) + std::string type_name; + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + type_name = "float"; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + type_name = "bfloat"; + } else { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported dtype for Metal kernel"); + throw std::runtime_error("Unsupported dtype for Metal SDPA kernel"); + } + + // Select head_dim - must match exactly one of the supported sizes (64, 96, 128) + int64_t head_dim = headSize; + if (head_dim != 64 && head_dim != 96 && head_dim != 128) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim); + throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128"); + } + + std::string kernel_name = "sdpa_vector_" + type_name + "_" + std::to_string(head_dim) + "_" + std::to_string(head_dim); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using kernel: %s", kernel_name.c_str()); + + // Get kernel function + auto kernel_func = library->getKernelFunction(kernel_name); + if (!kernel_func) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get kernel function: %s", kernel_name.c_str()); + throw std::runtime_error("Failed to get SDPA kernel function"); + } + + // Create output tensor handle first so we can use it in the kernel + AOTITensorHandle out_tensor_handle = nullptr; + AOTITorchError create_out_result = aoti_torch_create_tensor_from_blob_v2( + out_contents_ptr, + 4, // ndim + output_sizes.data(), + out_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &out_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_out_result != Error::Ok || !out_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create output tensor"); + aoti_torch_mps_free(out_contents_ptr); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Failed to create output tensor"); + } + + // Mark that we own the memory + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[out_contents_ptr] = 1; + + auto* out_tensor = reinterpret_cast(out_tensor_handle); + + // Prepare kernel arguments (PyTorch format) + uint gqa_factor = static_cast(num_heads / key_tensor->sizes()[1]); + uint N = static_cast(kvSeqLength); + + // Get strides for Q, K, V (all 3 stride levels: batch, head, seq) + uint q_batch_stride = static_cast(query_tensor->strides()[0]); + uint q_head_stride = static_cast(query_tensor->strides()[1]); + uint q_seq_stride = static_cast(query_tensor->strides()[2]); + uint q_dim_stride = static_cast(query_tensor->strides()[3]); + + uint k_batch_stride = static_cast(key_tensor->strides()[0]); + uint k_head_stride = static_cast(key_tensor->sizes()[1] == 1 ? key_tensor->strides()[0] : key_tensor->strides()[1]); + uint k_seq_stride = static_cast(key_tensor->strides()[2]); + uint k_dim_stride = static_cast(key_tensor->strides()[3]); + + uint v_batch_stride = static_cast(value_tensor->strides()[0]); + uint v_head_stride = static_cast(value_tensor->sizes()[1] == 1 ? value_tensor->strides()[0] : value_tensor->strides()[1]); + uint v_seq_stride = static_cast(value_tensor->strides()[2]); + uint v_dim_stride = static_cast(value_tensor->strides()[3]); + + // Log strides for debugging + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Q strides - batch:%u, head:%u, seq:%u, dim:%u", + q_batch_stride, q_head_stride, q_seq_stride, q_dim_stride); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: K strides - batch:%u, head:%u, seq:%u, dim:%u", + k_batch_stride, k_head_stride, k_seq_stride, k_dim_stride); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: V strides - batch:%u, head:%u, seq:%u, dim:%u", + v_batch_stride, v_head_stride, v_seq_stride, v_dim_stride); + + // Check if middle dimensions (1 and 2) are transposed + // For contiguous [batch, num_heads, seq, dim]: stride[1] > stride[2] (head_stride > seq_stride) + // For transposed [batch, seq, num_heads, dim] in memory: stride[1] < stride[2] (head_stride < seq_stride) + bool q_transposed = (q_head_stride < q_seq_stride); + bool k_transposed = (k_head_stride < k_seq_stride); + bool v_transposed = (v_head_stride < v_seq_stride); + + if (q_transposed || k_transposed || v_transposed) { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Transposed middle dimensions detected (dims 1&2 swapped)! Q:%d, K:%d, V:%d", q_transposed, k_transposed, v_transposed); + ET_LOG(Debug, " For transposed layout: head_stride < seq_stride"); + ET_LOG(Debug, " Q: head_stride=%u, seq_stride=%u (transposed=%d)", q_head_stride, q_seq_stride, q_transposed); + ET_LOG(Debug, " K: head_stride=%u, seq_stride=%u (transposed=%d)", k_head_stride, k_seq_stride, k_transposed); + ET_LOG(Debug, " V: head_stride=%u, seq_stride=%u (transposed=%d)", v_head_stride, v_seq_stride, v_transposed); + ET_LOG(Debug, " The updated kernel will handle this by decomposing batch and head indices."); + } + + // Verify innermost dimension has stride=1 (required by current kernel implementation) + if (q_dim_stride != 1 || k_dim_stride != 1 || v_dim_stride != 1) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Non-unit dim stride detected!"); + ET_LOG(Error, " Q dim_stride=%u, K dim_stride=%u, V dim_stride=%u", q_dim_stride, k_dim_stride, v_dim_stride); + ET_LOG(Error, " Current kernel implementation requires innermost dimension to be contiguous (stride=1)"); + throw std::runtime_error("SDPA Metal kernel requires innermost dimension to be contiguous (dim_stride must be 1)"); + } + + bool has_mask_val = (attn_mask && *attn_mask); + + // Calculate mask strides if mask is present + uint mask_head_stride = 0; + uint mask_kv_seq_stride = 0; + uint mask_q_seq_stride = 0; + if (has_mask_val) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + int nd = mask_tensor->dim(); + mask_kv_seq_stride = (nd >= 1 && mask_tensor->sizes()[nd - 1] > 1) ? static_cast(mask_tensor->strides()[nd - 1]) : 0; + mask_q_seq_stride = (nd >= 2 && mask_tensor->sizes()[nd - 2] > 1) ? static_cast(mask_tensor->strides()[nd - 2]) : 0; + mask_head_stride = (nd >= 3 && mask_tensor->sizes()[nd - 3] > 1) ? static_cast(mask_tensor->strides()[nd - 3]) : 0; + } + + // Execute kernel + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Preparing to execute kernel with grid [%llu, %llu, %llu], group [1024, 1, 1]", + (unsigned long long)(batchSize * num_heads), (unsigned long long)qSize, 1ULL); + + kernel_func->runCommandBlock([&]() { + kernel_func->startEncoding(); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Encoder started, setting arguments"); + + // Set buffer arguments (0-3: Q, K, V, out) + kernel_func->setArg(0, *query_tensor); + kernel_func->setArg(1, *key_tensor); + kernel_func->setArg(2, *value_tensor); + kernel_func->setArg(3, *out_tensor); + + // Set scalar arguments (uint values) + kernel_func->setArg(4, gqa_factor); + kernel_func->setArg(5, N); + + // Set uint3 for qkv_head_strides (buffer 6) + kernel_func->setArgUint3(6, q_head_stride, k_head_stride, v_head_stride); + + // Set uint3 for qkv_seq_strides (buffer 7) + kernel_func->setArgUint3(7, q_seq_stride, k_seq_stride, v_seq_stride); + + // Set scale as float (buffer 8) + kernel_func->setArg(8, static_cast(scale_factor)); + + // Set mask buffer (buffer 9) + if (has_mask_val) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + kernel_func->setArg(9, *mask_tensor); } else { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported data type: %d", dtype); - throw std::runtime_error("Unsupported data type for scaled dot product attention"); + // Dummy buffer if no mask (won't be accessed) + kernel_func->setArg(9, *query_tensor); } - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + // Set uint3 for mask_strides (buffer 10) + kernel_func->setArgUint3(10, mask_head_stride, mask_kv_seq_stride, mask_q_seq_stride); - // Check that headSize is not zero to avoid division by zero - if (headSize == 0) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: headSize is zero"); - throw std::runtime_error("headSize must be non-zero for scaled dot product attention"); - } + // Set has_mask as bool (buffer 11) + kernel_func->setArg(11, has_mask_val); - // Calculate scale factor - double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast(headSize))); - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor); + // Set uint3 for qkv_batch_strides (buffer 12) - NEW + kernel_func->setArgUint3(12, q_batch_stride, k_batch_stride, v_batch_stride); - // Get Metal device - id device = get_metal_device(); - if (!device) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get Metal device"); - throw std::runtime_error("Failed to get Metal device"); - } + // Set num_q_heads (buffer 13) - NEW + kernel_func->setArg(13, static_cast(num_heads)); - // Get Metal buffers for query, key and value tensors - id query_buffer = get_mtl_buffer(query_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "query"); - id key_buffer = get_mtl_buffer(key_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "key"); - id value_buffer = get_mtl_buffer(value_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "value"); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: All arguments set, dispatching"); - // Calculate output tensor dimensions - std::vector output_sizes = {batchSize, num_heads, qSize, headSize}; - std::vector attn_sizes = {batchSize, num_heads, qSize, kvSeqLength}; + // Dispatch using threadgroups (PyTorch uses grid: [batch*heads, qSize, 1], group: [1024, 1, 1]) + // Note: We need to use dispatchThreadgroups, not dispatchThreads + // Each threadgroup processes one query token across all key-value tokens + kernel_func->dispatchThreadgroups( + batchSize * num_heads, // gridX + qSize, // gridY + 1, // gridZ + 1024, // threadsX + 1, // threadsY + 1); // threadsZ + }); - // Calculate strides for contiguous tensors - std::vector out_strides = { - num_heads * qSize * headSize, - qSize * headSize, - headSize, - 1 - }; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Command block completed"); - std::vector attn_strides = { - num_heads * qSize * kvSeqLength, - qSize * kvSeqLength, - kvSeqLength, - 1 - }; + // Create attention weights tensor handle (zero-filled) + std::memset(attn_contents_ptr, 0, attn_size_bytes); - // Allocate output Metal buffers via AOTI API to keep GPU residency and reuse - size_t out_size_bytes = batchSize * num_heads * qSize * headSize * element_size; - size_t attn_size_bytes = batchSize * num_heads * qSize * kvSeqLength * element_size; - - void* out_contents_ptr = nullptr; - id out_buffer = allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); - - void* attn_contents_ptr = nullptr; - id attn_weights_buffer = allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); - - // End any existing kernel coalescing to ensure a clean state for MPS - stream->endKernelCoalescing(); - - // Method 1: Using MPSGraph scaledDotProductAttention API - with detailed error handling - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention"); - - @try { - // Create MPSGraph for scaled dot product attention - // TODO: Implement caching for attention operation similar to mm and convolution - MPSGraph* mpsGraph = [MPSGraph new]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance"); - - // Define physical tensor shapes for placeholders (matching actual memory layout) - // Two transpose patterns supported: - // 1. Last 2 dims transposed (dims 2,3): [batch, num_heads, head_dim, seq_len] - // 2. Internal dims transposed (dims 1,2): [batch, seq_len, num_heads, head_dim] - NSArray* queryPhysicalShape; - NSArray* keyPhysicalShape; - NSArray* valuePhysicalShape; - - if (query_is_transposed_last2) { - // Physical layout: [batch, num_heads, headSize, qSize] (dims 2,3 swapped) - queryPhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(qSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (transposed dims 2,3): [%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)headSize, (int)qSize); - } else if (query_is_transposed_internal) { - // Physical layout: [batch, qSize, num_heads, headSize] (dims 1,2 swapped) - queryPhysicalShape = @[@(batchSize), @(qSize), @(num_heads), @(headSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (transposed dims 1,2): [%d,%d,%d,%d]", - (int)batchSize, (int)qSize, (int)num_heads, (int)headSize); - } else { - // Physical layout matches logical layout: [batch, num_heads, qSize, headSize] - queryPhysicalShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (contiguous): [%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)qSize, (int)headSize); - } - - if (key_is_transposed_last2) { - // Physical layout: [batch, num_heads, headSize, kvSeqLength] (dims 2,3 swapped) - keyPhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(kvSeqLength)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (transposed dims 2,3): [%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)headSize, (int)kvSeqLength); - } else if (key_is_transposed_internal) { - // Physical layout: [batch, kvSeqLength, num_heads, headSize] (dims 1,2 swapped) - keyPhysicalShape = @[@(batchSize), @(kvSeqLength), @(num_heads), @(headSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (transposed dims 1,2): [%d,%d,%d,%d]", - (int)batchSize, (int)kvSeqLength, (int)num_heads, (int)headSize); - } else { - // Physical layout matches logical layout: [batch, num_heads, kvSeqLength, headSize] - keyPhysicalShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (contiguous): [%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); - } - - if (value_is_transposed_last2) { - // Physical layout: [batch, num_heads, headSize, kvSeqLength] (dims 2,3 swapped) - valuePhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(kvSeqLength)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (transposed dims 2,3): [%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)headSize, (int)kvSeqLength); - } else if (value_is_transposed_internal) { - // Physical layout: [batch, kvSeqLength, num_heads, headSize] (dims 1,2 swapped) - valuePhysicalShape = @[@(batchSize), @(kvSeqLength), @(num_heads), @(headSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (transposed dims 1,2): [%d,%d,%d,%d]", - (int)batchSize, (int)kvSeqLength, (int)num_heads, (int)headSize); - } else { - // Physical layout matches logical layout: [batch, num_heads, kvSeqLength, headSize] - valuePhysicalShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (contiguous): [%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); - } - - // Create placeholders for input tensors with physical shapes - MPSGraphTensor* queryPlaceholder = [mpsGraph placeholderWithShape:queryPhysicalShape - dataType:mps_dtype - name:@"query_physical"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created query placeholder"); - - MPSGraphTensor* keyPlaceholder = [mpsGraph placeholderWithShape:keyPhysicalShape - dataType:mps_dtype - name:@"key_physical"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created key placeholder"); - - MPSGraphTensor* valuePlaceholder = [mpsGraph placeholderWithShape:valuePhysicalShape - dataType:mps_dtype - name:@"value_physical"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created value placeholder"); - - // Apply transpose operations in the graph to convert physical to logical layout - // Logical shapes needed for SDPA: Q[batch, num_heads, qSize, headSize], - // K[batch, num_heads, kvSeqLength, headSize], - // V[batch, num_heads, kvSeqLength, headSize] - MPSGraphTensor* queryLogical; - MPSGraphTensor* keyLogical; - MPSGraphTensor* valueLogical; - - if (query_is_transposed_last2) { - // Transpose dims 2,3: [batch, num_heads, headSize, qSize] → [batch, num_heads, qSize, headSize] - queryLogical = [mpsGraph transposeTensor:queryPlaceholder - dimension:-2 - withDimension:-1 - name:@"query_transposed_last2"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to query tensor in graph"); - } else if (query_is_transposed_internal) { - // Transpose dims 1,2: [batch, qSize, num_heads, headSize] → [batch, num_heads, qSize, headSize] - queryLogical = [mpsGraph transposeTensor:queryPlaceholder - dimension:1 - withDimension:2 - name:@"query_transposed_internal"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to query tensor in graph"); - } else { - queryLogical = queryPlaceholder; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using query placeholder directly (no transpose needed)"); - } - - if (key_is_transposed_last2) { - // Transpose dims 2,3: [batch, num_heads, headSize, kvSeqLength] → [batch, num_heads, kvSeqLength, headSize] - keyLogical = [mpsGraph transposeTensor:keyPlaceholder - dimension:-2 - withDimension:-1 - name:@"key_transposed_last2"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to key tensor in graph"); - } else if (key_is_transposed_internal) { - // Transpose dims 1,2: [batch, kvSeqLength, num_heads, headSize] → [batch, num_heads, kvSeqLength, headSize] - keyLogical = [mpsGraph transposeTensor:keyPlaceholder - dimension:1 - withDimension:2 - name:@"key_transposed_internal"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to key tensor in graph"); - } else { - keyLogical = keyPlaceholder; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using key placeholder directly (no transpose needed)"); - } - - if (value_is_transposed_last2) { - // Transpose dims 2,3: [batch, num_heads, headSize, kvSeqLength] → [batch, num_heads, kvSeqLength, headSize] - valueLogical = [mpsGraph transposeTensor:valuePlaceholder - dimension:-2 - withDimension:-1 - name:@"value_transposed_last2"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to value tensor in graph"); - } else if (value_is_transposed_internal) { - // Transpose dims 1,2: [batch, kvSeqLength, num_heads, headSize] → [batch, num_heads, kvSeqLength, headSize] - valueLogical = [mpsGraph transposeTensor:valuePlaceholder - dimension:1 - withDimension:2 - name:@"value_transposed_internal"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to value tensor in graph"); - } else { - valueLogical = valuePlaceholder; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using value placeholder directly (no transpose needed)"); - } - - MPSGraphTensor* maskTensor = nil; - - // Handle causal mask - if (is_causal) { - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Creating causal mask"); - - // Create a causal mask: lower triangular matrix filled with 0s, upper triangle with -inf - // Shape should be [qSize, kvSeqLength] - NSArray* maskShape = @[@(qSize), @(kvSeqLength)]; - - // Create ones tensor - MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f - shape:maskShape - dataType:mps_dtype]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created ones tensor for causal mask"); - - // Create lower triangular mask (including diagonal) - MPSGraphTensor* causalMask = [mpsGraph bandPartWithTensor:onesTensor - numLower:-1 - numUpper:0 - name:@"causal_mask"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created causal mask using bandPartWithTensor"); - - // Convert mask to attention weights format: 0 for allowed positions, -inf for masked - MPSGraphTensor* zerosTensor = [mpsGraph constantWithScalar:0.0f - shape:maskShape - dataType:mps_dtype]; - - MPSGraphTensor* negInfTensor = [mpsGraph constantWithScalar:-1e9f - shape:maskShape - dataType:mps_dtype]; - - // Select: where causal_mask == 1, use 0.0, else use -inf - maskTensor = [mpsGraph selectWithPredicateTensor:causalMask - truePredicateTensor:zerosTensor - falsePredicateTensor:negInfTensor - name:@"causal_mask_final"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created final causal mask using selectWithPredicateTensor"); - } - - // Handle explicit attention mask if provided - MPSGraphTensor* explicitMaskPlaceholder = nil; - if (attn_mask && *attn_mask) { - auto* mask_tensor = reinterpret_cast(*attn_mask); - - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Adding explicit attention mask"); - - // Create mask placeholder - NSMutableArray* maskShapeArray = [NSMutableArray array]; - for (int i = 0; i < mask_tensor->dim(); i++) { - [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; - } - - explicitMaskPlaceholder = [mpsGraph placeholderWithShape:maskShapeArray - dataType:mps_dtype - name:@"attention_mask"]; - - if (maskTensor) { - // Combine causal and explicit masks - maskTensor = [mpsGraph additionWithPrimaryTensor:maskTensor - secondaryTensor:explicitMaskPlaceholder - name:@"combined_mask"]; - } else { - maskTensor = explicitMaskPlaceholder; - } - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created explicit mask placeholder"); - } - - // Perform scaled dot product attention using MPSGraph with logical (possibly transposed) tensors - // The logical tensors have the correct shapes for attention computation regardless of input memory layout - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Calling scaledDotProductAttentionWithQueryTensor with scale=%f", scale_factor); - - MPSGraphTensor* outputTensor = [mpsGraph scaledDotProductAttentionWithQueryTensor:queryLogical - keyTensor:keyLogical - valueTensor:valueLogical - maskTensor:maskTensor - scale:scale_factor - name:@"scaled_dot_product_attention"]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Successfully created SDPA tensor"); - - // Create feeds dictionary for graph execution - NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created feeds dictionary"); - - // Create MPSGraphTensorData objects for input tensors using physical shapes - // Physical shapes match the actual memory layout of the tensors - MPSGraphTensorData* queryData = [[MPSGraphTensorData alloc] initWithMTLBuffer:query_buffer - shape:queryPhysicalShape - dataType:mps_dtype]; - MPSGraphTensorData* keyData = [[MPSGraphTensorData alloc] initWithMTLBuffer:key_buffer - shape:keyPhysicalShape - dataType:mps_dtype]; - MPSGraphTensorData* valueData = [[MPSGraphTensorData alloc] initWithMTLBuffer:value_buffer - shape:valuePhysicalShape - dataType:mps_dtype]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraphTensorData objects with physical shapes"); - - feeds[queryPlaceholder] = queryData; - feeds[keyPlaceholder] = keyData; - feeds[valuePlaceholder] = valueData; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds"); - - MPSGraphTensorData* maskData = nil; - - // Add explicit mask data to feeds if provided - if (explicitMaskPlaceholder && attn_mask && *attn_mask) { - auto* mask_tensor = reinterpret_cast(*attn_mask); - // Get Metal buffer for mask - id mask_buffer = get_mtl_buffer(mask_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "mask"); - - NSMutableArray* maskShapeArray = [NSMutableArray array]; - for (int i = 0; i < mask_tensor->dim(); i++) { - [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; - } - - maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer - shape:maskShapeArray - dataType:mps_dtype]; - feeds[explicitMaskPlaceholder] = maskData; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds"); - } - - // Create results dictionary - NSArray* outputShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; - MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer - shape:outputShape - dataType:mps_dtype]; - - NSDictionary* results = @{outputTensor: outputData}; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created results dictionary"); - - // Execute via shared stream and keep results on GPU - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream"); - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully"); - - // Release MPSGraph to prevent memory leak - [mpsGraph release]; - mpsGraph = nil; - - [queryData release]; - [keyData release]; - [valueData release]; - if (maskData) [maskData release]; - [outputData release]; - - } @catch (NSException *exception) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s", - [[exception name] UTF8String], [[exception reason] UTF8String]); - throw std::runtime_error("MPSGraph operation failed with NSException"); - } + AOTITensorHandle attn_tensor_handle = nullptr; + AOTITorchError create_attn_result = aoti_torch_create_tensor_from_blob_v2( + attn_contents_ptr, + 4, // ndim + attn_sizes.data(), + attn_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &attn_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); - // For attention weights, zero-fill the GPU buffer (shared memory allows CPU memset) - std::memset(attn_contents_ptr, 0, attn_size_bytes); - - // Create output tensor handles - AOTITensorHandle out_tensor_handle = nullptr; - AOTITensorHandle attn_tensor_handle = nullptr; - - AOTITorchError create_out_result = aoti_torch_create_tensor_from_blob_v2( - out_contents_ptr, - 4, // ndim - output_sizes.data(), - out_strides.data(), - 0, // storage_offset - dtype, - 13, // device_type (MPS) - 0, // device_index - &out_tensor_handle, - 0, // layout (strided) - nullptr, // opaque_metadata - 0 // opaque_metadata_size - ); - - AOTITorchError create_attn_result = aoti_torch_create_tensor_from_blob_v2( - attn_contents_ptr, - 4, // ndim - attn_sizes.data(), - attn_strides.data(), - 0, // storage_offset - dtype, - 13, // device_type (MPS) - 0, // device_index - &attn_tensor_handle, - 0, // layout (strided) - nullptr, // opaque_metadata - 0 // opaque_metadata_size - ); - - if (create_out_result != Error::Ok || create_attn_result != Error::Ok || - !out_tensor_handle || !attn_tensor_handle) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create output tensors"); - aoti_torch_mps_free(out_contents_ptr); - aoti_torch_mps_free(attn_contents_ptr); - throw std::runtime_error("Failed to create output tensors"); - } + if (create_attn_result != Error::Ok || !attn_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create attention weights tensor"); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Failed to create attention weights tensor"); + } - // Mark that we own the memory for these tensors - // Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2 - // The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it - extern std::unordered_map memory_to_n_tensor; - memory_to_n_tensor[out_contents_ptr] = 1; - memory_to_n_tensor[attn_contents_ptr] = 1; + memory_to_n_tensor[attn_contents_ptr] = 1; - // Set output tensor handles - *ret0 = out_tensor_handle; - *ret1 = attn_tensor_handle; + // Set output tensor handles + *ret0 = out_tensor_handle; + *ret1 = attn_tensor_handle; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph implementation completed successfully"); - } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Metal kernel implementation completed successfully"); + + } // @autoreleasepool ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executed successfully"); return Error::Ok; From 7bf57300e22abb41ce67006f805d22a7628af5af Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 8 Dec 2025 15:30:44 -0500 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 9 +++- .../apple/metal/runtime/shims/et_metal.mm | 29 +++++++++---- .../apple/metal/runtime/shims/et_metal_ops.mm | 43 +++++++++++++------ 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index f00cde5d22f..1c61499b242 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -199,8 +199,13 @@ class ETMetalKernelFunction { size_t group_size_size); // Dispatch with explicit threadgroup count (not thread count) - void dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ, - uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ); + void dispatchThreadgroups( + uint64_t gridX, + uint64_t gridY, + uint64_t gridZ, + uint64_t threadsX, + uint64_t threadsY, + uint64_t threadsZ); void runCommandBlock(std::function f); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index e6c919ba61f..f7d37c152ce 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -10,6 +10,7 @@ #import #import #import +#include #include #include #include @@ -423,14 +424,9 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev return; } - // Metal's uint3 is a packed struct of 3 uint32_t values - struct uint3 { - uint32_t x; - uint32_t y; - uint32_t z; - }; - uint3 val = {x, y, z}; - [encoder_ setBytes:&val length:sizeof(uint3) atIndex:idx]; + // Use SIMD library's uint3 type which matches Metal shader's uint3 layout + simd_uint3 val = {x, y, z}; + [encoder_ setBytes:&val length:sizeof(simd_uint3) atIndex:idx]; ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx); } @@ -566,6 +562,23 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev return; } + if (!cps_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No compute pipeline state"); + return; + } + + // Calculate total threads per threadgroup + uint64_t totalThreads = threadsX * threadsY * threadsZ; + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + // Validate total thread count + if (totalThreads > maxThreadsPerGroup) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: Requested %llu total threads per threadgroup exceeds device maximum of %llu", + (unsigned long long)totalThreads, (unsigned long long)maxThreadsPerGroup); + return; + } + MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ); MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ); diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 55d96aa9c49..da54dafb334 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -128,7 +128,7 @@ void logStats() { return R"( // Ported from PyTorch's Attention.metal // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mps/kernels/Attention.metal -// Largely influeneced by +// Largely influenced by // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/scaled_dot_product_attention.metal // Modified to support floating point masks and transposed middle dimensions (dims 1 & 2) @@ -138,9 +138,6 @@ void logStats() { using namespace metal; -typedef half float16_t; -typedef bfloat bfloat16_t; - // PyTorch's sdpa_vector kernel (one-pass variant) template [[kernel]] void sdpa_vector( @@ -334,7 +331,6 @@ void logStats() { INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128); INSTANTIATE_SDPA_VECTOR_HEADS(float); -INSTANTIATE_SDPA_VECTOR_HEADS(half); INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); )"; } @@ -342,11 +338,13 @@ void logStats() { // Global shader library cache for SDPA static std::unique_ptr sdpa_shader_library = nullptr; +static std::once_flag sdpa_shader_library_once_flag; + static ETMetalShaderLibrary* get_sdpa_shader_library() { - if (!sdpa_shader_library) { + std::call_once(sdpa_shader_library_once_flag, []() { std::string source = get_sdpa_metal_source(); sdpa_shader_library = std::make_unique(source); - } + }); return sdpa_shader_library.get(); } @@ -1165,6 +1163,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( return Error::InvalidArgument; } + if (is_causal) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: is_causal=True not implemented"); + return Error::NotImplemented; + } + if (dropout_p != 0.0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: dropout_p != 0 not implemented (dropout_p=%f)", dropout_p); + return Error::NotImplemented; + } + if (dropout_mask && *dropout_mask) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: dropout_mask provided not implemented"); + return Error::NotImplemented; + } + // Use the same dispatch pattern as other MPS operations for consistent synchronization ETMetalStream* stream = getCurrentMetalStream(); if (!stream) { @@ -1182,7 +1193,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Converted tensor handles to ET tensors"); // Log query tensor shape and strides - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor - dim=%d, shape=[%lld, %lld, %lld, %lld], strides=[%lld, %lld, %lld, %lld]", + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor - dim=%d, shape=[%d, %d, %d, %d], strides=[%d, %d, %d, %d]", (int)query_tensor->dim(), query_tensor->dim() > 0 ? query_tensor->sizes()[0] : 0, query_tensor->dim() > 1 ? query_tensor->sizes()[1] : 0, @@ -1194,7 +1205,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( query_tensor->dim() > 3 ? query_tensor->strides()[3] : 0); // Log key tensor shape and strides - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor - dim=%d, shape=[%lld, %lld, %lld, %lld], strides=[%lld, %lld, %lld, %lld]", + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor - dim=%d, shape=[%d, %d, %d, %d], strides=[%d, %d, %d, %d]", (int)key_tensor->dim(), key_tensor->dim() > 0 ? key_tensor->sizes()[0] : 0, key_tensor->dim() > 1 ? key_tensor->sizes()[1] : 0, @@ -1206,7 +1217,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( key_tensor->dim() > 3 ? key_tensor->strides()[3] : 0); // Log value tensor shape and strides - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor - dim=%d, shape=[%lld, %lld, %lld, %lld], strides=[%lld, %lld, %lld, %lld]", + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor - dim=%d, shape=[%d, %d, %d, %d], strides=[%d, %d, %d, %d]", (int)value_tensor->dim(), value_tensor->dim() > 0 ? value_tensor->sizes()[0] : 0, value_tensor->dim() > 1 ? value_tensor->sizes()[1] : 0, @@ -1256,6 +1267,13 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( throw std::runtime_error("headSize must be non-zero for scaled dot product attention"); } + // Validate key tensor head dimension to avoid division by zero in gqa_factor calculation + int64_t key_num_heads = key_tensor->sizes()[1]; + if (key_num_heads == 0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: key tensor head dimension (sizes()[1]) is zero"); + throw std::runtime_error("key tensor must have non-zero head dimension for scaled dot product attention"); + } + // Calculate scale factor double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast(headSize))); ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor); @@ -1416,7 +1434,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( uint mask_kv_seq_stride = 0; uint mask_q_seq_stride = 0; if (has_mask_val) { - auto* mask_tensor = reinterpret_cast(*attn_mask); + auto* mask_tensor = reinterpret_cast(*attn_mask); int nd = mask_tensor->dim(); mask_kv_seq_stride = (nd >= 1 && mask_tensor->sizes()[nd - 1] > 1) ? static_cast(mask_tensor->strides()[nd - 1]) : 0; mask_q_seq_stride = (nd >= 2 && mask_tensor->sizes()[nd - 2] > 1) ? static_cast(mask_tensor->strides()[nd - 2]) : 0; @@ -1488,9 +1506,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Command block completed"); - // Create attention weights tensor handle (zero-filled) - std::memset(attn_contents_ptr, 0, attn_size_bytes); - AOTITensorHandle attn_tensor_handle = nullptr; AOTITorchError create_attn_result = aoti_torch_create_tensor_from_blob_v2( attn_contents_ptr,