Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import typing
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Set

import torch
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
Expand Down Expand Up @@ -91,39 +91,24 @@ def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]
)

def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
self,
kernel: str,
args: list[str],
device: str,
*,
debug_args: Optional[list[str]] = None,
debug_handle: Optional[int] = None,
):
self, kernel: str, *args: Any, **kwargs: Any
) -> None:
if kernel not in supported_kernels:
missing_fallback_kernels.add(kernel)

original_generate_c_shim_extern_kernel_call(
self,
kernel,
args,
device,
debug_args=debug_args,
debug_handle=debug_handle,
return original_generate_c_shim_extern_kernel_call(
self, kernel, *args, **kwargs
)

def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
self,
op_overload,
raw_args,
output_args,
raw_outputs,
):
self, op_overload: Any, *args: Any, **kwargs: Any
) -> None:
kernel_name = getattr(op_overload, "_name", str(op_overload))
if kernel_name not in supported_kernels:
missing_fallback_kernels.add(kernel_name)

original_generate_fallback_kernel_with_runtime_lookup_aot(
self, op_overload, raw_args, output_args, raw_outputs
return original_generate_fallback_kernel_with_runtime_lookup_aot(
self, op_overload, *args, **kwargs
)

CppWrapperCpu.generate_c_shim_extern_kernel_call = (
Expand Down
Loading