# NeuroTrain Dataset模块完整教程

本Notebook提供NeuroTrain Dataset模块的完整使用教程，包括：

1. 基础数据集加载
2. 混合数据集使用
3. 数据增强
4. DataLoader配置
5. 自定义数据集
6. 高级功能

## 环境准备

首先确保已安装所有依赖：

```bash
conda activate ntrain
uv pip install -e '.[cu128]'
```


In [None]:
# Import necessary libraries
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent))

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from src.dataset import (
    get_dataset,
    get_train_dataset,
    get_test_dataset,
    get_train_valid_test_dataloader,
    random_sample,
)

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. 基础数据集加载

### 1.1 CIFAR-10数据集

CIFAR-10是一个经典的图像分类数据集，包含10个类别的60000张32x32彩色图像。


In [None]:
# Configure CIFAR-10 dataset
config_cifar = {
    "dataset": {
        "name": "cifar10",
        "root_dir": "../data/cifar10",
        "train": True,
        "download": True,
    }
}

# Load dataset
cifar_dataset = get_dataset(config_cifar)

print(f"Dataset size: {len(cifar_dataset)}")
print(
    f"Classes: {cifar_dataset.classes if hasattr(cifar_dataset, 'classes') else 'N/A'}"
)

# Visualize samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    image, label = cifar_dataset[i]
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).numpy()
    ax.imshow(image)
    ax.set_title(f"Class: {cifar_dataset.classes[label]}")
    ax.axis("off")
plt.suptitle("CIFAR-10 Sample Images", fontsize=16)
plt.tight_layout()
plt.show()

### 1.2 医学图像数据集 (DRIVE)

DRIVE数据集用于视网膜血管分割任务。


In [None]:
# Configure DRIVE dataset
config_drive = {
    "dataset": {
        "name": "drive",
        "root_dir": "../data/drive",
        "is_rgb": True,
        "train_split": 0.8,
        "image_size": [512, 512],
    }
}

try:
    # Load training and test datasets
    train_dataset = get_train_dataset(config_drive)
    test_dataset = get_test_dataset(config_drive)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    # Visualize a sample
    if len(train_dataset) > 0:
        image, mask = train_dataset[0]

        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Original image
        if isinstance(image, torch.Tensor):
            img_np = image.permute(1, 2, 0).numpy()
        axes[0].imshow(img_np)
        axes[0].set_title("Retinal Image", fontsize=14)
        axes[0].axis("off")

        # Mask
        if isinstance(mask, torch.Tensor):
            mask_np = mask.squeeze().numpy()
        axes[1].imshow(mask_np, cmap="gray")
        axes[1].set_title("Vessel Mask", fontsize=14)
        axes[1].axis("off")

        # Overlay
        axes[2].imshow(img_np)
        axes[2].imshow(mask_np, alpha=0.5, cmap="Reds")
        axes[2].set_title("Overlay", fontsize=14)
        axes[2].axis("off")

        plt.suptitle("DRIVE Dataset Sample", fontsize=16)
        plt.tight_layout()
        plt.show()

except Exception as e:
    print(f"DRIVE dataset not available: {e}")
    print("Please ensure the DRIVE dataset is in the correct directory.")

In [None]:
# Configure dataset with augmentation
config_aug = {
    "dataset": {
        "name": "cifar10",
        "root_dir": "../data/cifar10",
        "train": True,
        "download": True,
        "augmentation": {
            "random_flip": True,
            "random_rotation": True,
            "rotation_range": 15,
            "brightness_range": [0.8, 1.2],
            "color_jitter": True,
        },
    }
}

dataset_aug = get_dataset(config_aug)

# Get the same sample multiple times to see different augmentations
sample_idx = 0
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(8):
    image, label = dataset_aug[sample_idx]
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).numpy()

    row, col = i // 4, i % 4
    axes[row, col].imshow(image)
    axes[row, col].set_title(f"Augmentation {i+1}")
    axes[row, col].axis("off")

plt.suptitle("Data Augmentation Examples (Same Original Image)", fontsize=16)
plt.tight_layout()
plt.show()

## 3. DataLoader使用

DataLoader提供批量数据加载、并行处理和数据混洗功能。


In [None]:
# Configure training parameters
config_loader = {
    "dataset": {
        "name": "cifar10",
        "root_dir": "../data/cifar10",
        "train": True,
        "download": True,
    },
    "training": {"batch_size": 64, "num_workers": 2},
}

# Get DataLoaders
train_loader, valid_loader, test_loader = get_train_valid_test_dataloader(config_loader)

print(f"Number of training batches: {len(train_loader)}")
if valid_loader:
    print(f"Number of validation batches: {len(valid_loader)}")
if test_loader:
    print(f"Number of test batches: {len(test_loader)}")

# Inspect a batch
for images, labels in train_loader:
    print(f"\nBatch information:")
    print(f"  Images shape: {images.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Image dtype: {images.dtype}")
    print(f"  Image value range: [{images.min():.3f}, {images.max():.3f}]")

    # Visualize a batch
    fig, axes = plt.subplots(4, 8, figsize=(20, 10))
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            img = images[i].permute(1, 2, 0).numpy()
            # Normalize for display
            img = (img - img.min()) / (img.max() - img.min())
            ax.imshow(img)
            ax.set_title(f"Label: {labels[i].item()}", fontsize=8)
            ax.axis("off")
    plt.suptitle("Batch Visualization", fontsize=16)
    plt.tight_layout()
    plt.show()

    break  # Only show first batch

## 4. 混合数据集

混合数据集允许同时使用多个数据集进行训练，支持不同的采样策略。


In [None]:
# Configure hybrid dataset
config_hybrid = {
    "dataset": {
        "name": "enhanced_hybrid",
        "datasets": ["drive", "medical/chasedb1"],
        "sampling_strategy": "weighted",
        "ratios": [0.6, 0.4],
        "weights": [1.0, 1.2],
        "drive": {"root_dir": "../data/drive", "is_rgb": True},
        "medical/chasedb1": {"root_dir": "../data/chasedb1", "is_rgb": True},
    }
}

try:
    hybrid_dataset = get_dataset(config_hybrid)
    print(f"Hybrid dataset total samples: {len(hybrid_dataset)}")
    print(f"Sampling strategy: {config_hybrid['dataset']['sampling_strategy']}")
    print(f"Dataset ratios: {config_hybrid['dataset']['ratios']}")
    print(f"Sample weights: {config_hybrid['dataset']['weights']}")

    # Sample from different sources
    print("\nSampling examples:")
    for i in range(min(5, len(hybrid_dataset))):
        image, mask = hybrid_dataset[i]
        source = (
            hybrid_dataset.get_source(i)
            if hasattr(hybrid_dataset, "get_source")
            else "unknown"
        )
        print(f"  Sample {i}: source={source}, image shape={image.shape}")

except Exception as e:
    print(f"Hybrid dataset not available: {e}")
    print("Please ensure all component datasets are available.")

## 5. 数据集统计分析

了解数据集的统计特性对于模型训练很重要。


In [None]:
# Analyze CIFAR-10 dataset statistics
def analyze_dataset(dataset, num_samples=1000):
    """Analyze dataset statistics"""
    print(f"Analyzing dataset (using {num_samples} samples)...")

    # Sample randomly
    indices = np.random.choice(
        len(dataset), min(num_samples, len(dataset)), replace=False
    )

    # Collect statistics
    images = []
    labels = []

    for idx in indices:
        image, label = dataset[idx]
        if isinstance(image, torch.Tensor):
            images.append(image)
        labels.append(label)

    images = torch.stack(images)
    labels = np.array(labels)

    # Calculate statistics
    mean = images.mean(dim=[0, 2, 3])
    std = images.std(dim=[0, 2, 3])

    print(f"\nImage Statistics:")
    print(f"  Mean (RGB): [{mean[0]:.4f}, {mean[1]:.4f}, {mean[2]:.4f}]")
    print(f"  Std (RGB):  [{std[0]:.4f}, {std[1]:.4f}, {std[2]:.4f}]")
    print(f"  Value range: [{images.min():.4f}, {images.max():.4f}]")

    # Class distribution
    unique, counts = np.unique(labels, return_counts=True)
    print(f"\nClass Distribution:")
    for cls, count in zip(unique, counts):
        print(f"  Class {cls}: {count} samples ({count/len(labels)*100:.1f}%)")

    # Visualize class distribution
    plt.figure(figsize=(10, 5))
    plt.bar(unique, counts, color="steelblue")
    plt.xlabel("Class", fontsize=12)
    plt.ylabel("Number of Samples", fontsize=12)
    plt.title("Class Distribution", fontsize=14)
    plt.xticks(unique)
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    plt.show()

    return mean, std


# Analyze CIFAR-10
mean, std = analyze_dataset(cifar_dataset, num_samples=5000)

## 6. 总结

本教程介绍了NeuroTrain Dataset模块的主要功能：

1. ✅ 加载标准和医学图像数据集
2. ✅ 配置数据增强
3. ✅ 使用DataLoader进行批量加载
4. ✅ 创建混合数据集
5. ✅ 分析数据集统计信息

### 下一步

- 查看其他教程了解模型训练
- 阅读API文档了解更多选项
- 尝试自定义数据集类

### 参考资料

- [Dataset模块文档](../docs/DATASET_MODULE.md)
- [项目架构文档](../docs/ARCHITECTURE.md)
- [API参考](../docs/API_REFERENCE.md)
