Skip to content

Commit

Permalink
[AOTInductor] Use ProxyExecutor for aten op if c-shim is missing (#11…
Browse files Browse the repository at this point in the history
…3918)

Summary:

As discussed in the meeting, we are inverting the policy on the use of proxy executor for aten fallbacks.
By default, aten fallback ops will use proxy executor, unless a c-shim is available.

Added test for aten.sort and aten.index.Tensor, as they are now runnable with proxy executor.

Test Plan: CIs

Reviewed By: chenyang78

Differential Revision: D51417683
  • Loading branch information
SherlockNoMad authored and facebook-github-bot committed Nov 17, 2023
1 parent 6849d75 commit ccb4dcf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def forward(self, x):
example_inputs = (torch.randn(8, 4, 4, device=self.device),)
self.check_model(Model(), example_inputs)

@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
def test_dup_unbacked_sym_decl(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down
18 changes: 14 additions & 4 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4189,8 +4189,14 @@ class ExternKernelNode:
node: export_schema.Node


fbcode_use_proxy_executor = {
torch.ops.aten._scaled_dot_product_efficient_attention.default,
has_c_shim = {
aten._scaled_dot_product_flash_attention.default,
aten.addmm.out,
aten.bmm.out,
aten.mm.out,
aten._scaled_mm.default,
aten.repeat_interleave.Tensor,
aten.nonzero.default,
}


Expand Down Expand Up @@ -4394,7 +4400,7 @@ def handle_single_output(return_type, output):
node = ExternKernelNode(
name=self.get_name(),
node=export_schema.Node(
target=self.cpp_kernel,
target=self.op_overload.name(),
inputs=named_arguments,
outputs=output_arguments,
metadata={},
Expand All @@ -4413,7 +4419,11 @@ def codegen(self, wrapper):
op_base_name = kernel.__name__.split(".")[0]

if V.graph.cpp_wrapper:
if config.is_fbcode() and kernel in fbcode_use_proxy_executor:
if config.is_fbcode() and kernel not in has_c_shim:
log.warning(
"%s is missing a c-shim implementation, using proxy executor as fallback",
kernel,
)
self.use_runtime_dispatch = True
self.set_cpp_kernel(kernel)
else:
Expand Down

0 comments on commit ccb4dcf

Please sign in to comment.