# torch.topk の動作チェック

In [4]:
import torch

pred = torch.tensor([0.05, 0.3, 0.2, 0.1, 0.05, 0.1, 0.05, 0.15])
values, indicies = torch.topk(pred, k=3)
print(values)
print(indicies)

tensor([0.3000, 0.2000, 0.1500])
tensor([1, 2, 7])


In [9]:
# (2022.11.14メモ) top-k の正解率を求める
# https://github.com/pytorch/examples/blob/main/imagenet/main.py
# 上記のページの解説：https://zenn.dev/nnabeyang/articles/8b643ca99ddab2a568e0
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# バッチサイズ 3，クラス数 8 を想定
pred = torch.tensor([[0.05, 0.3, 0.2, 0.1, 0.05, 0.1, 0.05, 0.15],
                     [0.05, 0.2, 0.2, 0.1, 0.35, 0.1, 0.05, 0.15],
                     [0.05, 0.05, 0.2, 0.05, 0.05, 0.1, 0.05, 0.45]])

target = torch.tensor([1, 2, 7])

accuracy(pred, target, topk=(1,5))

[tensor([66.6667]), tensor([100.])]