In [1]:
import copy
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def setup_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(20)

In [3]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=1, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )  
        self.fc1 = nn.Linear(64, 64)  
        self.fc2 = nn.Linear(64, 10)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.shape[0], -1)  
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.softmax(x)
        return F.log_softmax(x, dim=1)

In [4]:
def get_val_loss(model, Val):
    model.eval()
    criterion = nn.CrossEntropyLoss().to(device)
    val_loss = []
    for (data, target) in Val:
        data, target = data.to(device), target.long().to(device)
        output = model(data)
        loss = criterion(output, target)
        val_loss.append(loss.cpu().item())

    return np.mean(val_loss)


In [9]:
def train():
    writer = SummaryWriter("log/")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
    val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)
    print('train...')
    epoch_num = 50
    best_model = None
    min_epochs = 5
    min_val_loss = np.inf
    model = cnn().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0008)
    criterion = nn.CrossEntropyLoss().to(device)
    for epoch in tqdm(range(epoch_num), ascii=True):
        train_loss = []
        for batch_idx, (data, target) in enumerate(train_data_loader):
            data, target = data.to(device), target.long().to(device)
            model.train()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.cpu().item())
        val_loss = get_val_loss(model, val_data_loader)
        writer.add_scalar("val_loss", val_loss, epoch)
        if epoch + 1 > min_epochs and val_loss < min_val_loss:
            min_val_loss = val_loss
            best_model = copy.deepcopy(model)
        tqdm.write('Epoch {:03d} train_loss {:.5f} val_loss {:.5f}'.format(epoch, np.mean(train_loss), val_loss))
    torch.save(best_model.state_dict(), "G:/资料/model/cnn.pkl")

In [10]:
train()

train...


  2%|2         | 1/50 [00:12<10:12, 12.51s/it]

Epoch 000 train_loss 1.75925 val_loss 1.67382


  4%|4         | 2/50 [00:24<09:51, 12.33s/it]

Epoch 001 train_loss 1.66943 val_loss 1.66236


  6%|6         | 3/50 [00:36<09:31, 12.15s/it]

Epoch 002 train_loss 1.66210 val_loss 1.65963


  8%|8         | 4/50 [00:48<09:11, 12.00s/it]

Epoch 003 train_loss 1.65761 val_loss 1.65710


 10%|#         | 5/50 [01:00<08:56, 11.92s/it]

Epoch 004 train_loss 1.65487 val_loss 1.65508


 12%|#2        | 6/50 [01:11<08:41, 11.86s/it]

Epoch 005 train_loss 1.65260 val_loss 1.65372


 14%|#4        | 7/50 [01:23<08:28, 11.81s/it]

Epoch 006 train_loss 1.65125 val_loss 1.65384


 16%|#6        | 8/50 [01:35<08:16, 11.82s/it]

Epoch 007 train_loss 1.65054 val_loss 1.65284


 18%|#8        | 9/50 [01:47<08:03, 11.80s/it]

Epoch 008 train_loss 1.64936 val_loss 1.65369


 20%|##        | 10/50 [01:59<07:51, 11.79s/it]

Epoch 009 train_loss 1.64862 val_loss 1.65157


 22%|##2       | 11/50 [02:10<07:40, 11.82s/it]

Epoch 010 train_loss 1.64844 val_loss 1.65170


 24%|##4       | 12/50 [02:23<07:36, 12.01s/it]

Epoch 011 train_loss 1.64670 val_loss 1.65201


 26%|##6       | 13/50 [02:35<07:24, 12.01s/it]

Epoch 012 train_loss 1.64646 val_loss 1.65071


 28%|##8       | 14/50 [02:47<07:09, 11.93s/it]

Epoch 013 train_loss 1.64626 val_loss 1.64838


 30%|###       | 15/50 [02:59<06:58, 11.96s/it]

Epoch 014 train_loss 1.64524 val_loss 1.64865


 32%|###2      | 16/50 [03:11<06:48, 12.02s/it]

Epoch 015 train_loss 1.64446 val_loss 1.65019


 34%|###4      | 17/50 [03:23<06:37, 12.04s/it]

Epoch 016 train_loss 1.64458 val_loss 1.64876


 36%|###6      | 18/50 [03:35<06:26, 12.07s/it]

Epoch 017 train_loss 1.64388 val_loss 1.64670


 38%|###8      | 19/50 [03:47<06:12, 12.01s/it]

Epoch 018 train_loss 1.64311 val_loss 1.64852


 40%|####      | 20/50 [03:59<05:56, 11.90s/it]

Epoch 019 train_loss 1.64302 val_loss 1.64872


 42%|####2     | 21/50 [04:10<05:41, 11.79s/it]

Epoch 020 train_loss 1.64290 val_loss 1.64640


 44%|####4     | 22/50 [04:22<05:27, 11.69s/it]

Epoch 021 train_loss 1.64307 val_loss 1.64625


 46%|####6     | 23/50 [04:33<05:15, 11.67s/it]

Epoch 022 train_loss 1.64213 val_loss 1.64733


 48%|####8     | 24/50 [04:45<05:08, 11.85s/it]

Epoch 023 train_loss 1.64201 val_loss 1.64677


 50%|#####     | 25/50 [04:57<04:56, 11.87s/it]

Epoch 024 train_loss 1.64142 val_loss 1.64795


 52%|#####2    | 26/50 [05:09<04:44, 11.84s/it]

Epoch 025 train_loss 1.64113 val_loss 1.64502


 54%|#####4    | 27/50 [05:21<04:31, 11.82s/it]

Epoch 026 train_loss 1.64111 val_loss 1.64523


 56%|#####6    | 28/50 [05:33<04:22, 11.92s/it]

Epoch 027 train_loss 1.64084 val_loss 1.64787


 58%|#####8    | 29/50 [05:47<04:23, 12.57s/it]

Epoch 028 train_loss 1.64082 val_loss 1.64690


 60%|######    | 30/50 [05:59<04:08, 12.43s/it]

Epoch 029 train_loss 1.64110 val_loss 1.64780


 62%|######2   | 31/50 [06:11<03:53, 12.27s/it]

Epoch 030 train_loss 1.64059 val_loss 1.64624


 64%|######4   | 32/50 [06:23<03:39, 12.20s/it]

Epoch 031 train_loss 1.63977 val_loss 1.64590


 66%|######6   | 33/50 [06:35<03:25, 12.11s/it]

Epoch 032 train_loss 1.64035 val_loss 1.64739


 68%|######8   | 34/50 [06:47<03:12, 12.04s/it]

Epoch 033 train_loss 1.64028 val_loss 1.64806


 70%|#######   | 35/50 [06:59<02:59, 11.98s/it]

Epoch 034 train_loss 1.63975 val_loss 1.64534


 72%|#######2  | 36/50 [07:11<02:47, 11.98s/it]

Epoch 035 train_loss 1.63896 val_loss 1.64597


 74%|#######4  | 37/50 [07:23<02:35, 12.00s/it]

Epoch 036 train_loss 1.63935 val_loss 1.64716


 76%|#######6  | 38/50 [07:35<02:23, 11.96s/it]

Epoch 037 train_loss 1.63955 val_loss 1.64771


 78%|#######8  | 39/50 [07:47<02:11, 11.95s/it]

Epoch 038 train_loss 1.63981 val_loss 1.64457


 80%|########  | 40/50 [07:58<01:59, 11.93s/it]

Epoch 039 train_loss 1.63912 val_loss 1.64476


 82%|########2 | 41/50 [08:11<01:48, 12.00s/it]

Epoch 040 train_loss 1.63892 val_loss 1.64408


 84%|########4 | 42/50 [08:23<01:35, 11.97s/it]

Epoch 041 train_loss 1.63881 val_loss 1.64586


 86%|########6 | 43/50 [08:34<01:23, 11.96s/it]

Epoch 042 train_loss 1.63872 val_loss 1.64633


 88%|########8 | 44/50 [08:46<01:11, 11.95s/it]

Epoch 043 train_loss 1.63846 val_loss 1.64622


 90%|######### | 45/50 [08:58<00:59, 11.91s/it]

Epoch 044 train_loss 1.63829 val_loss 1.64455


 92%|#########2| 46/50 [09:10<00:47, 11.88s/it]

Epoch 045 train_loss 1.63791 val_loss 1.64580


 94%|#########3| 47/50 [09:22<00:35, 11.89s/it]

Epoch 046 train_loss 1.63826 val_loss 1.64557


 96%|#########6| 48/50 [09:34<00:23, 11.91s/it]

Epoch 047 train_loss 1.63823 val_loss 1.64379


 98%|#########8| 49/50 [09:46<00:11, 11.88s/it]

Epoch 048 train_loss 1.63798 val_loss 1.64598


100%|##########| 50/50 [09:57<00:00, 11.96s/it]

Epoch 049 train_loss 1.63777 val_loss 1.64355





In [14]:
def test():
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)
    test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  # 创建测试数据加载器

    # 加载训练好的模型参数
    model = cnn().to(device)
    if not os.path.exists("G:/资料/model/cnn.pkl"):
        print("Model file does not exist. Please check the path.")
        return
    model.load_state_dict(torch.load("G:/资料/model/cnn.pkl", map_location=device))
    model.eval()  # 设置为评估模式

    test_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():  # 在测试过程中不计算梯度
        for data, target in test_data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)  # 获取最大log-probability的索引
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            all_preds.extend(pred.view(-1).cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    test_loss /= total
    accuracy = 100. * correct / total
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')

    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')



In [15]:
test()

  model.load_state_dict(torch.load("G:/资料/model/cnn.pkl", map_location=device))


Test set: Average loss: 1.6434, Accuracy: 7888/10000 (78.88%)
Precision: 0.7105, Recall: 0.7877, F1 Score: 0.7332


  _warn_prf(average, modifier, msg_start, len(result))
