# 02 - 模型训练：三种输入模式 × 八种模型

本notebook支持三种输入特征格式和八种深度学习模型的灵活组合。

## 三种输入模式

| 输入类型 | 形状 | 说明 | 适用模型 |
|---------|------|------|---------|
| `features` | (16,) | 16维统计特征（手工提取） | MLP, PINN |
| `sequence` | (T, 3) | 插值时序网格（单循环充电曲线） | CNN1D, LSTM, Transformer, FNO |
| `image` | (H, W, 3) | 热力图矩阵（多循环堆叠） | CNN2D, ViT |

## 八种模型

### 统计特征输入 (features)
- **MLP**: 多层感知机，轻量高效
- **PINN**: 物理信息神经网络，带物理约束

### 时序输入 (sequence)
- **CNN1D**: 一维卷积，提取局部特征
- **LSTM/BiLSTM**: 循环网络，建模时序依赖
- **Transformer**: 注意力机制，捕获长距离依赖
- **FNO**: 傅里叶神经算子，学习算子映射

### 热力图输入 (image)
- **CNN2D**: 二维卷积，提取空间特征
- **ViT**: Vision Transformer，全局注意力

## 特征说明

- **3通道**: 电压(V)、电流(I)、时间(s)
- **协议无关**: 适用于CC、CC-CV、多阶段等充电协议

**前置要求**: 运行 `00_data_preprocessing.ipynb` 完成数据预处理

## 配置选项

修改下方配置单元格来选择：
1. **数据集** - MATR, CALCE, HUST, XJTU, TJU, NASA, RWTH
2. **输入类型** - features (16维特征), sequence (时序), image (热力图)
3. **模型** - 根据输入类型选择兼容的模型

In [None]:
# ==================== 配置选项 ====================
# 修改以下参数来配置训练流程

# 1. 数据集选择
DATASET_NAME = 'MATR'  # 可选: 'MATR', 'CALCE', 'HUST', 'XJTU', 'TJU', 'NASA', 'RWTH'

# 2. 输入类型选择（三种核心输入方式）
INPUT_TYPE = 'sequence'  # 可选:
                         # ============================================
                         # 'features'  - 16维统计特征（手工提取）
                         #               形状: (16,)
                         #               适用模型: MLP, PINN
                         # --------------------------------------------
                         # 'sequence'  - 插值时序网格（单循环充电曲线）
                         #               形状: (num_samples, 3) 如 (200, 3)
                         #               通道: V, I, time
                         #               适用模型: CNN1D, LSTM, Transformer, FNO
                         # --------------------------------------------
                         # 'image'     - 热力图矩阵（多循环堆叠）
                         #               形状: (window_size, num_samples, 3) 如 (100, 200, 3)
                         #               适用模型: CNN2D, ViT
                         # ============================================

# 3. 模型选择（根据输入类型选择兼容的模型）
MODEL_NAME = 'lstm'      # 可选:
                         # ============================================
                         # 输入类型 'features' 适用:
                         #   'mlp'         - 多层感知机（推荐，轻量高效）
                         #   'pinn'        - 物理信息神经网络（带物理约束）
                         # --------------------------------------------
                         # 输入类型 'sequence' 适用:
                         #   'lstm'        - LSTM网络（推荐，效果稳定）
                         #   'bilstm'      - 双向LSTM
                         #   'cnn1d'       - 一维卷积网络
                         #   'transformer' - Transformer（捕获长距离依赖）
                         #   'fno'         - 傅里叶神经算子
                         # --------------------------------------------
                         # 输入类型 'image' 适用:
                         #   'cnn2d'       - 二维卷积网络（推荐）
                         #   'vit'         - Vision Transformer
                         # ============================================

# 输入类型与模型的兼容性映射
INPUT_MODEL_COMPATIBILITY = {
    'features': ['mlp', 'pinn'],
    'sequence': ['lstm', 'bilstm', 'cnn1d', 'transformer', 'fno'],
    'image': ['cnn2d', 'vit'],
}

# 4. 训练超参数
BATCH_SIZE = 32          # 批大小
HIDDEN_SIZE = 128        # 隐藏层大小（MLP/LSTM/Transformer等）
NUM_LAYERS = 2           # 网络层数
DROPOUT = 0.1            # Dropout比率
LEARNING_RATE = 1e-3     # 学习率
NUM_EPOCHS = 100         # 训练轮数
PATIENCE = 20            # 早停耐心值

# 5. 数据处理参数
NUM_SAMPLES = 200        # 每循环采样点数（sequence/image模式）
WINDOW_SIZE = 100        # 热力图窗口大小（image模式）
TRAIN_RATIO = 0.7        # 训练集比例
VAL_RATIO = 0.15         # 验证集比例

# 6. 目标选择
TARGET_TYPE = 'both'      # 可选: 'soh', 'rul', 'both'

# ==================== 配置验证 ====================

# 检查输入类型与模型兼容性
if MODEL_NAME not in INPUT_MODEL_COMPATIBILITY.get(INPUT_TYPE, []):
    compatible_models = INPUT_MODEL_COMPATIBILITY.get(INPUT_TYPE, [])
    raise ValueError(
        f"模型 '{MODEL_NAME}' 与输入类型 '{INPUT_TYPE}' 不兼容!\n"
        f"输入类型 '{INPUT_TYPE}' 支持的模型: {compatible_models}"
    )

print("=" * 60)
print("训练配置:")
print("=" * 60)
print(f"  数据集: {DATASET_NAME}")
print(f"  输入类型: {INPUT_TYPE}")
print(f"    - features: 16维统计特征 → MLP/PINN")
print(f"    - sequence: 插值网格 → CNN1D/LSTM/Transformer/FNO")
print(f"    - image: 热力图矩阵 → CNN2D/ViT")
print(f"  选择模型: {MODEL_NAME}")
print(f"  预测目标: {TARGET_TYPE}")
print(f"  批大小: {BATCH_SIZE}")
print(f"  训练轮数: {NUM_EPOCHS}")
print("=" * 60)

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
matplotlib.rcParams['axes.unicode_minus'] = False
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# 数据加载
from src.data import load_processed_batteries
from src.data.dataset import BatteryDataset, create_dataloaders

# 所有模型
from src.models import (
    # 统计特征输入 (16维)
    MLP, PINN,
    # 时序输入 (T, 3)
    CNN1D, LSTM, BiLSTM, TransformerModel, FNO,
    # 热力图输入 (H, W, 3)
    CNN2D, ViT, SimpleViT,
)
from src.models.base import BaseModel

# 设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'使用设备: {device}')

# 结果保存目录
CHECKPOINT_DIR = Path('../results/checkpoints')
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# 模型创建函数（根据输入类型和模型名称）
def create_model(model_name: str, input_type: str, **kwargs) -> nn.Module:
    """
    根据模型名称和输入类型创建模型
    
    Args:
        model_name: 模型名称
        input_type: 输入类型 ('features', 'sequence', 'image')
        **kwargs: 额外参数
    
    Returns:
        创建的模型实例
    """
    model_name = model_name.lower()
    
    # 默认参数
    defaults = {
        'features': {'input_dim': 16, 'hidden_dims': [64, 32], 'output_dim': 1},
        'sequence': {'in_channels': 3, 'hidden_size': 128, 'num_layers': 2, 'output_dim': 1},
        'image': {'in_channels': 3, 'hidden_channels': 32, 'output_dim': 1},
    }
    
    # 合并默认参数和用户参数
    params = {**defaults.get(input_type, {}), **kwargs}
    
    if model_name == 'mlp':
        return MLP(
            input_dim=params.get('input_dim', 16),
            hidden_dims=params.get('hidden_dims', [64, 32]),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'pinn':
        return PINN(
            input_dim=params.get('input_dim', 16),
            hidden_dims=params.get('hidden_dims', [64, 32]),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'lstm':
        return LSTM(
            in_channels=params.get('in_channels', 3),
            hidden_size=params.get('hidden_size', 128),
            num_layers=params.get('num_layers', 2),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'bilstm':
        return BiLSTM(
            in_channels=params.get('in_channels', 3),
            hidden_size=params.get('hidden_size', 128),
            num_layers=params.get('num_layers', 2),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'cnn1d':
        return CNN1D(
            in_channels=params.get('in_channels', 3),
            hidden_channels=params.get('hidden_channels', 64),
            num_layers=params.get('num_layers', 3),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'transformer':
        return TransformerModel(
            in_channels=params.get('in_channels', 3),
            d_model=params.get('hidden_size', 64),
            nhead=params.get('nhead', 4),
            num_layers=params.get('num_layers', 2),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'fno':
        return FNO(
            in_channels=params.get('in_channels', 3),
            width=params.get('hidden_size', 32),
            modes=params.get('modes', 16),
            num_layers=params.get('num_layers', 4),
            output_dim=params.get('output_dim', 1),
        )
    elif model_name == 'cnn2d':
        return CNN2D(
            in_channels=params.get('in_channels', 3),
            hidden_channels=params.get('hidden_channels', 32),
            num_layers=params.get('num_layers', 3),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    elif model_name == 'vit':
        return ViT(
            img_height=params.get('window_size', 100),
            img_width=params.get('num_samples', 200),
            in_channels=params.get('in_channels', 3),
            embed_dim=params.get('hidden_size', 128),
            num_heads=params.get('nhead', 4),
            num_layers=params.get('num_layers', 4),
            output_dim=params.get('output_dim', 1),
            dropout=params.get('dropout', 0.1),
        )
    else:
        raise ValueError(f"未知模型: {model_name}")

print("模型和数据集已加载")

## 1. 加载数据

In [None]:
# 加载数据集
batteries = load_processed_batteries(f'../data/processed/{DATASET_NAME}')

print(f"数据集: {DATASET_NAME}")
print(f"电池数: {len(batteries)}")

# 统计有效电池（已达EOL）
eol_count = sum(1 for b in batteries if b.eol_cycle is not None)
print(f"已达EOL: {eol_count}/{len(batteries)}")

# 显示示例电池信息
bat = batteries[0]
print(f"\n示例电池: {bat.cell_id}")
print(f"  循环数: {len(bat)}")
soh_arr = bat.get_soh_array()
if len(soh_arr) > 0:
    print(f"  SOH范围: {soh_arr.min():.3f} - {soh_arr.max():.3f}")
print(f"  EOL循环: {bat.eol_cycle}")

# 显示充电曲线示例
cycle = bat.cycles[10] if len(bat) > 10 else bat.cycles[0]
print(f"\n示例循环 #{cycle.cycle_number}:")
if cycle.voltage is not None:
    print(f"  电压点数: {len(cycle.voltage)}")
if cycle.current is not None:
    print(f"  电流点数: {len(cycle.current)}")
if cycle.time is not None:
    print(f"  时间点数: {len(cycle.time)}")

## 2. 创建数据集和加载器

使用 `random_truncate_train=True` 启用随机截取训练策略。

In [None]:
# 根据输入类型创建数据加载器
# ============================================
# INPUT_TYPE 决定特征提取方式：
#   'features'  - 16维统计特征
#   'sequence'  - 单循环时序 (num_samples, 3)
#   'image'     - 多循环热力图 (window_size, num_samples, 3)
# ============================================

train_loader, val_loader, test_loader = create_dataloaders(
    batteries=batteries,
    input_type=INPUT_TYPE,           # 输入类型：features/sequence/image
    target_type=TARGET_TYPE,         # 预测目标：soh/rul/both
    window_size=WINDOW_SIZE,         # 热力图窗口大小（image模式）
    num_samples=NUM_SAMPLES,         # 每循环采样点数
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    batch_size=BATCH_SIZE,
    seed=42,
)

print(f"\n训练批次: {len(train_loader)}")
print(f"验证样本: {len(val_loader.dataset)}")
print(f"测试样本: {len(test_loader.dataset)}")

In [None]:
# 查看训练批次
batch = next(iter(train_loader))

print(f"输入类型: {INPUT_TYPE}")
print(f"模型: {MODEL_NAME}")
print("-" * 40)
print("批次内容:")

# 根据输入类型显示不同的形状信息
feature_shape = batch['feature'].shape
if INPUT_TYPE == 'features':
    print(f"  feature: {feature_shape}  # (batch, 16) - 16维统计特征")
elif INPUT_TYPE == 'sequence':
    print(f"  feature: {feature_shape}  # (batch, num_samples, channels) - 单循环时序")
elif INPUT_TYPE == 'image':
    print(f"  feature: {feature_shape}  # (batch, window, samples, channels) - 热力图")

print(f"  label: {batch['label'].shape}")
print(f"  soh: {batch['soh'].shape}")

if 'rul' in batch:
    print(f"  rul: {batch['rul'].shape}")
    valid_rul = batch['rul'][batch['rul'] >= 0]
    print(f"  有效RUL数量: {len(valid_rul)}/{len(batch['rul'])}")
    if len(valid_rul) > 0:
        print(f"  有效RUL范围: [{valid_rul.min():.0f}, {valid_rul.max():.0f}]")

print(f"\n标签统计:")
print(f"  SOH范围: [{batch['soh'].min():.3f}, {batch['soh'].max():.3f}]")

## 3. RUL标签诊断

In [None]:
# 检查训练集中的RUL标签分布
all_rul = []
all_soh = []
for batch in train_loader:
    if 'rul' in batch:
        # 同时收集有效RUL及其对应的SOH
        valid_mask = batch['rul'] >= 0
        valid_rul = batch['rul'][valid_mask]
        valid_soh = batch['soh'][valid_mask]
        all_rul.extend(valid_rul.tolist())
        all_soh.extend(valid_soh.tolist())

if all_rul:
    all_rul = np.array(all_rul)
    all_soh = np.array(all_soh)
    print(f"训练集RUL统计:")
    print(f"  有效RUL样本数: {len(all_rul)}")
    print(f"  RUL范围: [{all_rul.min():.0f}, {all_rul.max():.0f}]")
    print(f"  RUL均值: {all_rul.mean():.1f}")
    print(f"  RUL中位数: {np.median(all_rul):.1f}")
    
    # 可视化RUL分布
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].hist(all_rul, bins=50, alpha=0.7, color='steelblue')
    axes[0].set_xlabel('RUL (循环)')
    axes[0].set_ylabel('频数')
    axes[0].set_title('训练集RUL分布')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].scatter(all_soh, all_rul, alpha=0.3, s=5)
    axes[1].set_xlabel('SOH')
    axes[1].set_ylabel('RUL (循环)')
    axes[1].set_title('SOH vs RUL关系')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("警告: 训练集中没有有效的RUL标签！")
    print("请检查数据预处理，确保电池已计算EOL。")

## 4. 创建和训练模型

In [None]:
# 创建模型（根据输入类型和模型名称）
# ============================================
# 输入类型与模型对应关系：
#   'features' (16,)        → MLP, PINN
#   'sequence' (T, 3)       → CNN1D, LSTM, BiLSTM, Transformer, FNO
#   'image' (H, W, 3)       → CNN2D, ViT
# ============================================

model = create_model(
    model_name=MODEL_NAME,
    input_type=INPUT_TYPE,
    # 通用参数
    in_channels=3,              # V, I, time 三通道
    hidden_size=HIDDEN_SIZE,
    hidden_channels=HIDDEN_SIZE // 2,  # CNN用
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    output_dim=1,
    # 图像模型参数
    window_size=WINDOW_SIZE,
    num_samples=NUM_SAMPLES,
).to(device)

# 模型信息
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("=" * 50)
print("模型配置:")
print("=" * 50)
print(f"  模型名称: {MODEL_NAME}")
print(f"  输入类型: {INPUT_TYPE}")
print(f"  模型类型: {type(model).__name__}")
print(f"  总参数量: {total_params:,}")
print(f"  可训练参数: {trainable_params:,}")
print("=" * 50)

In [None]:
# 训练模型
# 通用训练函数
def train_model(model, train_loader, val_loader, epochs, lr, patience, save_path, device):
    """训练模型"""
    from tqdm import tqdm
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10, verbose=True
    )
    criterion = nn.MSELoss()
    
    history = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in tqdm(range(epochs), desc='训练中'):
        # 训练阶段
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            features = batch['feature'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            outputs = model(features).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                features = batch['feature'].to(device)
                labels = batch['label'].to(device)
                outputs = model(features).squeeze()
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)
        scheduler.step(val_loss)
        
        # 早停检查
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\n早停于 epoch {epoch + 1}")
                break
        
        if (epoch + 1) % 10 == 0:
            tqdm.write(f"Epoch {epoch+1}: train_loss={train_loss:.6f}, val_loss={val_loss:.6f}")
    
    # 加载最佳模型
    model.load_state_dict(torch.load(save_path))
    return history

# 训练模型
save_path = str(CHECKPOINT_DIR / f'{MODEL_NAME}_{DATASET_NAME}_{INPUT_TYPE}.pth')
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    patience=PATIENCE,
    save_path=save_path,
    device=device,
)
print(f"\n模型已保存: {save_path}")

In [None]:
# 可视化训练过程
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 总损失
ax = axes[0]
ax.plot(history['train_loss'], label='训练')
if history['val_loss']:
    ax.plot(history['val_loss'], label='验证')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('总损失')
ax.legend()
ax.grid(True, alpha=0.3)

# SOH损失
ax = axes[1]
ax.plot(history['train_soh_loss'])
ax.set_xlabel('Epoch')
ax.set_ylabel('SOH Loss')
ax.set_title('SOH损失')
ax.grid(True, alpha=0.3)

# RUL损失
ax = axes[2]
ax.plot(history['train_rul_loss'])
ax.set_xlabel('Epoch')
ax.set_ylabel('RUL Loss')
ax.set_title('RUL损失')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. 模型评估

In [None]:
# 通用评估函数
def evaluate_model(model, test_loader, device):
    """评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            features = batch['feature'].to(device)
            labels = batch['label'].to(device)
            outputs = model(features).squeeze()
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    preds = np.array(all_preds)
    labels = np.array(all_labels)
    
    # 计算指标
    mse = np.mean((preds - labels) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(preds - labels))
    
    # 过滤有效值（排除-1等无效标签）
    valid_mask = labels >= 0
    if valid_mask.sum() > 0:
        valid_preds = preds[valid_mask]
        valid_labels = labels[valid_mask]
        mape = np.mean(np.abs((valid_labels - valid_preds) / (valid_labels + 1e-8))) * 100
    else:
        mape = 0
    
    return {
        'rmse': rmse,
        'mae': mae,
        'mape': mape,
        'predictions': preds,
        'labels': labels,
    }

# 在测试集上评估
metrics = evaluate_model(model, test_loader, device)

print("=" * 50)
print(f"测试集评估结果 ({MODEL_NAME} + {INPUT_TYPE})")
print("=" * 50)
print(f"  RMSE: {metrics['rmse']:.4f}")
print(f"  MAE:  {metrics['mae']:.4f}")
print(f"  MAPE: {metrics['mape']:.2f}%")
print("=" * 50)

In [None]:
# 预测结果可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 散点图：预测 vs 真实
ax = axes[0]
ax.scatter(metrics['labels'], metrics['predictions'], alpha=0.3, s=10)

# 添加理想线
min_val = min(metrics['labels'].min(), metrics['predictions'].min())
max_val = max(metrics['labels'].max(), metrics['predictions'].max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='理想线')

ax.set_xlabel('真实值')
ax.set_ylabel('预测值')
ax.set_title(f'{TARGET_TYPE.upper()} 预测 (RMSE={metrics["rmse"]:.4f})')
ax.legend()
ax.grid(True, alpha=0.3)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim(0.6, 1.05)
ax.set_ylim(0.6, 1.05)

# RUL预测
if 'rul_pred' in results and 'rul_label' in results:
    ax = axes[1]
    ax.scatter(results['rul_label'], results['rul_pred'], alpha=0.5, s=10)
    max_rul = max(results['rul_label'].max(), results['rul_pred'].max())
    ax.plot([0, max_rul], [0, max_rul], 'r--', linewidth=2, label='理想线')
    ax.set_xlabel('真实RUL (循环)')
    ax.set_ylabel('预测RUL (循环)')
    ax.set_title(f'RUL预测 (RMSE={metrics.get("rul_rmse", 0):.1f})')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# 预测误差分布
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# SOH误差分布
soh_error = results['soh_pred'] - results['soh_label']
ax = axes[0]
ax.hist(soh_error, bins=50, alpha=0.7, color='steelblue')
ax.axvline(x=0, color='r', linestyle='--')
ax.set_xlabel('SOH预测误差')
ax.set_ylabel('频数')
ax.set_title(f'SOH误差分布 (均值={soh_error.mean():.4f}, 标准差={soh_error.std():.4f})')
ax.grid(True, alpha=0.3)

# RUL误差分布
if 'rul_pred' in results and 'rul_label' in results:
    rul_error = results['rul_pred'] - results['rul_label']
    ax = axes[1]
    ax.hist(rul_error, bins=50, alpha=0.7, color='coral')
    ax.axvline(x=0, color='r', linestyle='--')
    ax.set_xlabel('RUL预测误差 (循环)')
    ax.set_ylabel('频数')
    ax.set_title(f'RUL误差分布 (均值={rul_error.mean():.1f}, 标准差={rul_error.std():.1f})')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. 不同模型架构比较

In [None]:
# 比较同一输入类型下的不同模型
# 根据当前输入类型选择可比较的模型
models_to_compare = INPUT_MODEL_COMPATIBILITY[INPUT_TYPE]
results_compare = {}

print(f"输入类型: {INPUT_TYPE}")
print(f"比较模型: {models_to_compare}")
print("-" * 50)

for model_name in models_to_compare:
    print(f"\n训练 {model_name} 模型...")
    
    try:
        model_cmp = create_model(
            model_name=model_name,
            input_type=INPUT_TYPE,
            in_channels=3,
            hidden_size=HIDDEN_SIZE,
            hidden_channels=HIDDEN_SIZE // 2,
            num_layers=NUM_LAYERS,
            dropout=DROPOUT,
            output_dim=1,
            window_size=WINDOW_SIZE,
            num_samples=NUM_SAMPLES,
        ).to(device)
        
        # 训练（使用较少epoch进行快速比较）
        save_path_cmp = str(CHECKPOINT_DIR / f'{model_name}_{DATASET_NAME}_{INPUT_TYPE}_cmp.pth')
        _ = train_model(
            model=model_cmp,
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=30,  # 比较时使用较少epoch
            lr=LEARNING_RATE,
            patience=10,
            save_path=save_path_cmp,
            device=device,
        )
        
        # 评估
        metrics_cmp = evaluate_model(model_cmp, test_loader, device)
        results_compare[model_name] = metrics_cmp
        
        print(f"  RMSE: {metrics_cmp['rmse']:.4f}")
        print(f"  MAE: {metrics_cmp['mae']:.4f}")
        
    except Exception as e:
        print(f"  跳过 {model_name}: {e}")

In [None]:
# 可视化比较
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

names = list(results_compare.keys())
soh_rmse = [results_compare[n]['soh_rmse'] for n in names]
rul_rmse = [results_compare[n].get('rul_rmse', 0) for n in names]
x = np.arange(len(names))

axes[0].bar(x, soh_rmse, color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(names)
axes[0].set_ylabel('SOH RMSE')
axes[0].set_title('SOH预测精度')
axes[0].grid(True, alpha=0.3, axis='y')

axes[1].bar(x, rul_rmse, color='coral')
axes[1].set_xticks(x)
axes[1].set_xticklabels(names)
axes[1].set_ylabel('RUL RMSE (循环)')
axes[1].set_title('RUL预测精度')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# 打印比较表
print("\n===== 模型比较 =====")
print(f"{'模型':<15} {'SOH RMSE':<12} {'RUL RMSE':<12}")
print("-" * 39)
for name in names:
    m = results_compare[name]
    rul = f"{m.get('rul_rmse', 0):.1f}" if 'rul_rmse' in m else '-'
    print(f"{name:<15} {m['soh_rmse']:<12.4f} {rul:<12}")

## 7. 单电池预测轨迹可视化

In [None]:
# 选择一个测试电池，展示完整预测轨迹
# 选择一个测试电池展示完整预测轨迹
test_batteries = [batteries[i] for i in np.random.permutation(len(batteries))[-5:]]
selected_battery = test_batteries[0]

# 创建该电池所有循环的数据集
from src.data.dataset import BatteryDataset

trajectory_dataset = BatteryDataset(
    batteries=[selected_battery],
    input_type=INPUT_TYPE,
    target_type=TARGET_TYPE,
    window_size=WINDOW_SIZE,
    num_samples=NUM_SAMPLES,
    min_cycle_for_prediction=WINDOW_SIZE if INPUT_TYPE == 'image' else 10,
)

full_loader = torch.utils.data.DataLoader(full_dataset, batch_size=32, shuffle=False)

# 预测
pred_results = model.predict_batch(full_loader)

# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# SOH轨迹
ax = axes[0]
ax.plot(pred_results['soh_label'], 'b-', linewidth=2, label='真实SOH')
ax.plot(pred_results['soh_pred'], 'r--', linewidth=2, label='预测SOH')
ax.axhline(y=0.8, color='green', linestyle=':', label='EOL阈值')
ax.set_xlabel('样本索引')
ax.set_ylabel('SOH')
ax.set_title('SOH预测轨迹')
ax.legend()
ax.grid(True, alpha=0.3)

# RUL轨迹
if 'rul_pred' in pred_results and 'rul_label' in pred_results:
    ax = axes[1]
    ax.plot(pred_results['rul_label'], 'b-', linewidth=2, label='真实RUL')
    ax.plot(pred_results['rul_pred'], 'r--', linewidth=2, label='预测RUL')
    ax.set_xlabel('样本索引')
    ax.set_ylabel('RUL (循环)')
    ax.set_title('RUL预测轨迹')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. 保存和加载模型

In [None]:
# 保存模型（命名规则：{模型}_{数据集}_{输入类型}_final.pth）
final_save_path = CHECKPOINT_DIR / f'{MODEL_NAME}_{DATASET_NAME}_{INPUT_TYPE}_final.pth'
torch.save(model.state_dict(), str(final_save_path))
print(f"模型已保存: {final_save_path}")

# 加载模型验证
model_loaded = create_model(
    model_name=MODEL_NAME,
    input_type=INPUT_TYPE,
    in_channels=3,
    hidden_size=HIDDEN_SIZE,
    hidden_channels=HIDDEN_SIZE // 2,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    output_dim=1,
    window_size=WINDOW_SIZE,
    num_samples=NUM_SAMPLES,
).to(device)

model_loaded.load_state_dict(torch.load(str(final_save_path)))

# 验证加载的模型
metrics_loaded = evaluate_model(model_loaded, test_loader, device)
print(f"加载模型验证 - RMSE: {metrics_loaded['rmse']:.4f}")

## 下一步

单数据集训练完成后，可以进行跨域迁移：
- `03_cross_domain_transfer.ipynb` - 跨数据集迁移与零样本预测