# LeNet-5 手写数字识别 (PyTorch + TensorBoard)

In [None]:
# Imports
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter


In [None]:
# === 全局超参数与配置 ===
# 训练参数
BATCH_SIZE = 64          # 批大小
LEARNING_RATE = 0.01     # 初始学习率
MOMENTUM = 0.9           # 优化器动量
EPOCHS = 5               # 训练总轮数

# 设备配置
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"当前运行设备: {DEVICE}")

## 计算公式与关键参数速览（LeNet-5 相关）
- 卷积输出尺寸（单维度，高或宽相同计算）
  - 公式：$H_{out} = \left\lfloor \frac{H_{in} + 2P - D\cdot(K-1) - 1}{S} + 1 \right\rfloor$
  - 符号：$K$=kernel_size，$S$=stride，$P$=padding，$D$=dilation。
- 池化输出尺寸（平均/最大池化同理）
  - 公式：$H_{out} = \left\lfloor \frac{H_{in} + 2P - (K-1) - 1}{S} + 1 \right\rfloor$
- 参数量（Parameter Count）估算
  - Conv2d：$\text{params} = C_{out} \times \left(\frac{C_{in}}{\text{groups}} \times K_h \times K_w\right) + (\text{bias? } C_{out}:0)$
  - Linear：$\text{params} = \text{in\_features} \times \text{out\_features} + (\text{bias? } \text{out\_features}:0)$
- 交叉熵损失 CrossEntropyLoss（训练单元使用）


In [None]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5,
                      stride=1, dilation=1, groups=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2,
                         ceil_mode=False, count_include_pad=False),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5,
                      stride=1, dilation=1, groups=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2,
                         ceil_mode=False, count_include_pad=False),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5,
                      stride=1, dilation=1, groups=1),
            nn.Tanh()
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84), #全连接层
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=num_classes)
        )

    def forward(self, x):
        # x 形状：Batchs×1×32×32
        x = self.features(x)   # 经过卷积/池化后：Batchs×120×1×1
        x = x.view(x.size(0), -1)  # 展平为：Batchs×120（保留批维度N(x.size(0))，合并其余维度）
        x = self.classifier(x)  # 全连接分类：N×10
        return x

model = LeNet5()
print(model)


In [None]:
transform = transforms.Compose([
    transforms.Pad(2),
    # ToTensor：PIL Image -> Tensor，且把像素值从 [0,255] 映射到 [0.0,1.0]
    transforms.ToTensor()# ,
    # transforms.Normalize((0.1307,), (0.3081,))
])

# root 指向当前目录 '.'，torchvision 会在 './MNIST' 下查找 raw/processed
train_dataset = datasets.MNIST(root='.', train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root='.', train=False, download=False, transform=transform)

# 使用全局 BATCH_SIZE
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print('Train samples:', len(train_dataset), 'Test samples:', len(test_dataset))

In [None]:
# 获取一个batch并查看类型
for batch in train_loader:
    print("Batch类型:", type(batch))
    print("Batch长度:", len(batch))
    
    for i, item in enumerate(batch):
        print(f"第{i}个元素类型: {type(item)}")
        print(f"第{i}个元素形状: {item.shape}")
        print(f"第{i}个元素数据类型: {item.dtype}")
    break


In [None]:
def train(model, device, train_loader, optimizer, loss_function, epoch, writer):
    model.train()
    # 从 1 开始编号
    for batch_idx, (inputs, targets) in enumerate(train_loader, start=1):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()

        # 计算当前 batch 的准确率
        preds = outputs.argmax(dim=1) 
        correct = (preds == targets).sum().item()
        accuracy = 100.0 * correct / targets.size(0)

        # 写入 TensorBoard
        global_step = (epoch - 1) * len(train_loader) + (batch_idx - 1)
        writer.add_scalar('metrics/train/loss', loss.item(), global_step)
        writer.add_scalar('metrics/train/accuracy', accuracy, global_step)
        
        # 打印
        print(f"Epoch {epoch} [{batch_idx}/{len(train_loader)}]  Loss: {loss.item():.4f}  Acc: {accuracy:.2f}%")
        
        # 强制刷新写入，确保数据立即写入磁盘
        writer.flush()


def test(model, device, test_loader, loss_function, epoch, writer):
    """在完整测试集上评估"""
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
            bsize = targets.size(0)
            total_loss += loss.item() * bsize
            preds = outputs.argmax(dim=1)
            total_correct += (preds == targets).sum().item()
            total_samples += bsize

    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples
    print(f"\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {total_correct}/{total_samples} ({accuracy:.2f}%)\n")

    ''' 
    写入 TensorBoard可视化显示
    writer.add_scalar('metrics/test/loss', avg_loss, epoch)
    writer.add_scalar('metrics/test/accuracy', accuracy, epoch)
    writer.flush()
    '''

    return avg_loss, accuracy

In [None]:
import os
import shutil
import subprocess
import time
import webbrowser
from tensorboard import program

# 1. 定义日志目录 (绝对路径)
tb_logdir = "./runs/lenet_mnist"
print(f"TensorBoard 监控目录: {tb_logdir}")

# 2. 尝试关闭占用端口 6006 的 TensorBoard 进程 (Windows)
# 这样可以释放对日志文件的锁定，允许我们清空目录
print("正在检查端口 6006...")
try:
    result = subprocess.check_output('netstat -ano | findstr :6006', shell=True).decode()
    if result:
        print("发现旧 TensorBoard 进程，正在终止...")
        pids = set()
        for line in result.strip().split('\n'):
            parts = line.split()
            if len(parts) >= 5:
                pid = parts[-1]
                pids.add(pid)
        
        for pid in pids:
            subprocess.run(f'taskkill /F /PID {pid}', 
                           shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        time.sleep(2) # 等待系统释放资源
        print("旧 TensorBoard 已关闭。")
except subprocess.CalledProcessError:
    # findstr 返回非 0 表示没找到，即没有进程在运行
    print("端口 6006 未被占用。")
except Exception as e:
    print(f"尝试关闭进程时出错 (可忽略): {e}")

# 3. 清空日志目录 (确保无历史曲线)
if os.path.exists(tb_logdir):
    try:
        shutil.rmtree(tb_logdir)
        print("已清空历史日志目录。")
    except Exception as e:
        print(f"警告: 无法完全清空目录: {e}")

os.makedirs(tb_logdir, exist_ok=True)

# 4. 启动 TensorBoard
tb = program.TensorBoard()
# --reload_interval 1: 设置后端每 1 秒去读取一次磁盘数据（默认通常是 5 秒）
# 这能让数据更新更及时，但前端页面仍需开启自动刷新
tb.configure(argv=[None, '--logdir', tb_logdir, '--port', '6006', '--host', '127.0.0.1', '--reload_interval', '1'])
url = tb.launch()

# 打开浏览器
webbrowser.open(url)

In [None]:
model = LeNet5().to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
loss_function = nn.CrossEntropyLoss()

writer = SummaryWriter("./runs/lenet_mnist")

for epoch in range(1, 1 + EPOCHS):
    t0 = time.time()
    train(model, DEVICE, train_loader, optimizer, loss_function, epoch, writer)
    test(model, DEVICE, test_loader, loss_function, epoch, writer)
    print(f'Epoch {epoch} finished in {time.time() - t0:.1f}s')

model_path = './lenet_mnist.pth'
torch.save(model.state_dict(), model_path)
print('Model saved to', model_path)
writer.close()

In [None]:
# 示例：加载保存的模型并可视化若干预测结果
# 加载模型
model = LeNet5()
model.load_state_dict(torch.load('./lenet_mnist.pth', map_location='cpu'))
model.eval()

# 取几张测试图像并预测
examples = []
labels = []
with torch.no_grad():
    for i, (data, target) in enumerate(test_loader):
        if i >= 1:
            break
        outputs = model(data)
        preds = outputs.argmax(dim=1).numpy()
        for j in range(min(8, data.size(0))):
            img = data[j].squeeze().numpy()
            examples.append(img)
            labels.append((int(target[j].item()), int(preds[j].item())))

# 绘制
plt.figure(figsize=(12, 6))
for idx, img in enumerate(examples):
    plt.subplot(2, 4, idx+1)
    plt.imshow(img, cmap='gray')
    gt, pr = labels[idx]
    plt.title(f'GT:{gt} Pred:{pr}')
    plt.axis('off')
plt.tight_layout()
plt.show()

## 本地手写窗口（Tkinter）

- 运行下一个单元将打开一个本地窗口。
- 操作：按住左键在白板上书写；点击“识别”进行推断；“清空”重置画布；“退出”关闭窗口。
- 说明：此窗口使用 Tkinter（Windows 通常自带）。若环境未安装 Tk 支持，可能无法启动。

In [None]:
# 本地手写窗口：Tkinter 画布 + LeNet-5 推理
import os
import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
import torch.nn.functional as F

# Tkinter 可能在某些环境不可用
try:
    import tkinter as tk
except Exception as e:
    print("未能导入 Tkinter：", e)
    raise

# 复用已定义的 LeNet5，如果此单元单独运行则做一次定义
try:
    LeNet5
except NameError:
    print("警告：LeNet-5 模型未定义")

# 加载模型（CPU 推理）
_device = torch.device('cpu')
_model = LeNet5().to(_device)
weights_path = './lenet_mnist.pth'
if os.path.exists(weights_path):    
    _model.load_state_dict(torch.load(weights_path, map_location=_device))
    _model.eval()
else:
    print('警告：未找到模型权重 ./lenet_mnist.pth，请先运行训练单元保存模型。')

_MEAN, _STD = 0.1307, 0.3081


def _preprocess_pil(pil_img: Image.Image) -> torch.Tensor:
    """将 PIL 图像转换为 1×1×32×32 标准化张量（与训练一致）。"""
    if pil_img.mode != 'L':
        pil_img = pil_img.convert('L')
    pil_img = pil_img.resize((28, 28), Image.NEAREST)
    arr = np.array(pil_img).astype(np.float32) / 255.0
    arr = 1.0 - arr  # 画布白底黑字 -> MNIST 黑底白字
    arr = (arr - _MEAN) / _STD
    arr = np.pad(arr, pad_width=((2, 2), (2, 2)), mode='constant', constant_values=0.0)
    ten = torch.from_numpy(arr)[None, None, :, :].to(_device)
    return ten


def _predict_from_pil(pil_img: Image.Image):
    if _model is None:
        return {str(i): 0.0 for i in range(10)}
    x = _preprocess_pil(pil_img)
    with torch.no_grad():
        logits = _model(x)
        probs = F.softmax(logits, dim=1).cpu().numpy()[0]
    return {str(i): float(probs[i]) for i in range(10)}


# === Tkinter 手写窗口 ===
CANVAS_SIZE = 280            # 画布像素大小（放大版）
BRUSH_WIDTH = 20             # 笔刷粗细

root = tk.Tk()
root.title('MNIST 手写数字识别 (LeNet-5) - 本地窗口')

canvas = tk.Canvas(root, width=CANVAS_SIZE, height=CANVAS_SIZE, bg='white')
canvas.pack(padx=8, pady=8)

# 用于推理的灰度图缓存（白底）
buffer_img = Image.new('L', (CANVAS_SIZE, CANVAS_SIZE), color=255)
buffer_draw = ImageDraw.Draw(buffer_img)

last_pos = {'x': None, 'y': None}


def on_button_press(event):
    last_pos['x'], last_pos['y'] = event.x, event.y


def on_move(event):
    lx, ly = last_pos['x'], last_pos['y']
    if lx is None or ly is None:
        last_pos['x'], last_pos['y'] = event.x, event.y
        return
    x, y = event.x, event.y
    canvas.create_line(lx, ly, x, y, width=BRUSH_WIDTH, fill='black', capstyle=tk.ROUND, smooth=True)
    buffer_draw.line([lx, ly, x, y], fill=0, width=BRUSH_WIDTH)
    last_pos['x'], last_pos['y'] = x, y


def on_button_release(event):
    last_pos['x'], last_pos['y'] = None, None


def clear_canvas():
    canvas.delete('all')
    buffer_draw.rectangle([(0, 0), (CANVAS_SIZE, CANVAS_SIZE)], fill=255)
    result_var.set('结果：')


def predict_canvas():
    probs = _predict_from_pil(buffer_img)
    # 取 Top-3 显示
    items = sorted([(int(k), v) for k, v in probs.items()], key=lambda kv: kv[1], reverse=True)[:3]
    text = '结果：' + '  '.join([f'{k}: {v*100:.2f}%' for k, v in items])
    result_var.set(text)


btn_frame = tk.Frame(root)
btn_frame.pack(fill='x', padx=8, pady=4)

tk.Button(btn_frame, text='识别', command=predict_canvas).pack(side='left', padx=4)
tk.Button(btn_frame, text='清空', command=clear_canvas).pack(side='left', padx=4)
tk.Button(btn_frame, text='退出', command=root.destroy).pack(side='right', padx=4)

result_var = tk.StringVar(value='结果：')
result_label = tk.Label(root, textvariable=result_var, anchor='w')
result_label.pack(fill='x', padx=8, pady=4)

canvas.bind('<ButtonPress-1>', on_button_press)
canvas.bind('<B1-Motion>', on_move)
canvas.bind('<ButtonRelease-1>', on_button_release)

# 启动窗口（注意：在某些 Jupyter 环境中主线程阻塞是预期行为）
try:
    root.mainloop()
except Exception as e:
    print('Tkinter 主循环启动失败：', e)
