From 4e24578abf9222a32039b19fcfb06397390db744 Mon Sep 17 00:00:00 2001 From: MultiK <596286458@qq.com> Date: Mon, 2 Dec 2019 20:48:30 +0800 Subject: [PATCH 1/2] fix a little bug about resume When resuming, we need to start from the last epoch not 0. --- references/detection/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 507d4faebae..7097f5e71d9 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -108,20 +108,20 @@ def main(args): # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) - + last_epoch = 0 if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - + last_epoch = lr_scheduler.last_epoch if args.test_only: evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() - for epoch in range(args.epochs): + for epoch in range(last_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) From 7bb3ed196278230042df7bd262840a9267742884 Mon Sep 17 00:00:00 2001 From: MultiK <596286458@qq.com> Date: Tue, 3 Dec 2019 09:55:04 +0800 Subject: [PATCH 2/2] the second way for resuming the second way for resuming --- references/detection/train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 7097f5e71d9..722f4b4f72c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -108,20 +108,21 @@ def main(args): # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) - last_epoch = 0 + if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - last_epoch = lr_scheduler.last_epoch + args.start_epoch = checkpoint['epoch'] + 1 + if args.test_only: evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() - for epoch in range(last_epoch, args.epochs): + for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) @@ -131,7 +132,8 @@ def main(args): 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), - 'args': args}, + 'args': args, + 'epoch': epoch}, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) # evaluate after every epoch @@ -171,6 +173,7 @@ def main(args): parser.add_argument('--print-freq', default=20, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) parser.add_argument( "--test-only",