diff --git a/references/classification/train.py b/references/classification/train.py index 0b855d105c9..29de4fce91c 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -325,6 +325,8 @@ def main(args): args.start_epoch = checkpoint["epoch"] + 1 if model_ema: model_ema.load_state_dict(checkpoint["model_ema"]) + if scaler: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: # We disable the cudnn benchmarking because it can noticeably affect the accuracy @@ -356,6 +358,8 @@ def main(args): } if model_ema: checkpoint["model_ema"] = model_ema.state_dict() + if scaler: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))