# 03 - 跨域迁移与零样本预测

本notebook演示跨数据集的模型迁移能力：

1. **多域联合训练** - 同时使用多个数据集训练
2. **零样本预测** - 在未见过的数据集上直接预测
3. **域适应** - 通过少量目标域数据微调

**核心优势**：
- 3通道特征（V, I, T）适应不同数据集
- 域平衡采样，避免大数据集主导训练
- 历史循环特征捕捉跨协议的通用退化模式

**前置要求**: 
- 运行 `00_data_preprocessing.ipynb` 预处理多个数据集
- 运行 `02_model_training.ipynb` 了解基本训练流程

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from src.data import load_all_processed, load_processed_batteries
from src.data.historical_dataset import (
    HistoricalBatteryDataset,
    create_historical_dataloaders,
    create_multi_domain_historical_loaders,
)
from src.models import HistoricalSOHRULModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'使用设备: {device}')

CHECKPOINT_DIR = Path('../results/checkpoints')
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

## 1. 加载多个数据集

In [None]:
# 加载所有可用数据集
all_data = load_all_processed('../data/processed')

print("===== 可用数据集 =====")
total_batteries = 0
for name, batteries in all_data.items():
    eol_count = sum(1 for b in batteries if b.eol_cycle is not None)
    total_batteries += len(batteries)
    print(f"  {name}: {len(batteries)} 电池 (EOL: {eol_count})")
print(f"\n总计: {total_batteries} 电池, {len(all_data)} 个数据集")

## 2. 多域联合训练

使用多个数据集联合训练，学习跨域通用的特征表示。

In [None]:
# 选择用于训练的数据集
train_domains = {k: v for k, v in all_data.items() if len(v) >= 5}
print(f"用于训练的数据集: {list(train_domains.keys())}")

# 创建多域数据加载器
train_loader, val_loader, test_loader, domain_info = create_multi_domain_historical_loaders(
    domain_batteries=train_domains,
    input_mode='summary',
    num_samples=200,
    num_summary_cycles=50,
    train_ratio=0.7,
    val_ratio=0.15,
    batch_size=32,
    balance_domains=True,      # 域平衡采样
    predict_rul=True,
    random_truncate=True,      # 训练时随机截取
    truncate_min_ratio=0.2,
    seed=42,
)

In [None]:
# 创建多域模型
model_multi = HistoricalSOHRULModel(
    backbone='lstm',
    in_channels=3,
    hidden_size=128,
    num_layers=2,
    use_trend_features=True,
    predict_rul=True,
).to(device)

# 训练
print("\n开始多域联合训练...")
history = model_multi.fit_historical(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=100,
    lr=1e-3,
    soh_weight=1.0,
    rul_weight=0.1,
    early_stop_patience=20,
    save_path=str(CHECKPOINT_DIR / 'historical_multidomain.pth'),
    verbose=True,
)

In [None]:
# 评估多域模型
metrics_multi = model_multi.evaluate_historical(test_loader)

print("\n===== 多域模型评估 =====")
print(f"SOH RMSE: {metrics_multi['soh_rmse']:.4f}")
print(f"SOH MAE:  {metrics_multi['soh_mae']:.4f}")
if 'rul_rmse' in metrics_multi:
    print(f"RUL RMSE: {metrics_multi['rul_rmse']:.1f}")
    print(f"RUL MAE:  {metrics_multi['rul_mae']:.1f}")

## 3. 零样本预测（跨域泛化）

在训练时未见过的数据集上直接进行预测。

In [None]:
# 零样本评估：在单个数据集上训练，在其他数据集上测试
def zero_shot_evaluate(model, target_data, dataset_name):
    """在目标数据集上进行零样本评估"""
    target_dataset = HistoricalBatteryDataset(
        batteries=target_data,
        input_mode='summary',
        num_samples=200,
        num_summary_cycles=50,
        random_truncate=False,
        predict_rul=True,
    )
    
    target_loader = torch.utils.data.DataLoader(
        target_dataset, batch_size=32, shuffle=False
    )
    
    metrics = model.evaluate_historical(target_loader)
    return metrics

# 零样本评估结果
print("\n===== 零样本预测评估 =====")
zero_shot_results = {}

for ds_name, ds_data in all_data.items():
    if len(ds_data) < 3:
        continue
    
    metrics = zero_shot_evaluate(model_multi, ds_data, ds_name)
    zero_shot_results[ds_name] = metrics
    
    print(f"\n{ds_name}:")
    print(f"  SOH RMSE: {metrics['soh_rmse']:.4f}")
    if 'rul_rmse' in metrics:
        print(f"  RUL RMSE: {metrics['rul_rmse']:.1f}")

In [None]:
# 可视化零样本结果
if zero_shot_results:
    names = list(zero_shot_results.keys())
    soh_rmse = [zero_shot_results[n]['soh_rmse'] for n in names]
    rul_rmse = [zero_shot_results[n].get('rul_rmse', 0) for n in names]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    x = np.arange(len(names))
    
    axes[0].bar(x, soh_rmse, color='steelblue')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(names, rotation=45)
    axes[0].set_ylabel('SOH RMSE')
    axes[0].set_title('各数据集零样本SOH预测')
    axes[0].grid(True, alpha=0.3, axis='y')
    
    axes[1].bar(x, rul_rmse, color='coral')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(names, rotation=45)
    axes[1].set_ylabel('RUL RMSE')
    axes[1].set_title('各数据集零样本RUL预测')
    axes[1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

## 4. 单域 vs 多域模型比较

比较在单个数据集训练的模型与多域训练模型的跨域性能。

In [None]:
# 在单个数据集上训练
source_name = 'MATR'
if source_name in all_data:
    print(f"训练单域模型 (源域: {source_name})...")
    
    train_loader_single, val_loader_single, _ = create_historical_dataloaders(
        batteries=all_data[source_name],
        input_mode='summary',
        num_samples=200,
        num_summary_cycles=50,
        random_truncate_train=True,
        batch_size=32,
    )
    
    model_single = HistoricalSOHRULModel(
        backbone='lstm',
        in_channels=3,
        hidden_size=128,
        use_trend_features=True,
        predict_rul=True,
    ).to(device)
    
    model_single.fit_historical(
        train_loader=train_loader_single,
        val_loader=val_loader_single,
        epochs=50,
        verbose=False,
    )
    
    print("完成!")

In [None]:
# 比较两个模型在不同数据集上的表现
if source_name in all_data:
    comparison = {'单域模型': {}, '多域模型': {}}
    
    for ds_name, ds_data in all_data.items():
        if len(ds_data) < 3:
            continue
        
        # 单域模型
        metrics_single = zero_shot_evaluate(model_single, ds_data, ds_name)
        comparison['单域模型'][ds_name] = metrics_single['soh_rmse']
        
        # 多域模型
        metrics_multi = zero_shot_evaluate(model_multi, ds_data, ds_name)
        comparison['多域模型'][ds_name] = metrics_multi['soh_rmse']
    
    # 可视化
    ds_names = list(comparison['单域模型'].keys())
    x = np.arange(len(ds_names))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.bar(x - width/2, [comparison['单域模型'][n] for n in ds_names], 
           width, label=f'单域模型({source_name})', color='steelblue')
    ax.bar(x + width/2, [comparison['多域模型'][n] for n in ds_names], 
           width, label='多域模型', color='coral')
    
    ax.set_xticks(x)
    ax.set_xticklabels(ds_names, rotation=45)
    ax.set_ylabel('SOH RMSE')
    ax.set_title('单域 vs 多域模型跨域性能比较')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    # 计算平均提升
    single_avg = np.mean(list(comparison['单域模型'].values()))
    multi_avg = np.mean(list(comparison['多域模型'].values()))
    improvement = (single_avg - multi_avg) / single_avg * 100
    print(f"\n平均性能提升: {improvement:.1f}%")

## 5. 域适应（Few-shot微调）

使用少量目标域数据对模型进行微调。

In [None]:
# 选择目标域
target_name = 'CALCE' if 'CALCE' in all_data else list(all_data.keys())[0]

if target_name in all_data and len(all_data[target_name]) >= 3:
    target_batteries = all_data[target_name]
    
    # 使用少量数据微调 (few-shot)
    few_shot_count = min(3, len(target_batteries) // 2)  # 使用3个电池
    
    # 划分数据
    np.random.seed(42)
    indices = np.random.permutation(len(target_batteries))
    finetune_batteries = [target_batteries[i] for i in indices[:few_shot_count]]
    test_batteries = [target_batteries[i] for i in indices[few_shot_count:]]
    
    print(f"目标域: {target_name}")
    print(f"微调数据: {few_shot_count} 电池")
    print(f"测试数据: {len(test_batteries)} 电池")
    
    # 复制模型进行微调
    model_finetuned = HistoricalSOHRULModel(
        backbone='lstm',
        in_channels=3,
        hidden_size=128,
        use_trend_features=True,
        predict_rul=True,
    ).to(device)
    
    # 加载预训练权重
    model_finetuned.load_state_dict(model_multi.state_dict())
    
    # 创建微调数据加载器
    finetune_dataset = HistoricalBatteryDataset(
        batteries=finetune_batteries,
        input_mode='summary',
        num_samples=200,
        num_summary_cycles=50,
        random_truncate=True,
        predict_rul=True,
    )
    finetune_loader = torch.utils.data.DataLoader(
        finetune_dataset, batch_size=16, shuffle=True
    )
    
    # 微调
    print("\n微调中...")
    model_finetuned.fit_historical(
        train_loader=finetune_loader,
        epochs=20,
        lr=1e-4,  # 更小的学习率
        verbose=False,
    )
    
    # 评估
    print("\n评估结果:")
    metrics_before = zero_shot_evaluate(model_multi, test_batteries, target_name)
    metrics_after = zero_shot_evaluate(model_finetuned, test_batteries, target_name)
    
    print(f"  微调前 SOH RMSE: {metrics_before['soh_rmse']:.4f}")
    print(f"  微调后 SOH RMSE: {metrics_after['soh_rmse']:.4f}")
    improvement = (metrics_before['soh_rmse'] - metrics_after['soh_rmse']) / metrics_before['soh_rmse'] * 100
    print(f"  提升: {improvement:.1f}%")

## 6. 总结

### 跨域迁移策略

| 策略 | 适用场景 | 数据需求 |
|------|----------|----------|
| 多域联合训练 | 有多个已标注数据集 | 多个源域 |
| 零样本预测 | 目标域无标注数据 | 仅源域 |
| Few-shot微调 | 少量目标域标注数据 | 源域+少量目标域 |

### 关键设计

1. **3通道特征**: 适应不同数据集
2. **域平衡采样**: 避免大数据集主导训练
3. **随机截取训练**: 增加数据多样性
4. **历史序列建模**: 捕捉通用退化模式