From 962ec0d913bbf6c30496560f12b4726445dce7da Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 31 Oct 2024 19:30:21 -0700 Subject: [PATCH 1/4] [AOTI] Remove the original model weights in Python deployment Summary: Fixes https://github.com/pytorch/torchchat/issues/1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path. --- torchchat/cli/builder.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 511cf1f35..2f087a072 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -544,6 +544,19 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. + # Using cpp runner to run AOTI compiled model is recommended. + # + # Released the loaded model to free up device memory. + # The AOTI-compiled model contains a copy of the model weights. + model.model = None + import gc + gc.collect() + torch.cuda.empty_cache() + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = torch._export.aot_load( str(builder_args.dso_path.absolute()), builder_args.device ) From e3acb5cb21315a0222ffb800bf08bc7955d79b4c Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 5 Nov 2024 09:04:08 -0500 Subject: [PATCH 2/4] Revert "[AOTI] Remove the original model weights in Python deployment" This reverts commit 962ec0d913bbf6c30496560f12b4726445dce7da. --- torchchat/cli/builder.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 47a16b7a4..17bc219f8 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -558,19 +558,6 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. - # Using cpp runner to run AOTI compiled model is recommended. - # - # Released the loaded model to free up device memory. - # The AOTI-compiled model contains a copy of the model weights. - model.model = None - import gc - gc.collect() - torch.cuda.empty_cache() - - def do_nothing(max_batch_size, max_seq_length): - pass - model.setup_caches = do_nothing - model.forward = torch._export.aot_load( str(builder_args.dso_path.absolute()), builder_args.device ) From 96718101a63c326a858c6dadd10be8a6af644323 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 5 Nov 2024 09:32:22 -0500 Subject: [PATCH 3/4] Refactor the code --- torchchat/cli/builder.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 17bc219f8..9da561d44 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -510,6 +510,15 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = _load_model_default(builder_args) # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) + if builder_args.dso_path or builder_args.aoti_package_path: + # AOTI-compoiled model will load its own weights. + # Release weights here to avoid OOM + import gc + if hasattr(model, "model"): + model.model = None + gc.collect() + torch.cuda.empty_cache() + model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() @@ -558,6 +567,12 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. + # Using cpp runner to run AOTI compiled model is recommended. + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = torch._export.aot_load( str(builder_args.dso_path.absolute()), builder_args.device ) From 978baa3bb66d678a5e0636f1efdb37a1c5eeb4ff Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 5 Nov 2024 16:17:11 -0500 Subject: [PATCH 4/4] Add setup_cache for aoti_package_path --- torchchat/cli/builder.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f2a85f4ff..fb2bfb299 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -632,6 +632,11 @@ def do_nothing(max_batch_size, max_seq_length): aoti_compiled_model = load_package( str(builder_args.aoti_package_path.absolute()) ) + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = aoti_compiled_model metadata = aoti_compiled_model.get_metadata() builder_args.device = metadata["AOTI_DEVICE_KEY"]