In [1]:
import torch
from torch import nn,optim
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
from torchinfo import summary

In [2]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [3]:
torch.set_printoptions(linewidth=1000)

In [4]:
class Config:
    LearningRate = 1e-2
    Epochs = 5
    BatchSize = 64
    Transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 255)
    ])
    TestTransform = transforms.Compose([
        transforms.Lambda(lambda x: ImageOps.invert(x)), #反色处理
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 255)
    ])

In [5]:
#参数设置对象创建
config = Config()

In [6]:
#通过PyTorch下载数据集
train_dataset = datasets.MNIST(root='',train=True,download=True,transform=config.Transform)
test_dataset = datasets.MNIST(root='',train=False,download=True,transform=config.Transform)

In [7]:
train_dataset_size = len(train_dataset)
test_dataset_size = len(test_dataset)
print(train_dataset_size)
print(test_dataset_size)
#shuffle=True可以打乱数据集，batch_size=64将会让这个数据生成器每次给我们64个数据,drop_last=True会把不够64一组的舍去（影响不大）。
train_loader = DataLoader(dataset=train_dataset,batch_size=config.BatchSize,shuffle=True,drop_last=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=config.BatchSize,shuffle=True,drop_last=True)

60000
10000


In [8]:
'''这里先单独获取一次DataLoader的数据，用来观察数据结构'''
#enumerate将可迭代对象组合为索引序列，同时列出数据和数据下标。

for index,data in enumerate(train_loader):
    inputs, labels = data
    with open("sample.txt","w") as f:
        f.write(str(list(inputs[0])))
    print(inputs.shape)
    print(labels)
    print(labels.shape)
    break

torch.Size([64, 1, 28, 28])
tensor([6, 9, 5, 5, 3, 5, 8, 9, 6, 0, 3, 1, 3, 7, 9, 2, 8, 9, 7, 0, 8, 2, 3, 9, 9, 2, 3, 8, 9, 7, 3, 4, 2, 8, 4, 8, 3, 1, 6, 9, 7, 7, 0, 7, 2, 1, 6, 8, 0, 8, 1, 8, 7, 0, 4, 9, 7, 0, 9, 5, 0, 6, 6, 6])
torch.Size([64])


In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),  # [1,28,28] -> [6,24,24]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [6,24,24] -> [6,12,12]
            nn.Conv2d(in_channels=6, out_channels=28, kernel_size=5, stride=1, padding=0),  # [6,12,12] -> [28,8,8]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [28,8,8] -> [28,4,4]
            nn.Flatten(),  # 展平为28*4*4=448维向量
            nn.Linear(in_features=28*4*4, out_features=120),  # 修改输入维度为448
            nn.ReLU(),
            nn.Linear(in_features=120, out_features=84),
            nn.ReLU(),
            nn.Linear(in_features=84, out_features=10)
        )

    def forward(self,x):
        x = self.model(x)
        return x
        

In [10]:
#神经网络模型对象创建
net = Net()
net = net.to(device)
print(summary(net))

Layer (type:depth-idx)                   Param #
Net                                      --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       156
│    └─ReLU: 2-2                         --
│    └─MaxPool2d: 2-3                    --
│    └─Conv2d: 2-4                       4,228
│    └─ReLU: 2-5                         --
│    └─MaxPool2d: 2-6                    --
│    └─Flatten: 2-7                      --
│    └─Linear: 2-8                       53,880
│    └─ReLU: 2-9                         --
│    └─Linear: 2-10                      10,164
│    └─ReLU: 2-11                        --
│    └─Linear: 2-12                      850
Total params: 69,278
Trainable params: 69,278
Non-trainable params: 0


In [11]:
#设置损失函数和优化模式
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
optim = optim.SGD(net.parameters(),config.LearningRate)

In [12]:
def train():
    for index,data in enumerate(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        loss = loss_fn(outputs,labels)
        optim.zero_grad()
        loss.backward()
        optim.step()

def test():
    times = 0
    for index,data in enumerate(test_loader):
        times += 1
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        accuracy = (outputs.argmax(axis=1) == labels).sum()
        if times % 30 == 0:
            print("Test accuracy:{0}".format(accuracy/len(labels)))

In [13]:
for epoch in range(config.Epochs):
    train()
    if epoch % 5 == 0:
        print("epoch {0}".format(epoch))
        test()
print("Final accuracy")
test()

epoch 0
Test accuracy:0.96875
Test accuracy:1.0
Test accuracy:0.96875
Test accuracy:1.0
Test accuracy:0.9375
Final accuracy
Test accuracy:0.984375
Test accuracy:0.96875
Test accuracy:0.953125
Test accuracy:1.0
Test accuracy:1.0


In [None]:
#手写数字检验模型效果
from PIL import ImageOps, ImageGrab
import tkinter as tk
from tkinter import ttk

class DigitRecognizerApp:
    def __init__(self, root):
        self.root = root
        root.title("手写数字识别器")
        root.geometry("500x600")
        root.resizable(False, False)
        
        # 创建顶部标题
        title_frame = ttk.Frame(root)
        title_frame.pack(pady=10)
        ttk.Label(title_frame, text="MNIST手写数字识别", font=("Arial", 16, "bold")).pack()
        
        # 创建画布区域
        canvas_frame = ttk.LabelFrame(root, text="手写区域")
        canvas_frame.pack(pady=10, padx=20, fill="both", expand=True)
        
        self.canvas = tk.Canvas(canvas_frame, width=280, height=280, bg="white", cursor="pencil")
        self.canvas.pack(pady=10)
        
        # 绑定鼠标事件
        self.canvas.bind("<B1-Motion>", self.paint)
        self.canvas.bind("<ButtonRelease-1>", self.predict)
        
        # 创建按钮区域
        button_frame = ttk.Frame(root)
        button_frame.pack(pady=10)
        
        self.clear_btn = ttk.Button(button_frame, text="清除画板", command=self.clear_canvas)
        self.clear_btn.pack(side="left", padx=10)
        
        # 创建预测结果显示区域
        result_frame = ttk.LabelFrame(root, text="识别结果")
        result_frame.pack(pady=10, padx=20, fill="both", expand=True)
        
        self.result_label = ttk.Label(result_frame, text="请手写一个数字...", font=("Arial", 24))
        self.result_label.pack(pady=20)
        
        self.confidence_label = ttk.Label(result_frame, text="", font=("Arial", 14))
        self.confidence_label.pack(pady=5)
        
        # 初始化变量
        self.last_x = None
        self.last_y = None
        self.line_width = 15
        
        # 存储画布位置信息
        self.canvas.update_idletasks()
        self.canvas_x = self.root.winfo_rootx() + self.canvas.winfo_x()
        self.canvas_y = self.root.winfo_rooty() + self.canvas.winfo_y()
        self.canvas_width = self.canvas.winfo_width()
        self.canvas_height = self.canvas.winfo_height()
    
    def paint(self, event):
        x, y = event.x, event.y
        if self.last_x and self.last_y:
            self.canvas.create_line(self.last_x, self.last_y, x, y, 
                                    width=self.line_width, fill="black", 
                                    capstyle=tk.ROUND, smooth=True)
        self.last_x = x
        self.last_y = y
    
    def clear_canvas(self):
        self.canvas.delete("all")
        # 确保背景是纯白色
        self.canvas.create_rectangle(0, 0, 
        self.canvas.winfo_width(), 
        self.canvas.winfo_height(),fill="#ffffff", outline="#ffffff")
        self.canvas.delete("all")
        self.result_label.config(text="请手写一个数字...")
        self.confidence_label.config(text="")
        self.last_x = None
        self.last_y = None
    
    def predict(self, event):
        self.canvas.update_idletasks()  # 确保UI更新
        self.canvas_x = self.root.winfo_rootx() + self.canvas.winfo_x()
        self.canvas_y = self.root.winfo_rooty() + self.canvas.winfo_y()
        # 重置最后位置
        self.last_x = None
        self.last_y = None
        
        # 获取画布在屏幕上的位置
        self.canvas.update_idletasks()
        self.canvas_x = self.root.winfo_rootx() + self.canvas.winfo_x()
        self.canvas_y = self.root.winfo_rooty() + self.canvas.winfo_y()
        
        # 使用ImageGrab直接截取画布区域
        bbox = (
            self.canvas_x, 
            self.canvas_y, 
            self.canvas_x + self.canvas.winfo_width(), 
            self.canvas_y + self.canvas.winfo_height()
        )
        
        img = ImageGrab.grab(bbox=bbox)
        
        # 转换为模型需要的格式
        img = img.convert('L')  # 转换为灰度图
        img = img.resize((28, 28))
        
        # 转换为张量并进行预测
        img_tensor = config.TestTransform(img).unsqueeze(0).to(device)
        with open("test_sample.txt","w") as f:
            f.write(str(list(img_tensor)))
        print(img_tensor)
        
        with torch.no_grad():
            output = net(img_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            pred = output.argmax(dim=1, keepdim=True).item()
            confidence = probabilities[pred].item()
        
        # 显示结果
        self.result_label.config(text=f"识别结果: {pred}")
        self.confidence_label.config(text=f"置信度: {confidence:.2%}")
        
        # 可选：显示处理后的图像用于调试
        #img.show(title="处理后的图像")
        
root = tk.Tk()
app = DigitRecognizerApp(root)
root.mainloop()


tensor([[[[ 15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  13.,  15.,  13.,  15.,  16.,  18.,  23.,  17.,  18.,  19.,  26.,  25.,  24.,  28.,  22.,  16.],
          [ 15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,   9., 104., 126., 138., 111., 112., 128.,  78., 127.,  57.,  99.,  33.,  94.,  90.,  54., 128., 125.,  53.],
          [ 15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  10., 111., 163., 140., 153., 131., 146.,  74.,  83.,  41., 108.,  29.,  55., 102.,  43., 123., 127.,  43.],
          [ 15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  23.,  27.,  25.,  24.,  27.,  35.,  20.,  22.,  17.,  29.,  13.,  14.,  28.,  18.,  35.,  31.,  22.],
          [ 15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  14.,  13.,  13.,  13.,  13.,  12.,  14.,  14.,  15.,  13.,  15.,  15.,  13.,  15.,  12.,  13.,  14.],
          [ 13.,  12.,  14.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,  15.,