From 52ab3c1a9d66f79b16a686eb2fd4ca6e0c3f4961 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Fri, 12 Nov 2021 16:00:36 +0800 Subject: [PATCH] save grad_scaler if use amp for better resume --- references/classification/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/references/classification/train.py b/references/classification/train.py index 0b855d105c9..29de4fce91c 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -325,6 +325,8 @@ def main(args): args.start_epoch = checkpoint["epoch"] + 1 if model_ema: model_ema.load_state_dict(checkpoint["model_ema"]) + if scaler: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: # We disable the cudnn benchmarking because it can noticeably affect the accuracy @@ -356,6 +358,8 @@ def main(args): } if model_ema: checkpoint["model_ema"] = model_ema.state_dict() + if scaler: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))