In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# 设置中文字体显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 读取MNIST数据集的函数
def read_mnist_images(filename):
    """读取MNIST图像数据，并确保数组可写"""
    with open(filename, 'rb') as f:
        # 读取文件头信息
        magic = int.from_bytes(f.read(4), byteorder='big')
        num_images = int.from_bytes(f.read(4), byteorder='big')
        rows = int.from_bytes(f.read(4), byteorder='big')
        cols = int.from_bytes(f.read(4), byteorder='big')
        
        # 读取图像数据，并确保数组可写（解决PyTorch警告）
        images = np.frombuffer(f.read(), dtype=np.uint8)
        images = images.reshape(num_images, rows, cols)
        images = images.copy()  # 关键：创建可写副本
    
    return images

def read_mnist_labels(filename):
    """读取MNIST标签数据"""
    with open(filename, 'rb') as f:
        magic = int.from_bytes(f.read(4), byteorder='big')
        num_labels = int.from_bytes(f.read(4), byteorder='big')
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    
    return labels

# 自定义数据集类
class MNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# 定义简单的卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(16 * 14 * 14, 10)
        self.layer_outputs = {}  # 存储各层输出
        
    def forward(self, x):
        self.layer_outputs['input'] = x
        x = self.conv1(x)
        self.layer_outputs['conv1'] = x
        x = self.pool(x)
        self.layer_outputs['pool'] = x
        x = self.relu(x)
        self.layer_outputs['relu'] = x
        x = x.view(-1, 16 * 14 * 14)
        x = self.fc(x)
        self.layer_outputs['output'] = x
        return x

# 可视化每一层的输出（优化资源释放）
def visualize_layers(model, image, label, idx):
    model.eval()
    try:
        with torch.no_grad():
            output = model(image.unsqueeze(0))  # 添加批次维度
        
        # 创建图像并限制大小，避免内存占用过高
        fig, axes = plt.subplots(2, 3, figsize=(12, 8))  # 缩小图像尺寸
        fig.suptitle(f'样本 {idx} - 真实标签: {label}', fontsize=14)
        
        # 1. 输入图像
        axes[0, 0].imshow(image.squeeze().numpy(), cmap='gray')
        axes[0, 0].set_title('输入图像')
        axes[0, 0].axis('off')
        
        # 2. 卷积层输出（取前4通道平均）
        conv_output = model.layer_outputs['conv1'].squeeze().numpy()
        conv_mean = np.mean(conv_output[:4], axis=0)
        axes[0, 1].imshow(conv_mean, cmap='viridis')
        axes[0, 1].set_title('卷积层输出')
        axes[0, 1].axis('off')
        
        # 3. 池化层输出
        pool_output = model.layer_outputs['pool'].squeeze().numpy()
        pool_mean = np.mean(pool_output[:4], axis=0)
        axes[0, 2].imshow(pool_mean, cmap='viridis')
        axes[0, 2].set_title('池化层输出')
        axes[0, 2].axis('off')
        
        # 4. 激活层输出
        relu_output = model.layer_outputs['relu'].squeeze().numpy()
        relu_mean = np.mean(relu_output[:4], axis=0)
        axes[1, 0].imshow(relu_mean, cmap='viridis')
        axes[1, 0].set_title('激活层输出')
        axes[1, 0].axis('off')
        
        # 5. 输出层预测
        output = model.layer_outputs['output'].squeeze().numpy()
        axes[1, 1].bar(range(10), output)
        axes[1, 1].set_title('预测得分')
        axes[1, 1].set_xticks(range(10))
        axes[1, 1].set_xlabel('数字')
        
        # 6. 预测结果
        predicted = np.argmax(output)
        axes[1, 2].text(0.5, 0.5, f'预测: {predicted}\n真实: {label}', 
                       fontsize=14, ha='center', va='center')
        axes[1, 2].set_title('结果对比')
        axes[1, 2].axis('off')
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        plt.show()
        
        # 打印层信息
        print(f"\n===== 样本 {idx} 层信息 =====")
        print(f"输入形状: {model.layer_outputs['input'].shape}")
        print(f"卷积层输出: {model.layer_outputs['conv1'].shape}")
        print(f"池化层输出: {model.layer_outputs['pool'].shape}")
        print(f"激活层输出: {model.layer_outputs['relu'].shape}")
        print(f"输出层输出: {model.layer_outputs['output'].shape}")
        print(f"真实标签: {label}, 预测标签: {predicted}")
        print("===========================\n")
        
    except Exception as e:
        print(f"可视化出错: {str(e)}")
    finally:
        # 强制释放图像资源（关键：解决内存泄漏）
        plt.close('all')

def main():
    # 读取数据
    print("正在读取MNIST数据集...")
    train_images = read_mnist_images('数据集/train-images.idx3-ubyte')
    train_labels = read_mnist_labels('数据集/train-labels.idx1-ubyte')
    print(f"训练集大小: 图像 {train_images.shape}, 标签 {train_labels.shape}")
    
    # 数据转换（简化转换，减少计算负担）
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 降低批次大小，减少内存占用
    train_dataset = MNISTDataset(train_images, train_labels, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  # 批次从64改为32
    
    # 初始化模型
    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 训练模型（减少epoch，先验证稳定性）
    num_epochs = 2  # 从3改为2，先测试是否崩溃
    print(f"\n开始训练模型，共 {num_epochs} 个epoch...")
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for i, (images, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # 每200个批次打印一次（减少打印频率）
            if (i + 1) % 200 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/200:.4f}')
                running_loss = 0.0
        
        # 每个epoch只可视化2个样本（减少资源占用）
        print(f"\nEpoch {epoch+1} 训练完成，可视化样本...")
        model.eval()
        with torch.no_grad():
            indices = np.random.choice(len(train_dataset), 2, replace=False)  # 从5个减为2个
            for idx in indices:
                image, label = train_dataset[idx]
                visualize_layers(model, image, label, idx)
    
    print("\n模型训练完成！")

if __name__ == "__main__":
    main()

正在读取MNIST数据集...
训练集大小: 图像 (60000, 28, 28), 标签 (60000,)

开始训练模型，共 2 个epoch...
Epoch [1/2], Step [200/1875], Loss: 0.5218
Epoch [1/2], Step [400/1875], Loss: 0.2598
Epoch [1/2], Step [600/1875], Loss: 0.2072
Epoch [1/2], Step [800/1875], Loss: 0.1566
Epoch [1/2], Step [1000/1875], Loss: 0.1392
Epoch [1/2], Step [1200/1875], Loss: 0.1348
Epoch [1/2], Step [1400/1875], Loss: 0.1192
Epoch [1/2], Step [1600/1875], Loss: 0.0981
Epoch [1/2], Step [1800/1875], Loss: 0.1119


: 