diff --git a/references/detection/train.py b/references/detection/train.py index 507d4faebae..722f4b4f72c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -108,20 +108,21 @@ def main(args): # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) - + if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - + args.start_epoch = checkpoint['epoch'] + 1 + if args.test_only: evaluate(model, data_loader_test, device=device) return print("Start training") start_time = time.time() - for epoch in range(args.epochs): + for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) @@ -131,7 +132,8 @@ def main(args): 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), - 'args': args}, + 'args': args, + 'epoch': epoch}, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) # evaluate after every epoch @@ -171,6 +173,7 @@ def main(args): parser.add_argument('--print-freq', default=20, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) parser.add_argument( "--test-only",