Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')

best_prec1 = 0
best_acc1 = 0


def main():
global args, best_prec1
global args, best_acc1
args = parser.parse_args()

if args.seed is not None:
Expand Down Expand Up @@ -122,7 +122,7 @@ def main():
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
best_acc1 = checkpoint['best_acc1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
Expand Down Expand Up @@ -179,16 +179,16 @@ def main():
train(train_loader, model, criterion, optimizer, epoch)

# evaluate on validation set
prec1 = validate(val_loader, model, criterion)
acc1 = validate(val_loader, model, criterion)

# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
}, is_best)

Expand Down Expand Up @@ -217,10 +217,10 @@ def train(train_loader, model, criterion, optimizer, epoch):
loss = criterion(output, target)

# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))

# compute gradient and do SGD step
optimizer.zero_grad()
Expand All @@ -236,8 +236,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))

Expand All @@ -263,10 +263,10 @@ def validate(val_loader, model, criterion):
loss = criterion(output, target)

# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
Expand All @@ -276,12 +276,12 @@ def validate(val_loader, model, criterion):
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

return top1.avg
Expand Down