diff --git a/references/detection/train.py b/references/detection/train.py index 722f4b4f72c..40459d110c3 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -108,7 +108,7 @@ def main(args): # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) - + if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model'])