From 40a0fc0d7caba2e0b13b35a48180e049e56f8748 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Tue, 26 Sep 2023 09:48:20 -0700 Subject: [PATCH] [inductor] support _scaled_dot_product_flash_attention fallback Summary: This PR supports _scaled_dot_product_flash_attention fallback kernel. Note that in the abi_compatible mode, we retrieve outputs by passing output argument pointers rather than relying on std::get. It also fixes an issue related to dynamic shapes, where we wrongfully query undefined dynamic symbols. Test Plan: ci Reviewed By: frank-wei Differential Revision: D49620191 --- test/inductor/test_aot_inductor.py | 39 +++++++++- torch/_inductor/codegen/wrapper.py | 76 ++++++++++++++----- torch/_inductor/ir.py | 17 +++-- torch/csrc/inductor/aoti_torch/c/shim.h | 19 +++++ .../csrc/inductor/aoti_torch/shim_common.cpp | 57 ++++++++++++++ 5 files changed, 183 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index ed635c5ee61f..ce40566bf95c 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -551,14 +551,49 @@ def forward(self, x, y): constraints=constraints, ) + # scaled_dot_product_flash_attention + def test_sdpa(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0] + + example_inputs = ( + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + ) + self.check_model(Repro(), example_inputs) + + def test_sdpa_2(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, x): + t = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=True + )[0] + return x + t + + example_inputs = ( + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"), + ) + self.check_model(Repro(), example_inputs) + -class AOTInductorTestABICompatibile(TestCase): +class AOTInductorTestABICompatible(TestCase): abi_compatible = True check_model = check_model check_model_with_multiple_inputs = check_model_with_multiple_inputs -copy_tests(AOTInductorTestsTemplate, AOTInductorTestABICompatibile, "abi_compatible") +copy_tests(AOTInductorTestsTemplate, AOTInductorTestABICompatible, "abi_compatible") class AOTInductorTestNonABICompatible(TestCase): diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 5a4ed2502386..eaa9a9000b5f 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -294,6 +294,11 @@ def __init__(self): self.last_seen_device_guard_index = None self.supports_intermediate_hooks = True self.expr_printer = pexpr + # Not all the dynamic symbols will be used in the generated code. This + # set contains those actually being defined by something like + # "{self.declare_shape} s0 = ...". It ensures that we are not going to + # emit queries for undefined symbols. + self.defined_symbols = set() self.write_header() self.write_prefix() @@ -587,6 +592,7 @@ def is_expr(x): for name, shape in graph_inputs_expr: shape = V.graph.sizevars.simplify(shape) if shape in needed: + self.defined_symbols.add(shape) needed.remove(shape) code.writeline(f"{self.declare}{shape} = {name}{self.ending}") @@ -595,6 +601,7 @@ def is_expr(x): for dim, shape in enumerate(shapes): shape = V.graph.sizevars.simplify(shape) if shape in needed: + self.defined_symbols.add(shape) needed.remove(shape) code.writeline( f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" @@ -605,6 +612,7 @@ def is_expr(x): for dim, shape in enumerate(shapes): shape = V.graph.sizevars.simplify(shape) if shape in needed: + self.defined_symbols.add(shape) needed.remove(shape) code.writeline( f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" @@ -623,7 +631,7 @@ def codegen_python_sizevar(self, x: Expr) -> str: def codegen_sizevar(self, x: Expr) -> str: return self.codegen_python_sizevar(x) - def codegen_tuple_access(self, basename: str, index: str) -> str: + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: return f"{basename}[{index}]" def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: @@ -646,6 +654,9 @@ def codegen_reinterpret_view(self, name, size, stride, offset, writer) -> str: def codegen_device_copy(self, src, dst): self.writeline(f"{dst}.copy_({src})") + def codegen_multi_output(self, name, value): + self.writeline(f"{self.declare}{name} = {value}{self.ending}") + def benchmark_compiled_module(self, output): def add_fake_input(name, shape, stride, device, dtype): output.writeline( @@ -1191,7 +1202,11 @@ def write_wrapper_decl(self): if V.graph.aot_mode: self.prefix.writeline("inputs.clear();") - dynamic_symbols = V.graph.sizevars.free_symbols() + dynamic_symbols = [ + s + for s in V.graph.sizevars.free_symbols() + if s in self.defined_symbols + ] for dim in dynamic_symbols: self.prefix.writeline( f'auto dim_{dim} = find_dynamic_dim("{dim}");' @@ -1411,21 +1426,38 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): kernel = "aoti_torch_" + kernel.split("::")[-1] self.writeline(f"AOTI_TORCH_ERROR_CODE_CHECK({kernel}({', '.join(args)}));") + def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args): + output_args = [] + output_raii_handles = [] + output_name_base = extern_kernel.get_name() + for idx, output in enumerate(extern_kernel.outputs): + if isinstance(output, ir.MultiOutput): + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + assert ( + output.indices[0][1] == idx + ), f"expected {output.indices[1]=} == {idx=} for {output_name_base=}" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError("unsupported type of {output=}") + args = args + output_args + self.generate_c_shim_extern_kernel_call(extern_kernel.kernel, args) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + def generate_extern_kernel_alloc(self, extern_kernel, args): if V.graph.aot_mode and config.aot_inductor.abi_compatible: - output_name = extern_kernel.get_name() - self.writeline(f"AtenTensorHandle {output_name};") - kernel = extern_kernel.kernel - size = self.codegen_shape_tuple(tuple(extern_kernel.get_size())) - stride = self.codegen_shape_tuple(tuple(extern_kernel.get_stride())) - args = [ - f"&{output_name}", - str(len(extern_kernel.get_size())), # ndim - self.codegen_int_array_var(size), - self.codegen_int_array_var(stride), - ] + args - # TODO: support extern kernel that allocates - self.generate_c_shim_extern_kernel_call(kernel, args) + self.generate_c_shim_extern_kernel_alloc_call(extern_kernel, args) else: super().generate_extern_kernel_alloc(extern_kernel, args) @@ -1470,8 +1502,12 @@ def add_benchmark_harness(self, output): def codegen_sizevar(self, x: Expr) -> str: return self.expr_printer(V.graph.sizevars.simplify(x)) - def codegen_tuple_access(self, basename: str, index: str) -> str: - return f"std::get<{index}>({basename})" + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + if V.graph.aot_mode and config.aot_inductor.abi_compatible: + # in the abi_compatible mode, outputs are returned via arguments + return name + else: + return f"std::get<{index}>({basename})" def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: parts = list(map(self.codegen_sizevar, shape)) @@ -1605,6 +1641,12 @@ def codegen_device_copy(self, src, dst): else: self.writeline(f"{dst}.copy_({src});") + def codegen_multi_output(self, name, value): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + if not config.aot_inductor.abi_compatible: + super().codegen_multi_output(name, value) + def generate_extern_kernel_args_decl_if_needed( self, op_overload, raw_args, output_args ): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f002fb118389..7330ecf5fa45 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3626,6 +3626,10 @@ def __init__( tuple(tensor_args), tuple(nontensor_args), ) + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs = [] self.use_cpp_op_schema = False self.op_overload = kernel @@ -3878,7 +3882,8 @@ def generate_output(output, indices): assert output is None, "FallbackKernel output type is not supported" return None - return generate_output(example_output, []) + packed.outputs = generate_output(example_output, []) + return packed.outputs def apply_constraint(self): return super().apply_constraint() @@ -3898,7 +3903,7 @@ def codegen_list_tuple_access(self, basename, indices): elif itype == tuple: # cpp wrapper code needs to use std::get<> to access a tuple tuple_access = V.graph.wrapper_code.codegen_tuple_access( - basename, str(i) + basename, self.get_name(), str(i) ) return self.codegen_list_tuple_access(tuple_access, indices[1:]) else: @@ -3907,10 +3912,10 @@ def codegen_list_tuple_access(self, basename, indices): return basename def codegen(self, wrapper): - line = wrapper.declare - line += f"{self.get_name()} = {self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices)}" - line += wrapper.ending - wrapper.writeline(line) + wrapper.codegen_multi_output( + self.get_name(), + self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), + ) self.codegen_size_asserts(wrapper) def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 54158aab3476..55dc54fff974 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -152,6 +152,25 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( AtenTensorHandle* ret // returns new reference ); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + AOTI_TORCH_EXPORT AOTITorchError aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 64a690108575..d6c95893eca4 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -16,6 +16,7 @@ #else #include +#include #include #include #include @@ -182,6 +183,62 @@ AOTITorchError aoti_torch_create_tensor_from_blob( }); } +AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query); + at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key); + at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value); + auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] = + at::_scaled_dot_product_flash_attention( + *query_tensor, + *key_tensor, + *value_tensor, + dropout_p, + is_causal, + return_debug_mask, + scale); + + at::Tensor* ret0_tensor = new at::Tensor(std::move(r0)); + *ret0 = tensor_pointer_to_tensor_handle(ret0_tensor); + at::Tensor* ret1_tensor = new at::Tensor(std::move(r1)); + *ret1 = tensor_pointer_to_tensor_handle(ret1_tensor); + // ret2 and ret3 may be null + if (ret2) { + at::Tensor* ret2_tensor = new at::Tensor(std::move(r2)); + *ret2 = tensor_pointer_to_tensor_handle(ret2_tensor); + } + if (ret3) { + at::Tensor* ret3_tensor = new at::Tensor(std::move(r3)); + *ret3 = tensor_pointer_to_tensor_handle(ret3_tensor); + } + *ret4 = r4; + *ret5 = r5; + at::Tensor* ret6_tensor = new at::Tensor(std::move(r6)); + *ret6 = tensor_pointer_to_tensor_handle(ret6_tensor); + at::Tensor* ret7_tensor = new at::Tensor(std::move(r7)); + *ret7 = tensor_pointer_to_tensor_handle(ret7_tensor); + at::Tensor* ret8_tensor = new at::Tensor(std::move(r8)); + *ret8 = tensor_pointer_to_tensor_handle(ret8_tensor); + }); +} + // TODO: implement a more efficient version instead of calling into aten AOTITorchError aoti_torch_tensor_copy_( AtenTensorHandle src,