-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
150 lines (126 loc) · 5.54 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from nets.WCCNN import *
import matplotlib.pyplot as plt
from utils import read_split_data, plot_data_loader_image
from my_dataset import MyDataSet
from prettytable import PrettyTable
from tqdm import tqdm
import numpy as np
import json
class ConfusionMatrix(object):
"""
注意,如果显示的图像不全,是matplotlib版本问题
本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
需要额外安装prettytable库
"""
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
# precision, recall, specificity
table = PrettyTable()
table.field_names = ["", "Precision", "Recall", "Specificity"]
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
table.add_row([self.labels[i], Precision, Recall, Specificity])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues)
# 设置x轴坐标label
plt.xticks(range(self.num_classes), self.labels, rotation=45)
# 设置y轴坐标label
plt.yticks(range(self.num_classes), self.labels)
# 显示colorbar
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
# 在图中标注数量/概率信息
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
model_path = ".\models\WCCNN_A_Dropout0.4_1000.pth" # 预测模型路径
#定义训练的设备
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
#加载自制数据集
# root = "./testset/0" # 数据集所在根目录
# root = "./testset/1" # 数据集所在根目录
# root = "./testset/2" # 数据集所在根目录
root = "./testset/3" # 数据集所在根目录
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
#添加tensorboard
#writer=SummaryWriter("logs",flush_secs=5)
test_data_set = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform="1")
test_data_size=len(test_data_set)
#加载数据集
batch_size = 64
test_dataloader = torch.utils.data.DataLoader(test_data_set,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=test_data_set.collate_fn)
#加载网络模型
model=torch.load(model_path)
model=model.to(device) #将模型加载到cuda
#读取 class_indict的json文件并获取类别便签
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
json_file = open(json_label_path, 'r')
class_indict = json.load(json_file)
labels = [label for _, label in class_indict.items()]
confusion = ConfusionMatrix(num_classes=13, labels=labels) #设置类别数量
#test_real_lable = [] #存储测试集的真实标签
total_correct_num=0 #总体的正确率
model.eval() #设置为测试模式
with torch.no_grad():
for data in tqdm(test_dataloader):
imgs, targets = data
imgs = imgs.type(torch.cuda.FloatTensor)
#test_real_lable.append(targets.numpy())
imgs = imgs.to(device) # 将图片加载到cuda上训练
targets = targets.to(device) # 加载到cuda上训练
outputs = model(imgs)
outputs = torch.softmax(outputs, dim=1)
outputs = torch.argmax(outputs, dim=1)
confusion.update(outputs.to("cpu").numpy(), targets.to("cpu").numpy())
#correct_num = (outputs.argmax(1) == targets).sum() # 1:表示横向取最大值所在项
#total_correct_num = total_correct_num + correct_num # 计算预测正确的总数
#print("测试集总体正确率为: {}".format(total_correct_num / test_data_size))
confusion.plot()
confusion.summary()
#writer.flush()
#writer.close()