diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 60208a5882d37..996f42867023b 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1147,9 +1147,9 @@ def load(cls, model, example_inputs, eager_forward): # Use a utility function for easier benchmarking source = """ - #include + #include - torch::aot_inductor::AOTInductorModel model; + torch::aot_inductor::AOTInductorModelContainer model(1); void run( const std::vector& input_tensors, diff --git a/test/cpp/aot_inductor/test.cpp b/test/cpp/aot_inductor/test.cpp index d143719ae5dd5..fbf6cfef12949 100644 --- a/test/cpp/aot_inductor/test.cpp +++ b/test/cpp/aot_inductor/test.cpp @@ -23,17 +23,28 @@ TEST(AotInductorTest, BasicTest) { Net net; net.to(torch::kCUDA); + // We should fix the weight over here. + // This should match exactly with the one in test.py + torch::Tensor weights = + at::arange(640, at::dtype(at::kFloat).device(at::kCUDA)); + weights = at::reshape(weights, {10, 64}); + torch::Tensor bias = at::zeros({10}, at::dtype(at::kFloat).device(at::kCUDA)); + + for (const auto& pair : net.named_parameters()) { + if (pair.key().find("weight") != std::string::npos) { + pair.value().copy_(weights); + } else if (pair.key().find("bias") != std::string::npos) { + pair.value().copy_(bias); + } + } + torch::Tensor x = at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA)); torch::Tensor y = at::randn({32, 64}, at::dtype(at::kFloat).device(at::kCUDA)); torch::Tensor results_ref = net.forward(x, y); - // TODO: we need to provide an API to concatenate args and weights std::vector inputs; - for (const auto& pair : net.named_parameters()) { - inputs.push_back(pair.value()); - } inputs.push_back(x); inputs.push_back(y); diff --git a/test/cpp/aot_inductor/test.py b/test/cpp/aot_inductor/test.py index 2de38993e2910..d07dfb07f82ed 100644 --- a/test/cpp/aot_inductor/test.py +++ b/test/cpp/aot_inductor/test.py @@ -8,6 +8,12 @@ class Net(torch.nn.Module): def __init__(self): super().__init__() self.fc = torch.nn.Linear(64, 10) + weights = torch.arange(640) + weights = torch.reshape(weights, (10, 64)) + + with torch.no_grad(): + self.fc.weight.copy_(weights) + self.fc.bias.copy_(torch.zeros(10)) def forward(self, x, y): return self.fc(torch.sin(x) + torch.cos(y)) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7f98637e86fab..aa5a1ac920f93 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -37,9 +37,9 @@ def load(cls, model, example_inputs, example_outputs, options=None): # Use a utility function for easier testing source = """ - #include + #include - torch::aot_inductor::AOTInductorModel model; + torch::aot_inductor::AOTInductorModelContainer model(1); void run( const std::vector& input_tensors, @@ -63,12 +63,10 @@ def run(cls, model, example_inputs, example_outputs, options=None): optimized, exported, output_tensors, output_spec = AOTInductorModelRunner.load( model, example_inputs, example_outputs, options ) - param_buffer_values = list(exported.state_dict.values()) flat_example_inputs = fx_pytree.tree_flatten_spec( example_inputs, exported.call_spec.in_spec ) - all_args = (*param_buffer_values, *flat_example_inputs) - optimized(all_args, output_tensors) + optimized(flat_example_inputs, output_tensors) return pytree.tree_unflatten(output_tensors, output_spec) @@ -91,6 +89,47 @@ def forward(self, x, y): actual = AOTInductorModelRunner.run(model, example_inputs, expected) self.assertTrue(same(actual, expected)) + def test_large(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(250112, 512, device="cuda") + + def forward(self, x, y): + return x + torch.nn.functional.linear(y, self.weight) + + model = Repro() + example_inputs = ( + torch.randn(1, 250112, device="cuda"), + torch.randn(1, 512, device="cuda"), + ) + expected = model(*example_inputs) + actual = AOTInductorModelRunner.run(model, example_inputs, expected) + self.assertTrue(same(actual, expected)) + + def test_with_offset(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + self.orig_tensor = torch.randn(2, 15, 10, device="cuda")[0] + self.tensor = self.orig_tensor[5:, :] + + def forward(self, x, y): + return ( + x + + torch.nn.functional.linear(y, self.orig_tensor[:10, :]) + + self.tensor + ) + + model = Repro() + example_inputs = ( + torch.randn(10, 10, device="cuda"), + torch.randn(10, 10, device="cuda"), + ) + expected = model(*example_inputs) + actual = AOTInductorModelRunner.run(model, example_inputs, expected) + self.assertTrue(same(actual, expected)) + def test_missing_output(self): class Repro(torch.nn.Module): def __init__(self): diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 1ceec7b130f50..843b4eb7a609e 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -307,8 +307,8 @@ def foo(mod, x): # we unfuse the conv bias, but it should only have one constant in the kernel if self.device == "cuda": FileCheck().check_not(".run(").check("conv").check(".run(").check_same( - "constant" - ).check_not("constant").check_next("return").run(code[0]) + "frozen_param" + ).check_not("frozen_param").check_next("return").run(code[0]) self.assertEqual( out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 1cdc6833fb408..610e376e89113 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -1,3 +1,4 @@ +import copy import dataclasses import io import re @@ -627,7 +628,6 @@ def aot_compile( Returns: Path to the generated shared library, and the exported program """ - from torch._inductor.compile_fx import compile_fx_aot from torch._inductor.decomposition import select_decomp_table global DECOMP_TABLE @@ -636,11 +636,46 @@ def aot_compile( # Reset the global value DECOMP_TABLE = core_aten_decompositions() - param_buffer_values = list(ep.state_dict.values()) flat_example_inputs = fx_pytree.tree_flatten_spec( combine_args_kwargs(args, kwargs), ep.call_spec.in_spec # type: ignore[arg-type] ) - all_args = (*param_buffer_values, *flat_example_inputs) - so_path = torch._inductor.aot_compile(ep.graph_module, list(all_args), options) - return so_path, ep + unlifted_module = ep.module() + unlifted_module.graph.set_codegen(torch.fx.CodeGen()) # type: ignore[attr-defined] + unlifted_module.recompile() + options = ( + {"from_export": True} + if options is None + else {**options, "from_export": True} + ) + so_path = torch._inductor.aot_compile(unlifted_module, flat_example_inputs, options) # type: ignore[arg-type] + + user_inputs = [] + user_outputs = [] + for node in unlifted_module.graph.nodes: + if node.op == "placeholder": + user_inputs.append(node.name) + elif node.op == "output": + user_outputs = [arg.name for arg in node.args[0]] + + unlifted_ep = ExportedProgram( + unlifted_module, + unlifted_module.graph, + ExportGraphSignature( + [], + [], + user_inputs, + user_outputs, + {}, + {}, + {}, + None, + ), + call_spec=copy.deepcopy(ep.call_spec), + state_dict={}, + range_constraints=copy.deepcopy(ep.range_constraints), + equality_constraints=copy.deepcopy(ep.equality_constraints), + module_call_graph=ep.module_call_graph, + ) + + return so_path, unlifted_ep diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index f256e39b99df0..21ce158873220 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -302,13 +302,13 @@ def get_lock_dir(): return lock_dir -def code_hash(code, extra: str = ""): - hashing_str = code +def code_hash(code: Union[str, bytes], extra: str = ""): + hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") if extra != "": - hashing_str = hashing_str + "||" + extra + hashing_str = hashing_str + b"||" + extra.encode("utf-8") return ( "c" - + base64.b32encode(hashlib.sha256(hashing_str.encode("utf-8")).digest())[:51] + + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51] .decode("utf-8") .lower() ) @@ -656,6 +656,10 @@ def pick_vec_isa(): return invalid_vec_isa +def get_compile_only(compile_only=True): + return "-c" if compile_only else "" + + def get_shared(shared=True): return "-shared -fPIC" if shared else "" @@ -888,6 +892,7 @@ def cpp_compile_command( vec_isa: VecISA = invalid_vec_isa, cuda=False, aot_mode=False, + compile_only=False, ): ipaths, lpaths, libs, macros = get_include_and_linking_paths( include_pytorch, vec_isa, cuda, aot_mode @@ -918,11 +923,20 @@ def cpp_compile_command( {use_custom_generated_macros()} {use_fb_internal_macros()} {use_standard_sys_dir_headers()} + {get_compile_only(compile_only)} -o {out_name} """, ).strip() +def run_command_and_check(cmd: str): + cmd = shlex.split(cmd) + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + class CudaKernelParamCache: cache = dict() clear = staticmethod(cache.clear) @@ -956,12 +970,41 @@ def compile(cls, graph, source_code, serialized_extern_kernel_nodes, cuda): "i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode ) ) + if config.is_fbcode(): + ld_command = build_paths.ld() + objcopy_command = build_paths.objcopy() + else: + ld_command = "ld" + objcopy_command = "objcopy" key, input_path = write( source_code, "cpp", extra=cpp_command, specified_dir=config.aot_inductor.output_path, ) + + def _to_bytes(t: torch.Tensor) -> bytes: + # This serializes the tensor's untyped_storage to bytes by accessing + # the raw data of the underlying structure. + import ctypes + + t_cpu = t.untyped_storage().cpu() + raw_array = ctypes.cast( + t_cpu.data_ptr(), ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()) + ) + + return bytes(raw_array.contents) + + aot_constants = b"" + for idx, tensor in enumerate(graph.constants.values()): + aot_constants += _to_bytes(tensor) + + consts_key, consts_path = write( + aot_constants, + "bin", + specified_dir=config.aot_inductor_output_path, + ) + if key not in cls.cache: from filelock import FileLock @@ -978,20 +1021,61 @@ def compile(cls, graph, source_code, serialized_extern_kernel_nodes, cuda): output_so = os.path.splitext(input_path)[0] + ".so" if not os.path.exists(output_so): - cmd = shlex.split( - cpp_compile_command( - input=input_path, - output=output_so, - vec_isa=picked_vec_isa, - cuda=cuda, - aot_mode=graph.aot_mode, - ) + output_o = os.path.splitext(input_path)[0] + ".o" + cmd = cpp_compile_command( + input=input_path, + output=output_o, + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + compile_only=True, + ) + log.debug("aot compilation command: %s", cmd) + run_command_and_check(cmd) + + consts_o = os.path.splitext(consts_path)[0] + ".o" + cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}" + run_command_and_check(cmd) + log.debug("aot constant binary command: %s", cmd) + + cmd = ( + f"{objcopy_command} --rename-section" + " .data=.lrodata,alloc,load,readonly,data,contents" + f" {consts_o} {consts_o}" + ) + log.debug("aot constant obj command: %s", cmd) + run_command_and_check(cmd) + + cmd = f"rm {consts_path}" + log.debug("aot constant bin removal command: %s", cmd) + run_command_and_check(cmd) + + body = re.sub(r"[\W_]+", "_", consts_path) + symbol_list = [] + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_size {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}" + ) + log.debug( + "aot constant binary redefine symbol: %s", " ".join(symbol_list) + ) + for cmd in symbol_list: + run_command_and_check(cmd) + + cmd = cpp_compile_command( + input=f"{output_o} {consts_o}", + output=output_so, + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, ) - log.debug("aot compilation command: %s", " ".join(cmd)) - try: - subprocess.check_call(cmd) - except subprocess.CalledProcessError as e: - raise exc.CppCompileError(cmd, e.output) from e + log.debug("aot linkage command: %s", cmd) + run_command_and_check(cmd) else: log.debug( "aot_inductor dynamic library already exist: %s", output_so diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index b6cb6eb1a7b54..1ec81557a2fef 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -294,9 +294,10 @@ def __init__(self): self.write_header() self.write_prefix() - for name, hashed in V.graph.constant_reprs.items(): - # include a hash so our code cache gives different constants different files - self.write_constant(name, hashed) + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache gives different constants different files + self.write_constant(name, hashed) self.allocated = set() self.freed = set() @@ -1010,10 +1011,16 @@ def write_wrapper_decl(self): isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) ), "Expect all constants to be Tensor" for idx, constants_key in enumerate(V.graph.constants.keys()): - constants_idx = inputs_len + idx - self.prefix.writeline( - f"at::Tensor {constants_key} = args[{constants_idx}];" - ) + if V.graph.aot_mode: + self.prefix.writeline( + f"""at::Tensor {constants_key} = constants_->at("{constants_key}");""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"at::Tensor {constants_key} = args[{constants_idx}];" + ) self.codegen_inputs(self.prefix, V.graph.graph_inputs) @@ -1045,14 +1052,17 @@ def codegen_model_constructor(self): """ num_inputs = len(V.graph.graph_inputs) num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) self.prefix.splice( f""" - AOTInductorModel::AOTInductorModel() - : AOTInductorModelBase({num_inputs}, {num_outputs}) {{ + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map) + : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}) {{ """ ) with self.prefix.indent(): + from .cpp import DTYPE_TO_ATEN + for idx, name in enumerate(V.graph.graph_inputs.keys()): # TODO: handle symbolic expressions later. assert not isinstance(V.graph.graph_inputs[name], sympy.Expr) @@ -1071,6 +1081,29 @@ def codegen_model_constructor(self): f"inputs_info_[{idx}].shape.emplace_back({size}, {size}, nullptr);" ) + for idx, (name, tensor) in enumerate(V.graph.constants.items()): + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = {DTYPE_TO_ATEN[tensor.dtype]};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + + self.prefix.writeline("constants_ = constants_map;") + for idx, output in enumerate(V.graph.graph_outputs): # TODO: handle symbolic expressions later. assert not isinstance(output, sympy.Expr) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index cf617153fa09e..0a2821caeca1c 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -573,6 +573,9 @@ def fx_codegen_and_compile( context.output_strides.append(None) compiled_fn = graph.compile_to_fn() + if V.aot_compilation is True: + return compiled_fn + if graph.disable_cudagraphs: BoxedBool.disable(cudagraphs) @@ -846,9 +849,6 @@ def is_saved_tensor(x): return len(static_arg_idxs) -_in_aot_compilation = BoxedBool(False) - - def compile_fx_aot( model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor], @@ -870,7 +870,7 @@ def compile_fx_aot( } extern_node_serializer = config_patches.pop("extern_node_serializer", None) - with mock.patch.object(_in_aot_compilation, "value", True): + with V.set_aot_compilation(True): return compile_fx( model_, example_inputs_, @@ -949,7 +949,7 @@ def fw_compiler_freezing( # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper # that drops constant-ified params - if _in_aot_compilation: + if V.aot_compilation is True: return optimized_function def wrapper(args): @@ -1161,6 +1161,10 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode) ) + if config.from_export and V.aot_compilation is True: + with V.set_fake_mode(fake_mode), compiled_autograd.disable(): + return inference_compiler(model_, example_inputs_) + with V.set_fake_mode(fake_mode), torch._guards.tracing( # type: ignore[call-arg] tracing_context ), compiled_autograd.disable(): diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e165da64d9da8..88dc08c04690d 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -68,6 +68,9 @@ # (if force_mixed_mm is true, the use_mixed_mm flag will be ignored) force_mixed_mm = False +# TODO: capture whether the graph is from export +from_export = False + # enable slow autotuning passes to select algorithms max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index d734a793f421e..22d63b72b44c5 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -73,13 +73,19 @@ def freeze( # See the details in fx_codegen_and_compile of compile_fx.py. view_to_reshape(aot_autograd_gm) - fw_metadata = torch._guards.TracingContext.get().fw_metadata - params_flat = torch._guards.TracingContext.get().params_flat - assert fw_metadata is not None and params_flat is not None + if torch._guards.TracingContext.get(): + fw_metadata = torch._guards.TracingContext.get().fw_metadata + params_flat = torch._guards.TracingContext.get().params_flat + assert fw_metadata is not None and params_flat is not None - preserved_arg_indices = replace_params_with_constants( - aot_autograd_gm, params_flat, fw_metadata - ) + preserved_arg_indices = replace_params_with_constants( + aot_autograd_gm, params_flat, fw_metadata + ) + else: + inputs = [ + node for node in aot_autograd_gm.graph.nodes if node.op == "placeholder" + ] + preserved_arg_indices = list(range(len(inputs))) # TODO - further restrict cse ? right now needed to dedup aliasing ops cse_graph = fx_graph_cse(aot_autograd_gm.graph) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e5bad8898f6bc..404816ad8b3c3 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -5,7 +5,7 @@ import re import sys import time -from collections import defaultdict +from collections import defaultdict, OrderedDict from contextlib import contextmanager from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple @@ -190,7 +190,7 @@ def __init__( self.device_idxs: Set[int] = set() self.cuda = False self.buffers: List[ir.ComputedBuffer] = [] - self.constants: Dict[str, torch.Tensor] = {} + self.constants: OrderedDict[str, torch.Tensor] = OrderedDict() self.constant_reprs: Dict[str, str] = {} self.removed_buffers: Set[str] = set() self.removed_inplace_buffers: Set[str] = set() @@ -508,9 +508,9 @@ def mark_buffer_mutated(self, name: str): for user in self.name_to_users[name]: user.realize() - def add_tensor_constant(self, data): - def allocate(): - for name, value in self.constants.items(): + def add_tensor_constant(self, data, name=None): + def allocate(name): + for constant_name, value in self.constants.items(): if ( not data.is_mkldnn and data.size() == value.size() @@ -519,17 +519,21 @@ def allocate(): and data.device == value.device and torch.eq(data, value).all() ): - return name - name = f"constant{len(self.constants)}" + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" self.constants[name] = data self.constant_reprs[name] = hashlib.sha256( repr(data).encode("utf-8") ).hexdigest() return name + name = allocate(name) + return TensorBox.create( ir.ConstantBuffer( - allocate(), + name, FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), ) ) @@ -624,7 +628,7 @@ def get_attr(self, target, args, kwargs): value = getattr(self.module, target) if unsupported_output_tensor(value): - return self.add_tensor_constant(value) + return self.add_tensor_constant(value, target) with no_dispatch(): if value.shape == (): @@ -635,7 +639,7 @@ def get_attr(self, target, args, kwargs): return tensor(value.tolist(), dtype=value.dtype, device=value.device) - return self.add_tensor_constant(value) + return self.add_tensor_constant(value, target) def call_module(self, target, args, kwargs): raise AssertionError() diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 502c8903ac58d..57ecc2704e2c3 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -168,6 +168,7 @@ def __getattr__(self, item): _kernel = Virtualized("kernel", NullHandler) _debug = Virtualized("debug", NullHandler) _interpreter = Virtualized("interpreter", NullHandler) +_aot_compilation = Virtualized("aot_compilation", NullHandler) class OpsValue: @@ -272,6 +273,8 @@ class _V: set_kernel_handler = _kernel._set_handler set_debug_handler = _debug._set_handler set_interpreter_handler = _interpreter._set_handler + set_aot_compilation = _aot_compilation._set_handler + get_aot_compilation = _aot_compilation._get_handler @property def ops(self) -> MockHandler: # type: ignore[valid-type] @@ -306,5 +309,9 @@ def debug(self): def interpreter(self): return _interpreter._get_handler() + @property + def aot_compilation(self): + return _aot_compilation._get_handler() + V = _V() diff --git a/torch/csrc/inductor/aot_inductor_model.h b/torch/csrc/inductor/aot_inductor_model.h index 467a2c2d1e6b7..501c0ae0e491f 100644 --- a/torch/csrc/inductor/aot_inductor_model.h +++ b/torch/csrc/inductor/aot_inductor_model.h @@ -23,6 +23,8 @@ namespace torch { namespace aot_inductor { +using ConstantMap = std::unordered_map; + // Defines the base class for AOTInductorModel, which is generated by the // AOTInductor cpp codegen. Since we do not need dynamic dispatch, we rely // on curiously recurring template pattern (CRTP) to save some runtime @@ -32,8 +34,13 @@ namespace aot_inductor { template class AOTInductorModelBase { public: - AOTInductorModelBase(size_t num_inputs, size_t num_outputs) - : inputs_info_(num_inputs), outputs_info_(num_outputs) { + AOTInductorModelBase( + size_t num_inputs, + size_t num_outputs, + size_t num_constants) + : inputs_info_(num_inputs), + outputs_info_(num_outputs), + constants_info_(num_constants) { C10_CUDA_CHECK(cudaEventCreate(&run_finished_)); } @@ -70,6 +77,10 @@ class AOTInductorModelBase { return outputs_info_.size(); } + size_t num_constants() const { + return constants_info_.size(); + } + const char* input_name(int64_t idx) const { return inputs_info_.at(idx).name; } @@ -86,6 +97,10 @@ class AOTInductorModelBase { return outputs_info_.at(idx).dtype; } + const char* constant_name(int64_t idx) const { + return constants_info_.at(idx).name; + } + std::vector max_input_shape(int64_t idx) const { return max_shape(inputs_info_, idx); } @@ -94,6 +109,26 @@ class AOTInductorModelBase { return max_shape(outputs_info_, idx); } + std::vector constant_shape(int64_t idx) const { + return constants_info_.at(idx).shape; + } + + std::vector constant_stride(int64_t idx) const { + return constants_info_.at(idx).stride; + } + + c10::ScalarType constant_type(int64_t idx) const { + return constants_info_.at(idx).dtype; + } + + size_t constant_offset(int64_t idx) const { + return constants_info_.at(idx).offset; + } + + size_t constant_data_size(int64_t idx) const { + return constants_info_.at(idx).data_size; + } + /// Returns true if the model is complete. bool is_finished() { auto event_status = cudaEventQuery(run_finished_); @@ -151,8 +186,20 @@ class AOTInductorModelBase { std::vector shape; }; + struct ConstInfo { + const char* name = nullptr; + std::vector shape; + std::vector stride; + c10::ScalarType dtype; + int64_t offset; + size_t data_size; + }; + std::vector inputs_info_; std::vector outputs_info_; + std::vector constants_info_; + + std::shared_ptr constants_; // Record if the model finishes an inference run so that its owning // AOTModelContainer can re-use this instance. @@ -175,7 +222,7 @@ class AOTInductorModelBase { class AOTInductorModel : public AOTInductorModelBase { public: - AOTInductorModel(); + AOTInductorModel(std::shared_ptr); void run_impl( const std::vector& inputs, @@ -183,8 +230,9 @@ class AOTInductorModel : public AOTInductorModelBase { cudaStream_t stream, ProxyExecutor* proxy_executor = nullptr); - static std::unique_ptr Create() { - return std::make_unique(); + static std::unique_ptr Create( + std::shared_ptr constants) { + return std::make_unique(constants); } }; diff --git a/torch/csrc/inductor/aot_inductor_model_container.h b/torch/csrc/inductor/aot_inductor_model_container.h index 688753ff0eb71..9a290429ed592 100644 --- a/torch/csrc/inductor/aot_inductor_model_container.h +++ b/torch/csrc/inductor/aot_inductor_model_container.h @@ -7,6 +7,30 @@ #include #include +// At codegen time, we write out a binary file called constants.bin. +// We then turn the raw binary to an object file that exposes this +// symbol and link it into the final .so. +// For information on the binary format, see `man objcopy`, under +// the "binary-architecture" flag: +// https://man7.org/linux/man-pages/man1/objcopy.1.html +// todo: use #embed in C++ 23 once available +extern const uint8_t _binary_constants_bin_start[]; +extern const uint8_t _binary_constants_bin_end[]; + +#define AOT_CONST_GPU_ALIGNMENT 64 + +namespace { + +using CUDAPtr = std::unique_ptr>; + +CUDAPtr RAII_cudaMalloc(size_t num_bytes) { + void* data_ptr; + C10_CUDA_CHECK(cudaMalloc((void**)&data_ptr, num_bytes)); + auto deleter = [](void* ptr) { C10_CUDA_CHECK(cudaFree(ptr)); }; + return CUDAPtr(data_ptr, deleter); +} +} // anonymous namespace + namespace torch { namespace aot_inductor { @@ -17,10 +41,11 @@ class AOTInductorModelContainer { << " model instances"; TORCH_CHECK(num_models > 0, "expected num_models to be larger than 0"); + constants_ = std::make_shared(); models_.reserve(num_models); available_models_.reserve(num_models); for (size_t i = 0; i < num_models; ++i) { - models_.push_back(AOTInductorModel::Create()); + models_.push_back(AOTInductorModel::Create(constants_)); available_models_.push_back(models_.back().get()); } @@ -52,6 +77,51 @@ class AOTInductorModelContainer { output_dtypes_.push_back(model->get_output_dtype(i)); max_output_shapes_.emplace_back(model->max_output_shape(i)); } + + size_t num_constants = model->num_constants(); + std::vector constants_internal_offset(num_constants); + // Compute required blob size with 64-alignment + size_t max_blob = 0; + for (size_t i = 0; i < num_constants; i++) { + size_t data_size = model->constant_data_size(i); + if (data_size % AOT_CONST_GPU_ALIGNMENT) { + data_size = AOT_CONST_GPU_ALIGNMENT + + (data_size / AOT_CONST_GPU_ALIGNMENT) * AOT_CONST_GPU_ALIGNMENT; + } + constants_internal_offset[i] = max_blob; + max_blob += data_size; + } + constant_blob_ = RAII_cudaMalloc(max_blob); + + constants_->reserve(num_constants); + auto* constants_ptr = static_cast(constant_blob_.get()); + size_t bytes_read = 0; + for (size_t i = 0; i < num_constants; i++) { + std::string name = model->constant_name(i); + size_t data_size = model->constant_data_size(i); + auto* internal_ptr = constants_ptr + constants_internal_offset[i]; + // Copy data to GPU memory + // TODO: Handle shared storage case. + C10_CUDA_CHECK(cudaMemcpy( + internal_ptr, + _binary_constants_bin_start + bytes_read, + data_size, + cudaMemcpyHostToDevice)); + bytes_read += data_size; + + // Create at::Tensor from copied memory. + auto dtype = model->constant_type(i); + auto size = model->constant_shape(i); + auto stride = model->constant_stride(i); + auto offset = model->constant_offset(i); + + auto tensor = at::for_blob(internal_ptr, size) + .strides(stride) + .storage_offset(offset) + .options(at::device(at::kCUDA).dtype(dtype)) + .make_tensor(); + constants_->emplace(std::move(name), tensor); + } } void run( @@ -124,6 +194,13 @@ class AOTInductorModelContainer { // Holds the upper-bound value for each dimension of any output shape. std::vector> max_output_shapes_; + // Holds the blob storage for constants' at::Tensor. + CUDAPtr constant_blob_; + + // Holds the mapping of constants to at::Tensor. + // The underlying data of at::Tensor is in constant_blob_. + std::shared_ptr constants_; + // Holds all the AOTInductorModel instances owned by this container. std::vector> models_;