In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # 添加这行导入语句
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义数据预处理和加载器
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=2)

# 定义CNN模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化CNN模型和优化器
cnn_model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn_model.parameters(), lr=0.01)

# 训练CNN模型
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = cnn_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / (i + 1)}')

print('Finished Training')

# 测试CNN模型
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        outputs = cnn_model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on the test set: %d %%' % (100 * correct / total))


Files already downloaded and verified
Epoch 1, Loss: 2.1010858869308704


KeyboardInterrupt: 

In [4]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
# 测试CNN模型
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

correct = 0
total = 0
correct = 0
total = 0
all_predicted = []
all_labels = []

with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        outputs = cnn_model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        all_predicted.extend(predicted.tolist())
        all_labels.extend(labels.tolist())

accuracy = 100 * correct / total
precision = precision_score(all_labels, all_predicted, average='macro')
recall = recall_score(all_labels, all_predicted, average='macro')
f1 = f1_score(all_labels, all_predicted, average='macro')
conf_matrix = confusion_matrix(all_labels, all_predicted)

print(f'Accuracy on the test set: {accuracy:.2f}%')
print(f'Precision on the test set: {precision*100:.2f}%')
print(f'Recall on the test set: {recall*100:.2f}%')
print(f'F1 score on the test set: {f1*100:.2f}%')
print('Confusion Matrix:')
print(conf_matrix)


Files already downloaded and verified
Accuracy on the test set: 57.00%
Precision on the test set: 57.42%
Recall on the test set: 57.00%
F1 score on the test set: 55.82%
Confusion Matrix:
[[747  41  28   7  14   9  16   8 114  16]
 [ 64 803   6   7   3   2   8   8  62  37]
 [149  29 390  60 108  61 102  53  42   6]
 [ 72  28  66 348  80 139 130  74  47  16]
 [ 73  19 123  46 429  31 136 114  27   2]
 [ 46   7  79 173  56 419  61 130  27   2]
 [ 28  21  53  48  48   9 744  24  18   7]
 [ 50  12  19  45  68  49  30 687  14  26]
 [159  60   2  11   2   6   6  12 726  16]
 [ 83 331   9  17   6   2  19  39  87 407]]
