diff --git a/imagenet/main.py b/imagenet/main.py index f0f196a1e5..2a1540ba13 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -71,7 +71,14 @@ def main(): else: model = torch.nn.DataParallel(model).cuda() - # optionally resume from a checkpoint + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) @@ -79,6 +86,7 @@ def main(): args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: @@ -112,13 +120,6 @@ def main(): batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() - - optimizer = torch.optim.SGD(model.parameters(), args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay) - if args.evaluate: validate(val_loader, model, criterion) return @@ -140,6 +141,7 @@ def main(): 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, + 'optimizer' : optimizer.state_dict(), }, is_best)