Skip to content

Commit

Permalink
Fix multi gpu lora merging
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jun 14, 2024
1 parent 33bb49d commit 0c9c01a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from utils import (
maybe_set_custom_tempdir,
maybe_set_torch_max_memory,
temporarily_unset_accelerate_envs,
temporarily_unset_distributed_envs,
try_cleanup_gpus,
)

Expand Down Expand Up @@ -252,7 +252,7 @@ def _train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs):
model_dir = cfg.output_dir
cleanup_checkpoints(output_dir=cfg.output_dir)
if cfg.adapter in {"lora", "qlora"}:
with temporarily_unset_accelerate_envs():
with temporarily_unset_distributed_envs():
axolotl_merge_lora_cli(config=axolotl_config, device_map="auto")
model_dir = os.path.join(model_dir, "merged")
model_parent_dir = os.path.dirname(model_dir)
Expand Down
10 changes: 5 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def maybe_set_torch_max_memory(device: int):


@contextlib.contextmanager
def temporarily_unset_accelerate_envs():
accelerate_envs = {}
def temporarily_unset_distributed_envs():
old_envs = {}
for key in os.environ:
if key.startswith("ACCELERATE_"):
accelerate_envs[key] = os.environ.pop(key)
if key.startswith("ACCELERATE_") or key in {"WORLD_SIZE"}:
old_envs[key] = os.environ.pop(key)
yield
os.environ.update(accelerate_envs)
os.environ.update(old_envs)


# Notebook Utils
Expand Down

0 comments on commit 0c9c01a

Please sign in to comment.