# 测试集图像分类预测结果

使用训练好的图像分类模型，预测测试集的所有图像，得到预测结果表格。

同济子豪兄：https://space.bilibili.com/1900783

[代码运行云GPU环境](https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1)：GPU RTX 3060、CUDA v11.2

## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [3]:
# 数据集文件夹路径
dataset_dir = 'fruit30_split'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

测试集图像数量 1078
类别个数 30
各类别名称 ['哈密瓜', '圣女果', '山竹', '杨梅', '柚子', '柠檬', '桂圆', '梨', '椰子', '榴莲', '火龙果', '猕猴桃', '石榴', '砂糖橘', '胡萝卜', '脐橙', '芒果', '苦瓜', '苹果-红', '苹果-青', '草莓', '荔枝', '菠萝', '葡萄-白', '葡萄-红', '西瓜', '西红柿', '车厘子', '香蕉', '黄瓜']
['哈密瓜', '圣女果', '山竹', '杨梅', '柚子', '柠檬', '桂圆', '梨', '椰子', '榴莲', '火龙果', '猕猴桃', '石榴', '砂糖橘', '胡萝卜', '脐橙', '芒果', '苦瓜', '苹果-红', '苹果-青', '草莓', '荔枝', '菠萝', '葡萄-白', '葡萄-红', '西瓜', '西红柿', '车厘子', '香蕉', '黄瓜']


## 导入训练好的模型

In [4]:
model = torch.load('checkpoints/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [5]:
test_dataset.imgs[:10]

[('fruit30_split/val/哈密瓜/106.jpg', 0),
 ('fruit30_split/val/哈密瓜/109.jpg', 0),
 ('fruit30_split/val/哈密瓜/114.jpg', 0),
 ('fruit30_split/val/哈密瓜/116.jpg', 0),
 ('fruit30_split/val/哈密瓜/118.png', 0),
 ('fruit30_split/val/哈密瓜/123.jpg', 0),
 ('fruit30_split/val/哈密瓜/127.jpg', 0),
 ('fruit30_split/val/哈密瓜/129.jpg', 0),
 ('fruit30_split/val/哈密瓜/131.jpg', 0),
 ('fruit30_split/val/哈密瓜/133.jpg', 0)]

In [6]:
img_paths = [each[0] for each in test_dataset.imgs]

In [7]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [8]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,fruit30_split/val/哈密瓜/106.jpg,0,哈密瓜
1,fruit30_split/val/哈密瓜/109.jpg,0,哈密瓜
2,fruit30_split/val/哈密瓜/114.jpg,0,哈密瓜
3,fruit30_split/val/哈密瓜/116.jpg,0,哈密瓜
4,fruit30_split/val/哈密瓜/118.png,0,哈密瓜
...,...,...,...
1073,fruit30_split/val/黄瓜/87.jpg,29,黄瓜
1074,fruit30_split/val/黄瓜/9.jpg,29,黄瓜
1075,fruit30_split/val/黄瓜/91.png,29,黄瓜
1076,fruit30_split/val/黄瓜/94.jpg,29,黄瓜


## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [9]:
# 记录 top-n 预测结果
n = 3

In [10]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred.append(pred_dict, ignore_index=True)

  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred_dict, ignore_index=True)
  df_pred = df_pred.append(pred

In [11]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,哈密瓜-预测置信度,圣女果-预测置信度,山竹-预测置信度,...,草莓-预测置信度,荔枝-预测置信度,菠萝-预测置信度,葡萄-白-预测置信度,葡萄-红-预测置信度,西瓜-预测置信度,西红柿-预测置信度,车厘子-预测置信度,香蕉-预测置信度,黄瓜-预测置信度
0,4,柚子,5,柠檬,7,梨,False,0.0023508405,4.9450587e-06,4.040416e-07,...,1.8150844e-07,1.2343869e-06,3.243423e-06,1.1201442e-05,6.4479555e-06,0.00011587413,0.00012866968,4.1425835e-07,4.6851837e-06,6.217669e-07
1,6,桂圆,0,哈密瓜,8,椰子,True,0.38112196,2.2620793e-07,5.739677e-06,...,7.8048956e-08,1.3498897e-06,9.750311e-07,0.0015106951,4.2909367e-05,0.00015741796,6.638699e-07,3.0484532e-06,3.217701e-05,2.3868986e-06
2,0,哈密瓜,26,西红柿,23,葡萄-白,True,0.53620106,0.0065545,0.006593606,...,0.00933481,0.0071764397,0.0010388161,0.037527516,0.034991886,0.0015775696,0.26540196,0.00016203609,0.005668558,0.001115545
3,0,哈密瓜,16,芒果,4,柚子,True,0.7525962,7.142698e-05,2.3597856e-06,...,3.1976517e-05,0.00025365953,6.0032762e-05,0.0015842838,3.3540405e-06,0.00027976301,0.00072566525,2.2601515e-07,0.02193577,0.00038454978
4,4,柚子,11,猕猴桃,23,葡萄-白,False,0.0050168657,0.00012685276,3.7899506e-05,...,0.0007075434,6.7957946e-05,7.4083924e-05,0.115253285,0.00076184527,0.00039999862,0.0028934702,2.9521214e-08,0.00033459536,0.0004361433
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1073,29,黄瓜,25,西瓜,17,苦瓜,True,5.93891e-06,1.0501529e-08,9.4484065e-09,...,2.6475979e-08,2.7889127e-10,1.8306792e-09,1.6883534e-05,4.3673448e-10,0.00083294965,6.809828e-07,3.9479784e-09,2.2288005e-07,0.99909365
1074,29,黄瓜,17,苦瓜,19,苹果-青,True,1.1680233e-07,1.3517191e-08,1.0803882e-09,...,1.07798314e-07,3.7485877e-09,9.369248e-09,3.7866417e-07,2.3112412e-08,1.9973343e-06,6.2446254e-08,2.1707587e-08,7.969958e-07,0.9997627
1075,29,黄瓜,17,苦瓜,23,葡萄-白,True,0.0010911732,0.0001325374,3.1810266e-06,...,0.00028670047,0.00068920944,0.00012004605,0.0062547335,0.0005051993,1.2841905e-05,0.00014438466,7.0835677e-07,0.00039330628,0.9471092
1076,29,黄瓜,17,苦瓜,10,火龙果,True,1.3076195e-06,2.0403022e-05,2.5345668e-07,...,0.00047119765,2.4162005e-06,0.00011817078,0.0013546849,0.00046067874,0.0014860139,4.860827e-05,4.6543205e-06,0.00011344947,0.9820547


## 拼接AB两张表格

In [12]:
df = pd.concat([df, df_pred], axis=1)

In [13]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,...,草莓-预测置信度,荔枝-预测置信度,菠萝-预测置信度,葡萄-白-预测置信度,葡萄-红-预测置信度,西瓜-预测置信度,西红柿-预测置信度,车厘子-预测置信度,香蕉-预测置信度,黄瓜-预测置信度
0,fruit30_split/val/哈密瓜/106.jpg,0,哈密瓜,4,柚子,5,柠檬,7,梨,False,...,1.8150844e-07,1.2343869e-06,3.243423e-06,1.1201442e-05,6.4479555e-06,0.00011587413,0.00012866968,4.1425835e-07,4.6851837e-06,6.217669e-07
1,fruit30_split/val/哈密瓜/109.jpg,0,哈密瓜,6,桂圆,0,哈密瓜,8,椰子,True,...,7.8048956e-08,1.3498897e-06,9.750311e-07,0.0015106951,4.2909367e-05,0.00015741796,6.638699e-07,3.0484532e-06,3.217701e-05,2.3868986e-06
2,fruit30_split/val/哈密瓜/114.jpg,0,哈密瓜,0,哈密瓜,26,西红柿,23,葡萄-白,True,...,0.00933481,0.0071764397,0.0010388161,0.037527516,0.034991886,0.0015775696,0.26540196,0.00016203609,0.005668558,0.001115545
3,fruit30_split/val/哈密瓜/116.jpg,0,哈密瓜,0,哈密瓜,16,芒果,4,柚子,True,...,3.1976517e-05,0.00025365953,6.0032762e-05,0.0015842838,3.3540405e-06,0.00027976301,0.00072566525,2.2601515e-07,0.02193577,0.00038454978
4,fruit30_split/val/哈密瓜/118.png,0,哈密瓜,4,柚子,11,猕猴桃,23,葡萄-白,False,...,0.0007075434,6.7957946e-05,7.4083924e-05,0.115253285,0.00076184527,0.00039999862,0.0028934702,2.9521214e-08,0.00033459536,0.0004361433
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1073,fruit30_split/val/黄瓜/87.jpg,29,黄瓜,29,黄瓜,25,西瓜,17,苦瓜,True,...,2.6475979e-08,2.7889127e-10,1.8306792e-09,1.6883534e-05,4.3673448e-10,0.00083294965,6.809828e-07,3.9479784e-09,2.2288005e-07,0.99909365
1074,fruit30_split/val/黄瓜/9.jpg,29,黄瓜,29,黄瓜,17,苦瓜,19,苹果-青,True,...,1.07798314e-07,3.7485877e-09,9.369248e-09,3.7866417e-07,2.3112412e-08,1.9973343e-06,6.2446254e-08,2.1707587e-08,7.969958e-07,0.9997627
1075,fruit30_split/val/黄瓜/91.png,29,黄瓜,29,黄瓜,17,苦瓜,23,葡萄-白,True,...,0.00028670047,0.00068920944,0.00012004605,0.0062547335,0.0005051993,1.2841905e-05,0.00014438466,7.0835677e-07,0.00039330628,0.9471092
1076,fruit30_split/val/黄瓜/94.jpg,29,黄瓜,29,黄瓜,17,苦瓜,10,火龙果,True,...,0.00047119765,2.4162005e-06,0.00011817078,0.0013546849,0.00046067874,0.0014860139,4.860827e-05,4.6543205e-06,0.00011344947,0.9820547


## 导出完整表格

In [14]:
df.to_csv('测试集预测结果.csv', index=False)