diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index a7f0f88cd9a..387d45bd3f5 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -106,7 +106,7 @@ def __init__( # Note: import this after portable_lib from executorch.extension.llm.custom_ops import ( # noqa - sdpa_with_kv_cache, # usort: skip + custom_ops, # usort: skip ) from executorch.kernels import quantized # noqa diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 62757506f3b..a6b055ced95 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -23,7 +23,7 @@ from executorch.examples.models.llama.runner.generation import LlamaRunner # Note: import this after portable_lib -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f8362648f32..65c5b68f7ad 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -99,7 +99,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: - from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa + from executorch.extension.llm.custom_ops import custom_ops # noqa _replace_sdpa_with_custom_op(module) return module diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index 2e50bcecf49..5fd60399415 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -18,7 +18,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa # usort: skip logging.basicConfig(level=logging.INFO) diff --git a/examples/models/llava/test/test_pte.py b/examples/models/llava/test/test_pte.py index 003b2b56755..f12d72f854b 100644 --- a/examples/models/llava/test/test_pte.py +++ b/examples/models/llava/test/test_pte.py @@ -14,7 +14,7 @@ from PIL import Image # Custom ops has to be loaded after portable_lib. -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa # usort: skip FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/extension/llm/README.md b/extension/llm/README.md index ad504966824..0f71088eea1 100644 --- a/extension/llm/README.md +++ b/extension/llm/README.md @@ -38,7 +38,7 @@ A sampler class in C++ to sample the logistics given some hyperparameters. ## custom_ops Contains custom op, such as: - custom sdpa: implements CPU flash attention and avoids copies by taking the kv cache as one of its arguments. - - _sdpa_with_kv_cache.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration. + - _custom_ops.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration. - _op_sdpa.cpp_: the optimized operator implementation and registration of _sdpa_with_kv_cache.out_. ## runner diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/custom_ops.py similarity index 99% rename from extension/llm/custom_ops/sdpa_with_kv_cache.py rename to extension/llm/custom_ops/custom_ops.py index be71425582c..26dac551a30 100644 --- a/extension/llm/custom_ops/sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -17,7 +17,6 @@ from torch.library import impl -# TODO rename this file to custom_ops_meta_registration.py try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index bb59f48a279..e3e8b30520f 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -81,7 +81,7 @@ def define_common_targets(): runtime.python_library( name = "custom_ops_aot_py", srcs = [ - "sdpa_with_kv_cache.py", + "custom_ops.py", ], visibility = [ "//executorch/...", diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index bfd64cb8975..9c8029c7b70 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F -from .sdpa_with_kv_cache import custom_ops_lib # noqa +from .custom_ops import custom_ops_lib # noqa def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len):