In [1]:
import torchvision

teseset = torchvision.datasets.ImageNet(root='../data/torch_imagenet', split='val')

In [2]:
import torch
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms

In [3]:
if __name__ == "__main__":
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         normalize,
         ])

    test_set = torchvision.datasets.ImageNet(root='../data/torch_imagenet', transform=transform, split='val')
    test_loader = data.DataLoader(test_set, batch_size=100, shuffle=True, num_workers=4)

    model = torchvision.models.resnet18(pretrained=True).to(device)
    model.eval()

    correct_top1 = 0
    correct_top5 = 0
    total = 0

    with torch.no_grad():
        for idx, (images, labels) in enumerate(test_loader):

            images = images.to(device)      # [100, 3, 224, 224]
            labels = labels.to(device)      # [100]
            outputs = model(images)

            # ------------------------------------------------------------------------------
            # rank 1
            _, pred = torch.max(outputs, 1)
            total += labels.size(0)
            correct_top1 += (pred == labels).sum().item()

            # ------------------------------------------------------------------------------
            # rank 5
            _, rank5 = outputs.topk(5, 1, True, True)
            rank5 = rank5.t()
            correct5 = rank5.eq(labels.view(1, -1).expand_as(rank5))

            # ------------------------------------------------------------------------------
            for k in range(6):
                correct_k = correct5[:k].reshape(-1).float().sum(0, keepdim=True)

            correct_top5 += correct_k.item()

            print("step : {} / {}".format(idx + 1, len(test_set)/int(labels.size(0))))
            print("top-1 percentage :  {0:0.2f}%".format(correct_top1 / total * 100))
            print("top-5 percentage :  {0:0.2f}%".format(correct_top5 / total * 100))
            
    print("top-1 percentage :  {0:0.2f}%".format(correct_top1 / total * 100))
    print("top-5 percentage :  {0:0.2f}%".format(correct_top5 / total * 100))

step : 1 / 500.0
top-1 percentage :  70.00%
top-5 percentage :  87.00%
step : 2 / 500.0
top-1 percentage :  68.00%
top-5 percentage :  88.00%
step : 3 / 500.0
top-1 percentage :  67.33%
top-5 percentage :  88.67%
step : 4 / 500.0
top-1 percentage :  67.50%
top-5 percentage :  89.50%
step : 5 / 500.0
top-1 percentage :  68.20%
top-5 percentage :  90.00%
step : 6 / 500.0
top-1 percentage :  68.00%
top-5 percentage :  89.17%
step : 7 / 500.0
top-1 percentage :  68.43%
top-5 percentage :  89.43%
step : 8 / 500.0
top-1 percentage :  68.38%
top-5 percentage :  89.12%
step : 9 / 500.0
top-1 percentage :  69.44%
top-5 percentage :  89.33%
step : 10 / 500.0
top-1 percentage :  70.10%
top-5 percentage :  90.10%
step : 11 / 500.0
top-1 percentage :  70.73%
top-5 percentage :  90.00%
step : 12 / 500.0
top-1 percentage :  70.50%
top-5 percentage :  90.08%
step : 13 / 500.0
top-1 percentage :  70.00%
top-5 percentage :  90.00%
step : 14 / 500.0
top-1 percentage :  70.00%
top-5 percentage :  89.93%
s