From 7faf105f3d45af63b36767eb9844ac79cf25685f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 3 Mar 2023 18:29:43 +0000 Subject: [PATCH] [reland][inductor] Add an AOT compilation mode for Inductor CPP backend Summary: This is a reland of https://github.com/pytorch/pytorch/pull/94822 [ghstack-poisoned] --- benchmarks/dynamo/common.py | 1 + test/inductor/aot/cpp/CMakeLists.txt | 23 +++++ test/inductor/aot/cpp/test.cpp | 41 ++++++++ test/inductor/aot/cpp/test.py | 22 +++++ test/inductor/aot/cpp/test.sh | 8 ++ torch/_inductor/__init__.py | 21 +++++ torch/_inductor/codecache.py | 46 +++++++++ torch/_inductor/codegen/cpp.py | 16 +++- torch/_inductor/codegen/cpp_prefix.h | 1 + torch/_inductor/codegen/wrapper.py | 134 +++++++++++++++++++-------- torch/_inductor/compile_fx.py | 21 +++++ torch/_inductor/config.py | 3 + torch/_inductor/debug.py | 4 +- torch/_inductor/graph.py | 24 ++++- 14 files changed, 314 insertions(+), 51 deletions(-) create mode 100644 test/inductor/aot/cpp/CMakeLists.txt create mode 100644 test/inductor/aot/cpp/test.cpp create mode 100644 test/inductor/aot/cpp/test.py create mode 100755 test/inductor/aot/cpp/test.sh diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 238bc10a0c948..d94f5620784de 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -175,6 +175,7 @@ class CI(NamedTuple): # TIMM "cait_m36_384", # Accuracy "pnasnet5large", # OOM + "xcit_large_24_p8_224", # OOM https://github.com/pytorch/pytorch/issues/95984 ] CI_SKIP[CI("inductor", training=True)] = [ diff --git a/test/inductor/aot/cpp/CMakeLists.txt b/test/inductor/aot/cpp/CMakeLists.txt new file mode 100644 index 0000000000000..f62dc3a164410 --- /dev/null +++ b/test/inductor/aot/cpp/CMakeLists.txt @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.0 FATAL_ERROR) +project(test) + +set(Torch_DIR "../../../../torch/share/cmake/Torch") +find_package(Torch REQUIRED) + +add_executable(test test.cpp ${CMAKE_BINARY_DIR}/aot_inductor_output.h) + +add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/aot_inductor_output.h + COMMAND python ${CMAKE_SOURCE_DIR}/test.py + DEPENDS ${CMAKE_SOURCE_DIR}/test.py +) +add_custom_target(generate_header ALL + DEPENDS ${CMAKE_BINARY_DIR}/aot_inductor_output.h) + +add_library(aot_inductor_output SHARED IMPORTED) +set_property(TARGET aot_inductor_output PROPERTY + IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/aot_inductor_output.so) + +target_link_libraries(test "${TORCH_LIBRARIES}" aot_inductor_output) + +set_property(TARGET test PROPERTY CXX_STANDARD 17) diff --git a/test/inductor/aot/cpp/test.cpp b/test/inductor/aot/cpp/test.cpp new file mode 100644 index 0000000000000..3ccf3b2466d9c --- /dev/null +++ b/test/inductor/aot/cpp/test.cpp @@ -0,0 +1,41 @@ +//#include +#include + +#include "build/aot_inductor_output.h" + +/* +class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.ones(32, 64) + + def forward(self, x): + x = torch.relu(x + self.weight) + return x +*/ +struct Net : torch::nn::Module { + Net() { + weight = register_parameter("weight", torch::ones({32, 64})); + } + torch::Tensor forward(torch::Tensor input) { + return torch::relu(input + weight); + } + torch::Tensor weight; +}; + +int main() { + torch::Tensor x = at::randn({32, 64}); + Net net; + torch::Tensor results_ref = net.forward(x); + + // TODO: we need to provide an API to concatenate args and weights + std::vector inputs = {x}; + for (const auto& pair : net.named_parameters()) { + inputs.push_back(pair.value()); + } + torch::Tensor results_opt = aot_inductor_entry(inputs); + + assert(torch::allclose(results_ref, results_opt)); + printf("PASS\n"); + return 0; +} diff --git a/test/inductor/aot/cpp/test.py b/test/inductor/aot/cpp/test.py new file mode 100644 index 0000000000000..fc04172ea5f8d --- /dev/null +++ b/test/inductor/aot/cpp/test.py @@ -0,0 +1,22 @@ +import torch +import torch._dynamo +import torch._inductor +import torch._inductor.config + +torch._inductor.config.aot_codegen_output_prefix = "aot_inductor_output" + + +class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.ones(32, 64) + + def forward(self, x): + x = torch.relu(x + self.weight) + return x + + +inp = torch.randn((32, 64), device="cpu") +module, _ = torch._dynamo.export(Net(), inp) +so_path = torch._inductor.aot_compile(module, [inp]) +print(so_path) diff --git a/test/inductor/aot/cpp/test.sh b/test/inductor/aot/cpp/test.sh new file mode 100755 index 0000000000000..b00f384e37928 --- /dev/null +++ b/test/inductor/aot/cpp/test.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -euxo pipefail + +mkdir -p build +cd build +cmake .. +make +./test diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index ceadaac7472e6..b35921762e9ec 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -25,3 +25,24 @@ def compile( from .compile_fx import compile_fx return compile_fx(gm, example_inputs, config_patches=options) + + +def aot_compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +) -> str: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options, aot_mode=True)() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 34206b9cfee3d..085d3ecbaa8fa 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -534,6 +534,52 @@ def cpp_compile_command( ).strip() +class AotCodeCache: + cache = dict() + clear = staticmethod(cache.clear) + + @classmethod + def compile(cls, source_code): + from .codegen.wrapper import CppWrapperCodeGen + + # TODO: update cpp_compile_command for different platforms + picked_vec_isa = pick_vec_isa() + key, input_path = write( + source_code, + "cpp", + code_hash(repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa))), + ) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_so = ( + os.path.join(os.getcwd(), f"{config.aot_codegen_output_prefix}.so") + if config.aot_codegen_output_prefix + else f"{input_path[:-3]}.so" + ) + + output_header = f"{output_so[:-3]}.h" + with open(output_header, "w") as header_file: + header_file.writelines("#include \n\n") + header_file.writelines(f"{CppWrapperCodeGen.decl_str};\n") + + log.info(f"AOT-Inductor compiles code into: {output_so}") + if not os.path.exists(output_so): + cmd = cpp_compile_command( + input=input_path, output=output_so, vec_isa=picked_vec_isa + ).split(" ") + try: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + cls.cache[key] = output_so + return cls.cache[key] + + class CppCodeCache: cache = dict() clear = staticmethod(cache.clear) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e9390d4aab182..df8246182f1d3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2040,7 +2040,12 @@ def codegen_define_and_call(self, wrapper): ) if enable_kernel_profile: code.writelines(["#include "]) - code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})']) + kernel_decl_name = kernel_name if V.graph.aot_mode else "kernel" + + if not V.graph.aot_mode or self.count == 1: + code.writeline(cpp_prefix()) + + code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})') with code.indent(): if enable_kernel_profile: graph_id = V.graph.graph_id @@ -2055,9 +2060,12 @@ def codegen_define_and_call(self, wrapper): code.splice(self.loops_code) codecache_def = IndentedBuffer() - codecache_def.writeline("async_compile.cpp('''") - codecache_def.splice(code) - codecache_def.writeline("''')") + if V.graph.aot_mode: + codecache_def.splice(code) + else: + codecache_def.writeline("async_compile.cpp('''") + codecache_def.splice(code) + codecache_def.writeline("''')") codecache_str = codecache_def.getvalue() # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 08321da5ce95a..1ba1ea452dc1c 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -5,6 +5,7 @@ #include #include +#include #include #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) #include diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 60c698cdd02b3..db351e7ca4fb8 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -269,6 +269,33 @@ def __init__(self): self.wrapper_call = IndentedBuffer() self.kernels = {} self.lines = [] + + self.set_header() + self.write_prefix() + + for name, value in V.graph.constants.items(): + # include a hash so our code cache gives different constants different files + hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest() + self.header.writeline(f"{name} = None # {hashed}") + + self.allocated = set() + self.freed = set() + + # maps from reusing buffer to reused buffer + self.reuses = dict() + + self.write_get_cuda_stream = functools.lru_cache(None)( + self.write_get_cuda_stream + ) + + @functools.lru_cache(None) + def add_import_once(line): + self.header.writeline(line) + + self.add_import_once = add_import_once + self._metas = {} + + def set_header(self): self.header.splice( f""" from ctypes import c_void_p, c_long @@ -296,30 +323,6 @@ def __init__(self): """ ) - self.write_prefix() - - for name, value in V.graph.constants.items(): - # include a hash so our code cache gives different constants different files - hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest() - self.header.writeline(f"{name} = None # {hashed}") - - self.allocated = set() - self.freed = set() - - # maps from reusing buffer to reused buffer - self.reuses = dict() - - self.write_get_cuda_stream = functools.lru_cache(None)( - self.write_get_cuda_stream - ) - - @functools.lru_cache(None) - def add_import_once(line): - self.header.writeline(line) - - self.add_import_once = add_import_once - self._metas = {} - def add_meta_once(self, meta): meta = repr(meta) if meta not in self._metas: @@ -629,6 +632,7 @@ class CppWrapperCodeGen(WrapperCodeGen): """ call_func_id = count() + decl_str = None def __init__(self): self._call_func_id = next(CppWrapperCodeGen.call_func_id) @@ -648,7 +652,7 @@ def has_cpp_codegen_func(x): for x in V.graph.graph_outputs ] - def write_prefix(self): + def write_prefix_header(self): self.prefix.splice( """ async_compile.wait(globals()) @@ -670,21 +674,30 @@ def write_prefix(self): """ ) - with self.wrapper_call.indent(): - inputs_len = len(V.graph.graph_inputs.keys()) - output_refs = self.get_output_refs() - if output_refs: - if len(output_refs) == 1: - output_types = "at::Tensor" - else: - output_types = "std::vector" + + def call_func_name(self): + return f"call_{self._call_func_id}" + + def write_prefix(self): + self.write_prefix_header() + + inputs_len = len(V.graph.graph_inputs.keys()) + output_refs = self.get_output_refs() + if output_refs: + if len(output_refs) == 1: + output_types = "at::Tensor" else: - output_types = "void" + output_types = "std::vector" + else: + output_types = "void" - inputs_types = "std::vector" - self.wrapper_call.writeline( - f"{output_types} call_{self._call_func_id}({inputs_types} args) {{" - ) + inputs_types = "std::vector" + + CppWrapperCodeGen.decl_str = ( + f"{output_types} {self.call_func_name()}({inputs_types} args)" + ) + self.prefix.splice(f"{CppWrapperCodeGen.decl_str} {{") + with self.wrapper_call.indent(): if inputs_len != 0: inputs_keys_str = ", ".join(V.graph.graph_inputs.keys()) self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};") @@ -746,18 +759,24 @@ def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = No def wrap_kernel_call(self, name, call_args): return "{}({});".format(name, ", ".join(call_args)) + def return_end_str(self): + return "\n}\n'''\n)" + def generate_return(self, output_refs): if output_refs: if len(output_refs) == 1: - self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )") + self.wrapper_call.writeline( + f"return {output_refs[0]};{self.return_end_str()}" + ) else: self.wrapper_call.writeline( "return std::vector({" + ", ".join(output_refs) - + "}); }''' )" + + "});" + + self.return_end_str() ) else: - self.wrapper_call.writeline("return; }''' )") + self.wrapper_call.writeline(f"return;{self.return_end_str()}") def generate_end(self, result): shared = codecache.get_shared() @@ -807,3 +826,36 @@ def generate_extern_kernel_out( else: args.insert(0, f"{codegen_reference}") self.writeline(f"{cpp_kernel}({', '.join(args)});") + + +class CppAotWrapperCodeGen(CppWrapperCodeGen): + """ + The AOT-version outer wrapper that calls the kernels in C++ + """ + + def set_header(self): + return + + def write_prefix_header(self): + return + + def call_func_name(self): + return "aot_inductor_entry" + + def define_kernel(self, name: str, kernel: str): + self.header.splice(f"\n{kernel}\n") + + def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None): + return + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)});" + + def return_end_str(self): + return "\n}" + + def generate_end(self, result): + return + + def add_benchmark_harness(self, output): + return diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 64ae64f480f9f..f705fc1ca73f6 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -138,6 +138,7 @@ def compile_fx_inner( num_fixed=0, is_backward=False, graph_id=None, + aot_mode=False, ): if is_tf32_warning_applicable(gm): _warn_tf32_disabled() @@ -174,10 +175,13 @@ def compile_fx_inner( shape_env=shape_env, num_static_inputs=num_fixed, graph_id=graph_id, + aot_mode=aot_mode, ) with V.set_graph_handler(graph): graph.run(*example_inputs) compiled_fn = graph.compile_to_fn() + if aot_mode: + return compiled_fn if cudagraphs: complex_memory_overlap_inputs = any( @@ -399,6 +403,7 @@ def compile_fx( inner_compile=compile_fx_inner, config_patches: Optional[Dict[str, Any]] = None, decompositions: Optional[Dict[OpOverload, Callable]] = None, + aot_mode=False, ): """Main entrypoint to a compile given FX graph""" if config_patches: @@ -409,7 +414,23 @@ def compile_fx( # need extra layer of patching as backwards is compiled out of scope inner_compile=config.patch(config_patches)(inner_compile), decompositions=decompositions, + aot_mode=aot_mode, ) + + if aot_mode: + aot_config_patches = { + "cpp_wrapper": True, + "debug": True, + "triton.cudagraphs": False, + } + with config.patch(aot_config_patches): + return compile_fx( + model_, + example_inputs_, + inner_compile=functools.partial(inner_compile, aot_mode=aot_mode), + decompositions=decompositions, + ) + recursive_compile_fx = functools.partial( compile_fx, inner_compile=inner_compile, diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f85ef3f135894..8ac48a89d3f81 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -15,6 +15,9 @@ # limit lines of inner_fn() when printing IR debug_max_lines = int(os.environ.get("TORCHINDUCTOR_DEBUG_MAX_LINES", "10")) +# Name for generated .h and .so files +aot_codegen_output_prefix = None + # use cpp wrapper instead of python wrapper cpp_wrapper = False diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 89edaabff995c..1e9cbc8c4c4b2 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -206,7 +206,7 @@ def enable_aot_logging(): stack.enter_context(patch("functorch.compile.config.debug_graphs", True)) stack.enter_context(patch("functorch.compile.config.debug_joint", True)) - path = os.path.join(get_debug_dir(), "aot_torchinductor") + path = os.path.join(get_debug_dir(), "torchinductor") if not os.path.exists(path): os.makedirs(path) @@ -245,7 +245,7 @@ def create_debug_dir(folder_name): for n in DebugContext._counter: dirname = os.path.join( get_debug_dir(), - "aot_torchinductor", + "torchinductor", f"{folder_name}.{n}", ) if not os.path.exists(dirname): diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 1465b98423617..a7b1f40e2342b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -23,7 +23,7 @@ from .._dynamo import config as dynamo_config from . import config, ir -from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen +from .codegen.wrapper import CppAotWrapperCodeGen, CppWrapperCodeGen, WrapperCodeGen from .exc import ( LoweringException, MissingOperatorWithDecomp, @@ -114,6 +114,7 @@ def __init__( shape_env=None, num_static_inputs=None, graph_id=None, + aot_mode=False, ): super().__init__(gm) self.extra_traceback = False # we do our own error wrapping @@ -143,6 +144,7 @@ def __init__( self.creation_time = time.time() self.name = "GraphLowering" self._can_use_cpp_wrapper = config.cpp_wrapper + self.aot_mode = aot_mode self.graph_id = graph_id self.scheduler = None self._warned_fallback = {"aten.convolution_backward"} @@ -539,10 +541,13 @@ def init_wrapper_code(self): self.check_cpp_wrapper() if self._can_use_cpp_wrapper: self.sizevars = CppSizeVarAllocator(self._shape_env) - self.wrapper_code = CppWrapperCodeGen() + self.wrapper_code = ( + CppAotWrapperCodeGen() if self.aot_mode else CppWrapperCodeGen() + ) return + else: + assert not self.aot_mode, "Model does not support AOT compilation" self.wrapper_code = WrapperCodeGen() - return def codegen(self): from .scheduler import Scheduler @@ -615,7 +620,18 @@ def compile_to_module(self): return mod def compile_to_fn(self): - return self.compile_to_module().call + if self.aot_mode: + from .codecache import AotCodeCache + + code = self.codegen() + if config.debug: + print(code) + + # return the generated .so file path + output_path = AotCodeCache.compile(code) + return lambda dummy: output_path + else: + return self.compile_to_module().call def get_output_names(self): assert self.graph_outputs is not None