diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 89d83f9f1..5c34d99df 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -37,7 +37,6 @@ configure_logging() LOG = get_logger("axolotl.train") - @dataclass class TrainDatasetMeta: """ @@ -139,10 +138,10 @@ def terminate_handler(_, __, model_weakref): sys.exit(0) _model_weakref = weakref.ref(model) - signal.signal( - signal.SIGINT, - lambda signum, frame: terminate_handler(signum, frame, _model_weakref), - ) + # signal.signal( + # signal.SIGINT, + # lambda signum, frame: terminate_handler(signum, frame, _model_weakref), + # ) badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" @@ -234,6 +233,11 @@ def terminate_handler(_, __, model_weakref): # defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + if cfg.deepspeed: + trainer.deepspeed.destroy() + trainer.accelerator.free_memory() + trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None + return model, tokenizer