In [1]:
import torch
from torch import nn

"""
如果外部一直说 device有问题，我们可以就在模型中加入device
"""


class CnnNet(nn.Module):
    def __init__(self, device):
        super(CnnNet, self).__init__()
        if device is None:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            self.device = device
            raise IOError("设备不可以为空")
        # 3 140 440 作为输入  (140 --> 128 = 12)
        self.conv1 = nn.Conv2d(3, 16, 7, 1, device=device)  # 134 434
        self.conv2 = nn.Conv2d(16, 16, 7, 1, device=device)  # 128 428
        self.conv3 = nn.Conv2d(16, 64, 3, 1, 1, device=device)  # 64 214
        self.conv4 = nn.Conv2d(64, 128, 3, 1, 1, device=device)  # 32 107
        self.conv5 = nn.Conv2d(128, 256, 3, 1, 1, device=device)  # 16 54
        self.conv6 = nn.Conv2d(256, 512, 3, 1, 1, device=device)  # 8 27
        self.conv7 = nn.Conv2d(512, 1024, 3, 1, 1, device=device)  # 4 13

        self.max_pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(4 * 13 * 1024, 1024, device=device)
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(1024, 65 * 7, device=device)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.max_pool(self.relu(self.conv3(x)))
        x = self.max_pool(self.relu(self.conv4(x)))
        x = self.max_pool(self.relu(self.conv5(x)))
        x = self.max_pool(self.relu(self.conv6(x)))
        x = self.max_pool(self.relu(self.conv7(x)))

        x = self.flatten(x)

        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = x.view(-1, 7, 65)
        return self.softmax(x)

# if __name__ == '__main__':
#     model = CNN()
#     images = torch.randn(1, 3, 140, 440)
#     result = model(images)
#     print(result.shape)


In [2]:
pip install pynvml

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting pynvml
  Downloading http://mirrors.aliyun.com/pypi/packages/54/5b/16e50abf152be7f18120f11dfff495014a9eaff7b764626e1656f04ad262/pynvml-11.5.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pynvml
Successfully installed pynvml-11.5.3
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
from torch import nn
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pynvml


class CustomDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


total_dataset = CustomDataset(torch.load("./data/cnn/cnn_images.pth", weights_only=False))
train_dataset = total_dataset[:800]
valid_dataset = total_dataset[800:]
# 加载每个车牌的信息
total_labels = torch.load("./data/cnn/labels.pth", weights_only=False)
train_labels = torch.tensor(total_labels[:800])
valid_labels = torch.tensor(total_labels[800:])

train_loader = DataLoader(train_dataset, batch_size=60, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=60, shuffle=False)
train_labels_loader = DataLoader(train_labels, batch_size=60, shuffle=False)
valid_labels_loader = DataLoader(valid_labels, batch_size=60, shuffle=False)

device = torch.device("cuda:0")
model = CnnNet(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# 显存监控（假如20batch_size 显存未满，则向上调解)）
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)

# 训练模型
num_epochs = 10000
# 最佳最准确率
best_acc = -1
for epoch in range(num_epochs):
    total_loss = 0.0
    total_acc = []
    for images, labels in zip(train_loader, train_labels_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # 前向传播
        outputs = model(images.float())
        # 如果数据不对成，就要对labels加入one_hot
        labels = F.one_hot(labels,65)
        loss = criterion(outputs, labels.float())

        # 反向传播和优化
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    with torch.no_grad():
        for images, labels in zip(valid_loader, valid_labels_loader):
            val_images = images.to(device)
            val_labels = labels.to(device)

            val_outputs = model(val_images.float())
            predicted = torch.argmax(val_outputs,dim=-1)
            
            total_acc.append(torch.sum(predicted == val_labels) / labels.shape[0])

    acc = 100 * (sum(total_acc) / len(total_acc))
    # 打印训练信息
    print('Epoch [{}/{}], Loss: {:.4f}, acc: {:.2f}%'
          .format(epoch + 1, num_epochs, total_loss / len(train_loader), acc))

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), 'best.pth')

    meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
    print(f'GPU in use {(meminfo.used / 1024 ** 3):.4f}/{meminfo.total / 1024 ** 3} G')


Epoch [1/10000], Loss: 0.2098, acc: 18.33%
GPU in use 8.8900/10.0 G
Epoch [2/10000], Loss: 0.2096, acc: 15.83%
GPU in use 8.8900/10.0 G
Epoch [3/10000], Loss: 0.2093, acc: 18.33%
GPU in use 8.8900/10.0 G
Epoch [4/10000], Loss: 0.2094, acc: 17.50%
GPU in use 8.8900/10.0 G
Epoch [5/10000], Loss: 0.2095, acc: 17.92%
GPU in use 8.8900/10.0 G
Epoch [6/10000], Loss: 0.2095, acc: 17.50%
GPU in use 8.8900/10.0 G
Epoch [7/10000], Loss: 0.2095, acc: 17.50%
GPU in use 8.8900/10.0 G
Epoch [8/10000], Loss: 0.2097, acc: 17.50%
GPU in use 8.8900/10.0 G
Epoch [9/10000], Loss: 0.2096, acc: 20.42%
GPU in use 8.8900/10.0 G
Epoch [10/10000], Loss: 0.2093, acc: 14.58%
GPU in use 8.8900/10.0 G
Epoch [11/10000], Loss: 0.2094, acc: 17.08%
GPU in use 8.8900/10.0 G
Epoch [12/10000], Loss: 0.2096, acc: 16.67%
GPU in use 8.8900/10.0 G
Epoch [13/10000], Loss: 0.2096, acc: 16.67%
GPU in use 8.8900/10.0 G
Epoch [14/10000], Loss: 0.2093, acc: 19.17%
GPU in use 8.8900/10.0 G
Epoch [15/10000], Loss: 0.2098, acc: 16.67%

KeyboardInterrupt: 