# MNIST数据加载和预处理

## Notebook运行提示
- 代码已拆分为多个小单元, 按顺序运行即可在每一步观察输出与中间变量。
- 涉及 `Path(__file__)` 或相对路径的脚本会自动注入 `__file__` 解析逻辑, Notebook 环境下也能引用原项目资源。
- 可在每个单元下追加说明或参数试验记录, 以跟踪核心算法和数据处理步骤。


In [None]:
# Notebook路径自适应处理
import pathlib as _nb_pathlib
def _nb_resolve_file_path():
    if '__file__' not in globals():
        _cwd = _nb_pathlib.Path.cwd().resolve()
        for _candidate in (_cwd, *_cwd.parents):
            _potential = _candidate / '09-practical-projects/02_计算机视觉项目/01_MNIST手写数字识别_CNN入门/src/data.py'
            if _potential.exists():
                globals()['__file__'] = str(_potential)
                return
        globals()['__file__'] = str((_cwd / '09-practical-projects/02_计算机视觉项目/01_MNIST手写数字识别_CNN入门/src/data.py').resolve())
_nb_resolve_file_path()
del _nb_pathlib


In [None]:

import sys
from pathlib import Path
import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split

# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from utils.common import set_seed

In [None]:


def load_mnist_data(normalize=True):
    """
    加载MNIST数据集

    Args:
        normalize: 是否归一化

    Returns:
        (X_train, y_train), (X_test, y_test)
    """
    print("正在加载MNIST数据集...")

    # 从Keras加载数据
    (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

    print(f"✓ 训练集: {X_train.shape}, 标签: {y_train.shape}")
    print(f"✓ 测试集: {X_test.shape}, 标签: {y_test.shape}")

    # 添加通道维度
    X_train = X_train.reshape(-1, 28, 28, 1)
    X_test = X_test.reshape(-1, 28, 28, 1)

    # 归一化到[0, 1]
    if normalize:
        X_train = X_train.astype('float32') / 255.0
        X_test = X_test.astype('float32') / 255.0
        print("✓ 数据已归一化到[0, 1]")

    return (X_train, y_train), (X_test, y_test)

In [None]:


def prepare_data(test_size=0.1, random_state=42):
    """
    准备训练、验证和测试数据

    Args:
        test_size: 验证集比例
        random_state: 随机种子

    Returns:
        (X_train, y_train), (X_val, y_val), (X_test, y_test)
    """
    set_seed(random_state)

    # 加载数据
    (X_train, y_train), (X_test, y_test) = load_mnist_data()

    # 划分训练集和验证集
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train,
        test_size=test_size,
        random_state=random_state,
        stratify=y_train
    )

    print(f"\n数据划分:")
    print(f"  训练集: {X_train.shape}")
    print(f"  验证集: {X_val.shape}")
    print(f"  测试集: {X_test.shape}")

    return (X_train, y_train), (X_val, y_val), (X_test, y_test)

In [None]:


def create_data_augmentation():
    """
    创建数据增强层

    Returns:
        数据增强Sequential模型
    """
    data_augmentation = keras.Sequential([
        keras.layers.RandomRotation(0.1),
        keras.layers.RandomTranslation(0.1, 0.1),
        keras.layers.RandomZoom(0.1),
    ], name='data_augmentation')

    return data_augmentation

In [None]:


def get_class_distribution(y):
    """
    获取类别分布

    Args:
        y: 标签数组

    Returns:
        dict: 类别分布
    """
    unique, counts = np.unique(y, return_counts=True)
    distribution = dict(zip(unique, counts))

    print("\n类别分布:")
    total = len(y)
    for label, count in sorted(distribution.items()):
        percentage = count / total * 100
        print(f"  {label}: {count:5d} ({percentage:5.2f}%)")

    return distribution

In [None]:


def print_data_info(X, y, name='数据集'):
    """
    打印数据集信息

    Args:
        X: 特征数据
        y: 标签数据
        name: 数据集名称
    """
    print(f"\n{'=' * 60}")
    print(f"{name}信息")
    print(f"{'=' * 60}")
    print(f"形状: {X.shape}")
    print(f"数据类型: {X.dtype}")
    print(f"值范围: [{X.min():.3f}, {X.max():.3f}]")
    print(f"标签形状: {y.shape}")
    print(f"类别数: {len(np.unique(y))}")

    get_class_distribution(y)

In [None]:


if __name__ == '__main__':
    print("=" * 60)
    print("MNIST数据加载测试")
    print("=" * 60)

    # 加载数据
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = prepare_data()

    # 打印信息
    print_data_info(X_train, y_train, '训练集')
    print_data_info(X_val, y_val, '验证集')
    print_data_info(X_test, y_test, '测试集')

    print("\n✓ 数据加载测试完成！")