From 0c9c01a5826750c3131ed28525db398cd5f171e6 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Fri, 14 Jun 2024 13:46:00 +0000 Subject: [PATCH] Fix multi gpu lora merging --- train.py | 4 ++-- utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 394815a..aadb7e1 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,7 @@ from utils import ( maybe_set_custom_tempdir, maybe_set_torch_max_memory, - temporarily_unset_accelerate_envs, + temporarily_unset_distributed_envs, try_cleanup_gpus, ) @@ -252,7 +252,7 @@ def _train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs): model_dir = cfg.output_dir cleanup_checkpoints(output_dir=cfg.output_dir) if cfg.adapter in {"lora", "qlora"}: - with temporarily_unset_accelerate_envs(): + with temporarily_unset_distributed_envs(): axolotl_merge_lora_cli(config=axolotl_config, device_map="auto") model_dir = os.path.join(model_dir, "merged") model_parent_dir = os.path.dirname(model_dir) diff --git a/utils.py b/utils.py index 8fc26a8..e4bb7ee 100644 --- a/utils.py +++ b/utils.py @@ -104,13 +104,13 @@ def maybe_set_torch_max_memory(device: int): @contextlib.contextmanager -def temporarily_unset_accelerate_envs(): - accelerate_envs = {} +def temporarily_unset_distributed_envs(): + old_envs = {} for key in os.environ: - if key.startswith("ACCELERATE_"): - accelerate_envs[key] = os.environ.pop(key) + if key.startswith("ACCELERATE_") or key in {"WORLD_SIZE"}: + old_envs[key] = os.environ.pop(key) yield - os.environ.update(accelerate_envs) + os.environ.update(old_envs) # Notebook Utils