From 388a062d1599d01a9b446d0948286d7960279e8a Mon Sep 17 00:00:00 2001 From: drisspg Date: Sun, 18 Aug 2024 17:32:02 -0700 Subject: [PATCH] Update fused kernels and call _safe_softmax from SDPA (#4772) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4772 X-link: https://github.com/pytorch/pytorch/pull/131863 cc ezyang gchanan jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 Valentine233 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: Chillee Differential Revision: D61418679 Pulled By: drisspg --- .../coreml/test/test_coreml_partitioner.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 45c468e450b..34cf531b261 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -68,15 +68,23 @@ def test_vit_skip_conv(self): ) ) + conv_block = ["aten.convolution.default", "executorch_call_delegate"] + safe_softmax_block = [ + "getitem", + "getitem", + "getitem", + "getitem", + "aten.any.dim", + "executorch_call_delegate", + ] + final_block = ["getitem"] + total = conv_block + 12 * safe_softmax_block + final_block + assert [ node.target.__name__ for node in delegated_program_manager.exported_program().graph.nodes if node.op == "call_function" - ] == [ - "aten.convolution.default", - "executorch_call_delegate", - "getitem", - ] + ] == total if __name__ == "__main__":