Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] support _scaled_dot_product_flash_attention fallback #110085

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 37 additions & 2 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,14 +551,49 @@ def forward(self, x, y):
constraints=constraints,
)

# scaled_dot_product_flash_attention
def test_sdpa(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v):
return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]

example_inputs = (
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
)
self.check_model(Repro(), example_inputs)

def test_sdpa_2(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v, x):
t = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=True
)[0]
return x + t

example_inputs = (
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
)
self.check_model(Repro(), example_inputs)


class AOTInductorTestABICompatibile(TestCase):
class AOTInductorTestABICompatible(TestCase):
abi_compatible = True
check_model = check_model
check_model_with_multiple_inputs = check_model_with_multiple_inputs


copy_tests(AOTInductorTestsTemplate, AOTInductorTestABICompatibile, "abi_compatible")
copy_tests(AOTInductorTestsTemplate, AOTInductorTestABICompatible, "abi_compatible")


class AOTInductorTestNonABICompatible(TestCase):
Expand Down
76 changes: 59 additions & 17 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ def __init__(self):
self.last_seen_device_guard_index = None
self.supports_intermediate_hooks = True
self.expr_printer = pexpr
# Not all the dynamic symbols will be used in the generated code. This
# set contains those actually being defined by something like
# "{self.declare_shape} s0 = ...". It ensures that we are not going to
# emit queries for undefined symbols.
self.defined_symbols = set()

self.write_header()
self.write_prefix()
Expand Down Expand Up @@ -587,6 +592,7 @@ def is_expr(x):
for name, shape in graph_inputs_expr:
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
needed.remove(shape)
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")

Expand All @@ -595,6 +601,7 @@ def is_expr(x):
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
needed.remove(shape)
code.writeline(
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
Expand All @@ -605,6 +612,7 @@ def is_expr(x):
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
needed.remove(shape)
code.writeline(
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
Expand All @@ -623,7 +631,7 @@ def codegen_python_sizevar(self, x: Expr) -> str:
def codegen_sizevar(self, x: Expr) -> str:
return self.codegen_python_sizevar(x)

def codegen_tuple_access(self, basename: str, index: str) -> str:
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
return f"{basename}[{index}]"

def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
Expand All @@ -646,6 +654,9 @@ def codegen_reinterpret_view(self, name, size, stride, offset, writer) -> str:
def codegen_device_copy(self, src, dst):
self.writeline(f"{dst}.copy_({src})")

def codegen_multi_output(self, name, value):
self.writeline(f"{self.declare}{name} = {value}{self.ending}")

def benchmark_compiled_module(self, output):
def add_fake_input(name, shape, stride, device, dtype):
output.writeline(
Expand Down Expand Up @@ -1191,7 +1202,11 @@ def write_wrapper_decl(self):

if V.graph.aot_mode:
self.prefix.writeline("inputs.clear();")
dynamic_symbols = V.graph.sizevars.free_symbols()
dynamic_symbols = [
s
for s in V.graph.sizevars.free_symbols()
if s in self.defined_symbols
]
for dim in dynamic_symbols:
self.prefix.writeline(
f'auto dim_{dim} = find_dynamic_dim("{dim}");'
Expand Down Expand Up @@ -1411,21 +1426,38 @@ def generate_c_shim_extern_kernel_call(self, kernel, args):
kernel = "aoti_torch_" + kernel.split("::")[-1]
self.writeline(f"AOTI_TORCH_ERROR_CODE_CHECK({kernel}({', '.join(args)}));")

def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args):
output_args = []
output_raii_handles = []
output_name_base = extern_kernel.get_name()
for idx, output in enumerate(extern_kernel.outputs):
if isinstance(output, ir.MultiOutput):
name = f"{output.get_name()}"
output_handle_name = f"{name}_handle"
assert (
output.indices[0][1] == idx
), f"expected {output.indices[1]=} == {idx=} for {output_name_base=}"
self.writeline(f"AtenTensorHandle {output_handle_name};")
output_args.append(f"&{output_handle_name}")
output_raii_handles.append(
f"RAIIAtenTensorHandle {name}({output_handle_name});"
)
elif isinstance(output, int):
output_name = f"{output_name_base}_{idx}"
self.writeline(f"int64_t {output_name} = {output};")
output_args.append(f"&{output_name}")
elif output is None:
output_args.append("nullptr")
else:
raise NotImplementedError("unsupported type of {output=}")
args = args + output_args
self.generate_c_shim_extern_kernel_call(extern_kernel.kernel, args)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)

def generate_extern_kernel_alloc(self, extern_kernel, args):
if V.graph.aot_mode and config.aot_inductor.abi_compatible:
output_name = extern_kernel.get_name()
self.writeline(f"AtenTensorHandle {output_name};")
kernel = extern_kernel.kernel
size = self.codegen_shape_tuple(tuple(extern_kernel.get_size()))
stride = self.codegen_shape_tuple(tuple(extern_kernel.get_stride()))
args = [
f"&{output_name}",
str(len(extern_kernel.get_size())), # ndim
self.codegen_int_array_var(size),
self.codegen_int_array_var(stride),
] + args
# TODO: support extern kernel that allocates
self.generate_c_shim_extern_kernel_call(kernel, args)
self.generate_c_shim_extern_kernel_alloc_call(extern_kernel, args)
else:
super().generate_extern_kernel_alloc(extern_kernel, args)

Expand Down Expand Up @@ -1470,8 +1502,12 @@ def add_benchmark_harness(self, output):
def codegen_sizevar(self, x: Expr) -> str:
return self.expr_printer(V.graph.sizevars.simplify(x))

def codegen_tuple_access(self, basename: str, index: str) -> str:
return f"std::get<{index}>({basename})"
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
if V.graph.aot_mode and config.aot_inductor.abi_compatible:
# in the abi_compatible mode, outputs are returned via arguments
return name
else:
return f"std::get<{index}>({basename})"

def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
parts = list(map(self.codegen_sizevar, shape))
Expand Down Expand Up @@ -1605,6 +1641,12 @@ def codegen_device_copy(self, src, dst):
else:
self.writeline(f"{dst}.copy_({src});")

def codegen_multi_output(self, name, value):
# in the abi_compatible mode, outputs are retrieved by passing
# output pointers, so we skip its codegen here.
if not config.aot_inductor.abi_compatible:
super().codegen_multi_output(name, value)

def generate_extern_kernel_args_decl_if_needed(
self, op_overload, raw_args, output_args
):
Expand Down
17 changes: 11 additions & 6 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3626,6 +3626,10 @@ def __init__(
tuple(tensor_args),
tuple(nontensor_args),
)
# We need output buffers for generating kernel arguments in the
# abi-compatible mode, where we retrieve outputs by pass each individual
# output through the abi-compatible interface.
self.outputs = []
self.use_cpp_op_schema = False

self.op_overload = kernel
Expand Down Expand Up @@ -3878,7 +3882,8 @@ def generate_output(output, indices):
assert output is None, "FallbackKernel output type is not supported"
return None

return generate_output(example_output, [])
packed.outputs = generate_output(example_output, [])
return packed.outputs

def apply_constraint(self):
return super().apply_constraint()
Expand All @@ -3898,7 +3903,7 @@ def codegen_list_tuple_access(self, basename, indices):
elif itype == tuple:
# cpp wrapper code needs to use std::get<> to access a tuple
tuple_access = V.graph.wrapper_code.codegen_tuple_access(
basename, str(i)
basename, self.get_name(), str(i)
)
return self.codegen_list_tuple_access(tuple_access, indices[1:])
else:
Expand All @@ -3907,10 +3912,10 @@ def codegen_list_tuple_access(self, basename, indices):
return basename

def codegen(self, wrapper):
line = wrapper.declare
line += f"{self.get_name()} = {self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices)}"
line += wrapper.ending
wrapper.writeline(line)
wrapper.codegen_multi_output(
self.get_name(),
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
)
self.codegen_size_asserts(wrapper)

def __init__(self, layout, input, indices: List[Tuple[Any, ...]]):
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
double scale,
AtenTensorHandle* ret0, // returns new reference
AtenTensorHandle* ret1, // returns new reference
AtenTensorHandle* ret2, // returns new reference
AtenTensorHandle* ret3, // returns new reference
int64_t* ret4,
int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8 // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);

Expand Down
57 changes: 57 additions & 0 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#else

#include <ATen/ops/_addmm_activation.h>
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
#include <ATen/ops/addmm.h>
#include <ATen/ops/as_strided.h>
#include <ATen/ops/bmm.h>
Expand Down Expand Up @@ -182,6 +183,62 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
});
}

AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
bool is_causal,
bool return_debug_mask,
double scale,
AtenTensorHandle* ret0, // returns new reference
AtenTensorHandle* ret1, // returns new reference
AtenTensorHandle* ret2, // returns new reference
AtenTensorHandle* ret3, // returns new reference
int64_t* ret4,
int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8 // returns new reference
) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] =
at::_scaled_dot_product_flash_attention(
*query_tensor,
*key_tensor,
*value_tensor,
dropout_p,
is_causal,
return_debug_mask,
scale);

at::Tensor* ret0_tensor = new at::Tensor(std::move(r0));
*ret0 = tensor_pointer_to_tensor_handle(ret0_tensor);
at::Tensor* ret1_tensor = new at::Tensor(std::move(r1));
*ret1 = tensor_pointer_to_tensor_handle(ret1_tensor);
// ret2 and ret3 may be null
if (ret2) {
at::Tensor* ret2_tensor = new at::Tensor(std::move(r2));
*ret2 = tensor_pointer_to_tensor_handle(ret2_tensor);
}
if (ret3) {
at::Tensor* ret3_tensor = new at::Tensor(std::move(r3));
*ret3 = tensor_pointer_to_tensor_handle(ret3_tensor);
}
*ret4 = r4;
*ret5 = r5;
at::Tensor* ret6_tensor = new at::Tensor(std::move(r6));
*ret6 = tensor_pointer_to_tensor_handle(ret6_tensor);
at::Tensor* ret7_tensor = new at::Tensor(std::move(r7));
*ret7 = tensor_pointer_to_tensor_handle(ret7_tensor);
at::Tensor* ret8_tensor = new at::Tensor(std::move(r8));
*ret8 = tensor_pointer_to_tensor_handle(ret8_tensor);
});
}

// TODO: implement a more efficient version instead of calling into aten
AOTITorchError aoti_torch_tensor_copy_(
AtenTensorHandle src,
Expand Down