diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 7e1fa66ac7c..b150c68fe6d 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -32,8 +32,69 @@ // Declare the global mapping from et_metal.mm extern std::unordered_map> ptr_to_mtl_buffer; +// ======================= +// MPSGraph Caching Infrastructure +// ======================= + namespace { +// Cache key structure for different operations +struct GraphCacheKey { + std::string op_name; + std::vector shape_params; + int32_t dtype; + bool transpose_flag; + + bool operator==(const GraphCacheKey& other) const { + return op_name == other.op_name && + shape_params == other.shape_params && + dtype == other.dtype && + transpose_flag == other.transpose_flag; + } +}; + +// Hash function for GraphCacheKey +struct GraphCacheKeyHash { + std::size_t operator()(const GraphCacheKey& key) const { + std::size_t hash = std::hash{}(key.op_name); + for (auto val : key.shape_params) { + hash ^= std::hash{}(val) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + hash ^= std::hash{}(key.dtype) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.transpose_flag) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + return hash; + } +}; + +// Struct to store both the compiled graph and its tensors for reuse +struct CachedGraph { + MPSGraph* graph; + MPSGraphTensor* input1; + MPSGraphTensor* input2; + MPSGraphTensor* input3; // Optional (e.g., bias, mask) + MPSGraphTensor* output; +}; + +// Global cache for compiled MPSGraphs +// These graphs are never released - they're reused across calls +static std::unordered_map graph_cache; + +// Statistics for monitoring cache effectiveness +struct CacheStats { + size_t hits = 0; + size_t misses = 0; + + void logStats() { + if ((hits + misses) % 100 == 0 && (hits + misses) > 0) { + double hit_rate = 100.0 * hits / (hits + misses); + ET_LOG(Debug, "MPSGraph cache stats: %zu hits, %zu misses (%.1f%% hit rate)", + hits, misses, hit_rate); + } + } +}; + +static CacheStats cache_stats; + // Helper function to get Metal buffer from the global mapping static id get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) { void* data_ptr = tensor->mutable_data_ptr(); @@ -61,7 +122,7 @@ return it->second; } -} // namespace +} // anonymous namespace extern "C" { @@ -180,13 +241,8 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: dtype=%d, element_size=%zu", dtype, element_size); ET_LOG(Debug, "aoti_torch_mps_mm_out: M=%lld, K=%lld, N=%lld", M, K, N); - // Create MPSGraph for matrix multiplication - MPSGraph* mpsGraph = [MPSGraph new]; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Created MPSGraph instance"); - - // Define tensor shapes for placeholders + // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* selfShape = @[@(M), @(K)]; - NSArray* outShape = @[@(M), @(N)]; // For mat2, we need to handle both contiguous and transposed cases // If mat2 is transposed, its physical layout in memory is [N, K] (column-major) @@ -202,43 +258,91 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (contiguous): [%d,%d]", (int)K, (int)N); } - ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", - (int)M, (int)K, - mat2_is_transposed ? (int)N : (int)K, - mat2_is_transposed ? (int)K : (int)N); + // Create cache key for this matrix multiplication + GraphCacheKey cache_key; + cache_key.op_name = "mm"; + cache_key.shape_params = {M, K, N}; + cache_key.dtype = dtype; + cache_key.transpose_flag = mat2_is_transposed; + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* mmOutput = nil; + MPSGraphTensor* selfPlaceholder = nil; + MPSGraphTensor* mat2Placeholder = nil; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + selfPlaceholder = cached.input1; + mat2Placeholder = cached.input2; + mmOutput = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits); - // Create placeholders for input tensors - MPSGraphTensor* selfPlaceholder = [mpsGraph placeholderWithShape:selfShape - dataType:mps_dtype - name:@"self"]; - MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape - dataType:mps_dtype - name:@"mat2_physical"]; + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", + (int)M, (int)K, + mat2_is_transposed ? (int)N : (int)K, + mat2_is_transposed ? (int)K : (int)N); + + // Create placeholders for input tensors + selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape + dataType:mps_dtype + name:@"mat2_physical"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + + // If mat2 is transposed, apply transpose operation in the graph to get the logical shape + MPSGraphTensor* mat2Logical; + if (mat2_is_transposed) { + // Transpose from physical [N, K] to logical [K, N] + // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors + mat2Logical = [mpsGraph transposeTensor:mat2Placeholder + dimension:-2 + withDimension:-1 + name:@"mat2_transposed"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); + } else { + // No transpose needed, use placeholder directly + mat2Logical = mat2Placeholder; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); + } - ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + // Perform matrix multiplication using MPSGraph with the logical mat2 tensor + mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Logical + name:@"matrix_multiplication"]; - // If mat2 is transposed, apply transpose operation in the graph to get the logical shape - MPSGraphTensor* mat2Logical; - if (mat2_is_transposed) { - // Transpose from physical [N, K] to logical [K, N] - // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors - mat2Logical = [mpsGraph transposeTensor:mat2Placeholder - dimension:-2 - withDimension:-1 - name:@"mat2_transposed"]; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); - } else { - // No transpose needed, use placeholder directly - mat2Logical = mat2Placeholder; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); - } + ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = selfPlaceholder; + cached_graph.input2 = mat2Placeholder; + cached_graph.input3 = nil; + cached_graph.output = mmOutput; + graph_cache[cache_key] = cached_graph; - // Perform matrix multiplication using MPSGraph with the logical mat2 tensor - MPSGraphTensor* mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder - secondaryTensor:mat2Logical - name:@"matrix_multiplication"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Cached compiled MPSGraph for future reuse"); + } // End of cache miss/hit block - ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + // Define output shape + NSArray* outShape = @[@(M), @(N)]; // Create feeds dictionary for graph execution NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; @@ -279,10 +383,6 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); - // Release MPSGraph to prevent memory leak - [mpsGraph release]; - mpsGraph = nil; - [selfData release]; [mat2Data release]; [outputData release]; @@ -502,106 +602,150 @@ AOTITorchError aoti_torch_mps_convolution( ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); - // Create MPSGraph for convolution - MPSGraph* mpsGraph = [MPSGraph new]; - ET_LOG(Debug, "aoti_torch_mps_convolution: Created MPSGraph instance"); - - // Define tensor shapes for placeholders (always 4D NCHW for MPSGraph) + // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; - ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", - (int)N, (int)C_in, (int)H_in, (int)W_in, - (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); - - // Create placeholders for input tensors - MPSGraphTensor* inputPlaceholder = [mpsGraph placeholderWithShape:inputShape - dataType:mps_dtype - name:@"input"]; - MPSGraphTensor* weightPlaceholder = [mpsGraph placeholderWithShape:weightShape - dataType:mps_dtype - name:@"weight"]; - - ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); - - // Create convolution descriptor - MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w - strideInY:stride_h - dilationRateInX:dil_w - dilationRateInY:dil_h - groups:groups - paddingLeft:pad_w - paddingRight:pad_w - paddingTop:pad_h - paddingBottom:pad_h - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW - weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; - - ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", - stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); - - // Perform convolution using MPSGraph + // Create cache key for this convolution + GraphCacheKey cache_key; + cache_key.op_name = "conv"; + cache_key.shape_params = {N, C_in, H_in, W_in, C_out, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, groups}; + cache_key.dtype = dtype; + cache_key.transpose_flag = (transposed != 0); + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; MPSGraphTensor* convOutput = nil; - if (transposed) { - ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); - // For transposed convolution, we need to handle output padding - int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; - int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; - - // For transposed convolution, we need to adjust the padding calculation - // In transposed convolution, the effective padding is typically negative - // and we use output_padding to control the final output size - int64_t transposed_pad_h = pad_h - output_pad_h; - int64_t transposed_pad_w = pad_w - output_pad_w; - - // Create transposed convolution descriptor with adjusted padding - MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w - strideInY:stride_h - dilationRateInX:dil_w - dilationRateInY:dil_h - groups:groups - paddingLeft:transposed_pad_w - paddingRight:transposed_pad_w - paddingTop:transposed_pad_h - paddingBottom:transposed_pad_h - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW - weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; - - convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder - weightsTensor:weightPlaceholder - descriptor:transposedConvDesc - name:@"transposed_convolution"]; + MPSGraphTensor* finalOutput = nil; + MPSGraphTensor* inputPlaceholder = nil; + MPSGraphTensor* weightPlaceholder = nil; + MPSGraphTensor* biasPlaceholder = nil; + bool has_bias = (bias_tensor != nullptr); + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + inputPlaceholder = cached.input1; + weightPlaceholder = cached.input2; + biasPlaceholder = cached.input3; // May be nil if no bias + finalOutput = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_convolution: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits); + } else { - ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); - convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder - weightsTensor:weightPlaceholder - descriptor:convDesc - name:@"convolution"]; - } + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_convolution: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", + (int)N, (int)C_in, (int)H_in, (int)W_in, + (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); + + // Create placeholders for input tensors + inputPlaceholder = [mpsGraph placeholderWithShape:inputShape + dataType:mps_dtype + name:@"input"]; + weightPlaceholder = [mpsGraph placeholderWithShape:weightShape + dataType:mps_dtype + name:@"weight"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); + + // Create convolution descriptor + MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:pad_w + paddingRight:pad_w + paddingTop:pad_h + paddingBottom:pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", + stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); + + // Perform convolution using MPSGraph + if (transposed) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); + // For transposed convolution, we need to handle output padding + int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; + int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; + + // For transposed convolution, we need to adjust the padding calculation + // In transposed convolution, the effective padding is typically negative + // and we use output_padding to control the final output size + int64_t transposed_pad_h = pad_h - output_pad_h; + int64_t transposed_pad_w = pad_w - output_pad_w; + + // Create transposed convolution descriptor with adjusted padding + MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:transposed_pad_w + paddingRight:transposed_pad_w + paddingTop:transposed_pad_h + paddingBottom:transposed_pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:transposedConvDesc + name:@"transposed_convolution"]; + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:convDesc + name:@"convolution"]; + } - ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); + ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); - // Handle bias if provided - MPSGraphTensor* finalOutput = convOutput; - MPSGraphTensor* biasPlaceholder = nil; - if (bias_tensor) { - ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); + // Handle bias if provided + if (bias_tensor) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); - // Create bias placeholder - NSArray* biasShape = @[@(C_out)]; - biasPlaceholder = [mpsGraph placeholderWithShape:biasShape - dataType:mps_dtype - name:@"bias"]; + // Create bias placeholder + NSArray* biasShape = @[@(C_out)]; + biasPlaceholder = [mpsGraph placeholderWithShape:biasShape + dataType:mps_dtype + name:@"bias"]; - // Add bias to convolution output - finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput - secondaryTensor:biasPlaceholder - name:@"add_bias"]; + // Add bias to convolution output + finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput + secondaryTensor:biasPlaceholder + name:@"add_bias"]; - ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); - } + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); + } else { + finalOutput = convOutput; + } + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = inputPlaceholder; + cached_graph.input2 = weightPlaceholder; + cached_graph.input3 = biasPlaceholder; // May be nil if no bias + cached_graph.output = finalOutput; + graph_cache[cache_key] = cached_graph; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Cached compiled MPSGraph for future reuse"); + } // End of cache miss block // Create feeds dictionary for graph execution NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; @@ -748,15 +892,12 @@ AOTITorchError aoti_torch_mps_convolution( // Store the tensor handle - mark that we own the memory since we manually allocated it *ret0 = output_tensor_handle; + // Mark that we own the memory for these tensors // Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2 // The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it extern std::unordered_map memory_to_n_tensor; memory_to_n_tensor[tensor_data] = 1; - // Release MPSGraph to prevent memory leak - [mpsGraph release]; - mpsGraph = nil; - [inputData release]; [weightData release]; if (biasData) [biasData release]; @@ -1012,6 +1153,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( @try { // Create MPSGraph for scaled dot product attention + // TODO: Implement caching for attention operation similar to mm and convolution MPSGraph* mpsGraph = [MPSGraph new]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance");