In [1]:
import torch
import torchvision
import os
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 32 x 32

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 16 x 16

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 8 x 8

            nn.Flatten(), 
            nn.Linear(16384, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 20))
        
    def forward(self, input):
        return self.network(input)

model = torch.load('t_max.pth')
print(model)

CNN(
  (network): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Flatten(start_dim=1, end_dim=-1)
    (16): Linear(in_features=16384, out_features=1024, bias=True)
    (17): ReLU()
    (18): Linear(in_features=1024, 

In [2]:
test_data = ImageFolder("../data/data/test", transform=ToTensor())
test_loader = DataLoader(test_data, 128*2, pin_memory=True)
test_data_size = len(test_data)

In [3]:
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()

model.eval()
total_test_loss = 0
total_accuracy = 0
for data in test_loader:
    imgs, targets = data
    if torch.cuda.is_available():
        imgs = imgs.cuda()
        targets = targets.cuda()
    output = model(imgs)
    loss = loss_fn(output, targets)
    total_test_loss += loss.item()
    accuracy = (output.argmax(1) == targets).sum()
    total_accuracy += accuracy

print("整体测试集上的Loss: {}".format(total_test_loss))
print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))


整体测试集上的Loss: 12.96743711325098
整体测试集上的正确率: 0.9579623937606812


In [4]:
import cv2

imgs = cv2.imread("mzd1.jpg")
imgs = cv2.resize(imgs, (64,64))
imgs = cv2.cvtColor(imgs, cv2.COLOR_BGR2RGB)


error: OpenCV(4.5.5) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\resize.cpp:4052: error: (-215:Assertion failed) !ssize.empty() in function 'cv::resize'


In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
img = transform(imgs)
img = torch.reshape(img, (1, 3, 64, 64))
model.eval()
with torch.no_grad():
    output = model(img)
print(output)
flag = int(output.argmax(1))

flags = {0: 'bdsr', 1: 'csl', 2: 'fwq', 3: 'gj', 4: 'htj', 5: 'hy', 6: 'lgq', 7: 'lqs', 8: 'lx', 9: 'mf', 10: 'mzd', 11: 'oyx', 12: 'sgt', 13: 'shz', 14: 'smh', 15: 'wxz', 16: 'wzm', 17: 'yyr', 18: 'yzq', 19: 'zmf'}
print(flags[flag])


tensor([[ 0.9330, -6.4784,  2.0502, -7.3038, -0.2205, -3.2098, -6.6562,  0.3002,
         -3.1503,  4.5722, -2.6175, -7.4439, -2.8055, -8.3417, -3.4978,  2.9400,
         -9.9634,  1.0477, -4.9583, -0.8024]])
mf
