diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 2d396a296bd..7742cd53c2d 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -9,7 +9,7 @@ import typing from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Set import torch from executorch.backends.aoti.passes.replace_view_copy_with_view import ( @@ -91,39 +91,24 @@ def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str] ) def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, - ): + self, kernel: str, *args: Any, **kwargs: Any + ) -> None: if kernel not in supported_kernels: missing_fallback_kernels.add(kernel) - original_generate_c_shim_extern_kernel_call( - self, - kernel, - args, - device, - debug_args=debug_args, - debug_handle=debug_handle, + return original_generate_c_shim_extern_kernel_call( + self, kernel, *args, **kwargs ) def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( - self, - op_overload, - raw_args, - output_args, - raw_outputs, - ): + self, op_overload: Any, *args: Any, **kwargs: Any + ) -> None: kernel_name = getattr(op_overload, "_name", str(op_overload)) if kernel_name not in supported_kernels: missing_fallback_kernels.add(kernel_name) - original_generate_fallback_kernel_with_runtime_lookup_aot( - self, op_overload, raw_args, output_args, raw_outputs + return original_generate_fallback_kernel_with_runtime_lookup_aot( + self, op_overload, *args, **kwargs ) CppWrapperCpu.generate_c_shim_extern_kernel_call = (