Skip to content

Commit

Permalink
Fix memory leak from sigint signal and deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Apr 29, 2024
1 parent 5294653 commit 7ac62f5
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
configure_logging()
LOG = get_logger("axolotl.train")


@dataclass
class TrainDatasetMeta:
"""
Expand Down Expand Up @@ -126,15 +125,18 @@ def train(

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:

def terminate_handler(_, __, model):
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
import weakref
def terminate_handler(_, __, model_weakref):
if model_weakref():
_model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)

_model_weakref = weakref.ref(model)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, _model_weakref)
)

badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
Expand Down Expand Up @@ -216,6 +218,11 @@ def terminate_handler(_, __, model):
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()

if cfg.deepspeed:
trainer.deepspeed.destroy()
trainer.accelerator.free_memory()
trainer.model, trainer.optimizer = None, None

return model, tokenizer


Expand Down

0 comments on commit 7ac62f5

Please sign in to comment.