# 02 - 模型训练：SOH + RUL联合预测

本notebook演示基于历史循环序列的SOH和RUL联合预测。

## 核心设计

**输入**: 从第1次循环到当前的所有历史充电曲线  
**输出**: 当前SOH + 预测的RUL（剩余循环次数）  
**特征**: 电压、电流、时间（3通道，跨数据集通用，无温度）

## 新特性

1. **16维协议无关特征**：适用于所有充电协议（CC、CC-CV、多阶段等）
2. **3通道时序数据**：V, I, time（保留真实秒数）
3. **随机EOL阈值**：训练时随机选择EOL阈值，范围为 [电池最低SOH, 当前SOH]
4. **多数据集支持**：可选择 MATR, CALCE, HUST, XJTU 等数据集

## 训练策略

**随机截取策略**：训练时随机选择截断点，使用从头到截断点的靠前序列预测：
- 每个epoch看到不同的截断位置
- 模型学习从寿命的不同阶段进行预测
- 增加数据多样性，提升泛化能力

**前置要求**: 运行 `00_data_preprocessing.ipynb` 完成数据预处理

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']  # Windows 系统推荐使用 SimHei
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
import torch
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from src.data import load_processed_batteries, load_all_processed
from src.data.historical_dataset import (
    HistoricalBatteryDataset,
    create_historical_dataloaders,
)
from src.models import HistoricalSOHRULModel, get_model

# 设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'使用设备: {device}')

# 结果保存目录
CHECKPOINT_DIR = Path('../results/checkpoints')
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

## 1. 加载数据

In [None]:
# 加载单个数据集
dataset_name = 'MATR'  # 可选: MATR, CALCE, HUST, XJTU等
batteries = load_processed_batteries(f'../data/processed/{dataset_name}')

print(f"数据集: {dataset_name}")
print(f"电池数: {len(batteries)}")

# 统计有效电池（已达EOL）
eol_count = sum(1 for b in batteries if b.eol_cycle is not None)
print(f"已达EOL: {eol_count}/{len(batteries)}")

# 显示示例
bat = batteries[0]
print(f"\n示例: {bat.cell_id}")
print(f"  循环数: {len(bat)}")
print(f"  SOH范围: {bat.get_soh_array().min():.3f} - {bat.get_soh_array().max():.3f}")
print(f"  EOL: {bat.eol_cycle}")




## 2. 创建数据集和加载器

使用 `random_truncate_train=True` 启用随机截取训练策略。

In [None]:
# 创建数据加载器
train_loader, val_loader, test_loader = create_historical_dataloaders(
    batteries=batteries,
    input_mode='summary',           # 使用摘要模式
    num_samples=200,                # 每循环采样点
    num_summary_cycles=50,          # 摘要循环数
    train_ratio=0.7,
    val_ratio=0.15,
    batch_size=32,
    predict_rul=True,               # 同时预测RUL
    random_truncate_train=True,     # 训练时随机截取
    truncate_min_ratio=0.2,         # 最小截取20%
    truncate_max_ratio=1.0,         # 最大截取100%
    seed=42,
)

print(f"\n训练批次: {len(train_loader)}")
print(f"验证样本: {len(val_loader.dataset)}")
print(f"测试样本: {len(test_loader.dataset)}")

In [None]:
# 查看一个训练批次
batch = next(iter(train_loader))

print("批次内容:")
print(f"  feature: {batch['feature'].shape}  # (batch, cycles, samples, channels)")
print(f"  soh: {batch['soh'].shape}")
if 'rul' in batch:
    print(f"  rul: {batch['rul'].shape}")
    valid_rul = batch['rul'][batch['rul'] >= 0]
    print(f"  有效RUL数量: {len(valid_rul)}/{len(batch['rul'])}")
    if len(valid_rul) > 0:
        print(f"  有效RUL范围: [{valid_rul.min():.0f}, {valid_rul.max():.0f}]")
if 'trend_features' in batch:
    print(f"  trend_features: {batch['trend_features'].shape}")
print(f"\n  SOH范围: [{batch['soh'].min():.3f}, {batch['soh'].max():.3f}]")

## 3. RUL标签诊断

In [None]:
# 检查训练集中的RUL标签分布
all_rul = []
all_soh = []
for batch in train_loader:
    if 'rul' in batch:
        # 同时收集有效RUL及其对应的SOH
        valid_mask = batch['rul'] >= 0
        valid_rul = batch['rul'][valid_mask]
        valid_soh = batch['soh'][valid_mask]
        all_rul.extend(valid_rul.tolist())
        all_soh.extend(valid_soh.tolist())

if all_rul:
    all_rul = np.array(all_rul)
    all_soh = np.array(all_soh)
    print(f"训练集RUL统计:")
    print(f"  有效RUL样本数: {len(all_rul)}")
    print(f"  RUL范围: [{all_rul.min():.0f}, {all_rul.max():.0f}]")
    print(f"  RUL均值: {all_rul.mean():.1f}")
    print(f"  RUL中位数: {np.median(all_rul):.1f}")
    
    # 可视化RUL分布
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].hist(all_rul, bins=50, alpha=0.7, color='steelblue')
    axes[0].set_xlabel('RUL (循环)')
    axes[0].set_ylabel('频数')
    axes[0].set_title('训练集RUL分布')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].scatter(all_soh, all_rul, alpha=0.3, s=5)
    axes[1].set_xlabel('SOH')
    axes[1].set_ylabel('RUL (循环)')
    axes[1].set_title('SOH vs RUL关系')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("警告: 训练集中没有有效的RUL标签！")
    print("请检查数据预处理，确保电池已计算EOL。")

## 4. 创建和训练模型

In [None]:
# 创建模型
model = HistoricalSOHRULModel(
    backbone='lstm',             # 可选: 'lstm', 'transformer', 'cnn'
    in_channels=3,               # V, I, T三通道
    hidden_size=128,
    num_layers=2,
    dropout=0.1,
    use_trend_features=True,     # 使用退化趋势特征
    predict_rul=True,            # 同时预测RUL
).to(device)

# 模型信息
total_params = sum(p.numel() for p in model.parameters())
print(f"模型: {model.backbone_type}")
print(f"参数量: {total_params:,}")

In [None]:
# 训练模型
# 注意：RUL损失权重已增加到1.0，确保模型充分学习RUL预测
# 使用组合损失（MSE + MAE）和改进的初始化，提升RUL预测性能
history = model.fit_historical(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=100,
    lr=1e-3,
    soh_weight=1.0,              # SOH损失权重
    rul_weight=1.0,              # RUL损失权重（已从0.1增加到1.0）
    early_stop_patience=20,
    save_path=str(CHECKPOINT_DIR / f'historical_{dataset_name}.pth'),
    verbose=True,
)

Epoch 10: loss=57528.833984, soh=0.002868, rul=57528.830078, val=36144.955489


Training:  20%|██████████████▏                                                        | 20/100 [08:27<15:49, 11.87s/it]

Epoch 20: loss=36961.198242, soh=0.006434, rul=36961.192383, val=22610.705720


Training:  30%|█████████████████████▎                                                 | 30/100 [10:23<13:32, 11.61s/it]

Epoch 30: loss=36286.351562, soh=0.005813, rul=36286.345703, val=19004.448538


Training:  40%|████████████████████████████▍                                          | 40/100 [12:22<11:50, 11.83s/it]

Epoch 40: loss=37371.562012, soh=0.007771, rul=37371.553223, val=19606.794226


Training:  50%|███████████████████████████████████▌                                   | 50/100 [14:18<09:35, 11.51s/it]

Epoch 50: loss=26966.047852, soh=0.005554, rul=26966.042480, val=19410.839642


Training:  50%|███████████████████████████████████▌                                   | 50/100 [14:29<14:29, 17.40s/it]

Early stopping at epoch 51





In [None]:
# 可视化训练过程
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 总损失
ax = axes[0]
ax.plot(history['train_loss'], label='训练')
if history['val_loss']:
    ax.plot(history['val_loss'], label='验证')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('总损失')
ax.legend()
ax.grid(True, alpha=0.3)

# SOH损失
ax = axes[1]
ax.plot(history['train_soh_loss'])
ax.set_xlabel('Epoch')
ax.set_ylabel('SOH Loss')
ax.set_title('SOH损失')
ax.grid(True, alpha=0.3)

# RUL损失
ax = axes[2]
ax.plot(history['train_rul_loss'])
ax.set_xlabel('Epoch')
ax.set_ylabel('RUL Loss')
ax.set_title('RUL损失')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. 模型评估

In [None]:
# 在测试集上评估
metrics = model.evaluate_historical(test_loader)

print("===== 测试集评估结果 =====")
print(f"SOH RMSE: {metrics['soh_rmse']:.4f}")
print(f"SOH MAE:  {metrics['soh_mae']:.4f}")
if 'rul_rmse' in metrics:
    print(f"RUL RMSE: {metrics['rul_rmse']:.1f} 循环")
    print(f"RUL MAE:  {metrics['rul_mae']:.1f} 循环")
    if 'rul_mape' in metrics:
        print(f"RUL MAPE: {metrics['rul_mape']:.1f}%")

In [None]:
# 获取预测结果并可视化
results = model.predict_batch(test_loader)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# SOH预测
ax = axes[0]
ax.scatter(results['soh_label'], results['soh_pred'], alpha=0.5, s=10)
ax.plot([0.6, 1.1], [0.6, 1.1], 'r--', linewidth=2, label='理想线')
ax.set_xlabel('真实SOH')
ax.set_ylabel('预测SOH')
ax.set_title(f'SOH预测 (RMSE={metrics["soh_rmse"]:.4f})')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim(0.6, 1.05)
ax.set_ylim(0.6, 1.05)

# RUL预测
if 'rul_pred' in results and 'rul_label' in results:
    ax = axes[1]
    ax.scatter(results['rul_label'], results['rul_pred'], alpha=0.5, s=10)
    max_rul = max(results['rul_label'].max(), results['rul_pred'].max())
    ax.plot([0, max_rul], [0, max_rul], 'r--', linewidth=2, label='理想线')
    ax.set_xlabel('真实RUL (循环)')
    ax.set_ylabel('预测RUL (循环)')
    ax.set_title(f'RUL预测 (RMSE={metrics.get("rul_rmse", 0):.1f})')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# 预测误差分布
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# SOH误差分布
soh_error = results['soh_pred'] - results['soh_label']
ax = axes[0]
ax.hist(soh_error, bins=50, alpha=0.7, color='steelblue')
ax.axvline(x=0, color='r', linestyle='--')
ax.set_xlabel('SOH预测误差')
ax.set_ylabel('频数')
ax.set_title(f'SOH误差分布 (均值={soh_error.mean():.4f}, 标准差={soh_error.std():.4f})')
ax.grid(True, alpha=0.3)

# RUL误差分布
if 'rul_pred' in results and 'rul_label' in results:
    rul_error = results['rul_pred'] - results['rul_label']
    ax = axes[1]
    ax.hist(rul_error, bins=50, alpha=0.7, color='coral')
    ax.axvline(x=0, color='r', linestyle='--')
    ax.set_xlabel('RUL预测误差 (循环)')
    ax.set_ylabel('频数')
    ax.set_title(f'RUL误差分布 (均值={rul_error.mean():.1f}, 标准差={rul_error.std():.1f})')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. 不同模型架构比较

In [None]:
# 比较不同backbone
backbones = ['lstm', 'transformer', 'cnn']
results_compare = {}

for backbone in backbones:
    print(f"\n训练 {backbone} 模型...")
    
    model_bb = HistoricalSOHRULModel(
        backbone=backbone,
        in_channels=3,
        hidden_size=128,
        use_trend_features=True,
        predict_rul=True,
    ).to(device)
    
    model_bb.fit_historical(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=50,
        lr=1e-3,
        verbose=False,
    )
    
    metrics_bb = model_bb.evaluate_historical(test_loader)
    results_compare[backbone] = metrics_bb
    
    print(f"  SOH RMSE: {metrics_bb['soh_rmse']:.4f}")
    if 'rul_rmse' in metrics_bb:
        print(f"  RUL RMSE: {metrics_bb['rul_rmse']:.1f}")

In [None]:
# 可视化比较
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

names = list(results_compare.keys())
soh_rmse = [results_compare[n]['soh_rmse'] for n in names]
rul_rmse = [results_compare[n].get('rul_rmse', 0) for n in names]
x = np.arange(len(names))

axes[0].bar(x, soh_rmse, color='steelblue')
axes[0].set_xticks(x)
axes[0].set_xticklabels(names)
axes[0].set_ylabel('SOH RMSE')
axes[0].set_title('SOH预测精度')
axes[0].grid(True, alpha=0.3, axis='y')

axes[1].bar(x, rul_rmse, color='coral')
axes[1].set_xticks(x)
axes[1].set_xticklabels(names)
axes[1].set_ylabel('RUL RMSE (循环)')
axes[1].set_title('RUL预测精度')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# 打印比较表
print("\n===== 模型比较 =====")
print(f"{'模型':<15} {'SOH RMSE':<12} {'RUL RMSE':<12}")
print("-" * 39)
for name in names:
    m = results_compare[name]
    rul = f"{m.get('rul_rmse', 0):.1f}" if 'rul_rmse' in m else '-'
    print(f"{name:<15} {m['soh_rmse']:<12.4f} {rul:<12}")

## 7. 单电池预测轨迹可视化

In [None]:
# 选择一个测试电池，展示完整预测轨迹
# 创建完整预测的数据集（不随机截取）
test_batteries = [batteries[i] for i in np.random.permutation(len(batteries))[-5:]]

full_dataset = HistoricalBatteryDataset(
    batteries=test_batteries[:1],  # 只用一个电池
    input_mode='summary',
    num_samples=200,
    num_summary_cycles=50,
    random_truncate=False,
    predict_rul=True,
)

full_loader = torch.utils.data.DataLoader(full_dataset, batch_size=32, shuffle=False)

# 预测
pred_results = model.predict_batch(full_loader)

# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# SOH轨迹
ax = axes[0]
ax.plot(pred_results['soh_label'], 'b-', linewidth=2, label='真实SOH')
ax.plot(pred_results['soh_pred'], 'r--', linewidth=2, label='预测SOH')
ax.axhline(y=0.8, color='green', linestyle=':', label='EOL阈值')
ax.set_xlabel('样本索引')
ax.set_ylabel('SOH')
ax.set_title('SOH预测轨迹')
ax.legend()
ax.grid(True, alpha=0.3)

# RUL轨迹
if 'rul_pred' in pred_results and 'rul_label' in pred_results:
    ax = axes[1]
    ax.plot(pred_results['rul_label'], 'b-', linewidth=2, label='真实RUL')
    ax.plot(pred_results['rul_pred'], 'r--', linewidth=2, label='预测RUL')
    ax.set_xlabel('样本索引')
    ax.set_ylabel('RUL (循环)')
    ax.set_title('RUL预测轨迹')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. 保存和加载模型

In [None]:
# 保存模型
save_path = CHECKPOINT_DIR / f'historical_{dataset_name}_final.pth'
model.save(str(save_path))
print(f"模型已保存: {save_path}")

# 加载模型
model_loaded = HistoricalSOHRULModel(
    backbone='lstm',
    in_channels=3,
    hidden_size=128,
    use_trend_features=True,
    predict_rul=True,
).to(device)
model_loaded.load(str(save_path))

# 验证
metrics_loaded = model_loaded.evaluate_historical(test_loader)
print(f"加载模型验证 - SOH RMSE: {metrics_loaded['soh_rmse']:.4f}")

## 下一步

单数据集训练完成后，可以进行跨域迁移：
- `03_cross_domain_transfer.ipynb` - 跨数据集迁移与零样本预测