In [4]:
# 載入HFL_global_model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from Model import TransformerModel
device = torch.device("mps")
ssl_model_path = "ssl_pretrain.pt"
hfl_model_path = "HFL_global_model.pth"
HFL_Global_Model = TransformerModel(
    feature_dim=14,
    d_model=256,
    nhead=8,
    num_layers=4,
    output_dim=1,
    max_seq_length=100,
    dropout=0.1
).to(device)

SSL_Model = TransformerModel(
    feature_dim=9,
    d_model=256,
    nhead=8,
    num_layers=4,
    output_dim=None,
    max_seq_length=5000,
    dropout=0.1
).to(device)

# 載入HFL模型
HFL_Global_Model.load_state_dict(torch.load(hfl_model_path, map_location=device))

# 載入SSL模型 - 需要從checkpoint中提取model_state_dict
ssl_checkpoint = torch.load(ssl_model_path, map_location=device)
if isinstance(ssl_checkpoint, dict) and 'model_state_dict' in ssl_checkpoint:
    SSL_Model.load_state_dict(ssl_checkpoint['model_state_dict'])
else:
    SSL_Model.load_state_dict(ssl_checkpoint)

print("模型載入成功！")

模型載入成功！


In [None]:
# 使用 Trainer 載入 Dataset
import glob
import os
from config import load_config
from DataLoader import SequenceCSVDataset
from Trainer import FederatedTrainer

# 載入配置（如果沒有 config.yaml，需要先創建一個簡單的配置對象）
try:
    config = load_config('config.yaml')
except:
    # 如果沒有配置文件，創建一個簡單的配置對象
    class SimpleConfig:
        def __init__(self):
            self.data_path = "./data"  # 請根據實際路徑修改
            self.input_length = 96
            self.output_length = 1
            self.features = ['feature1', 'feature2']  # 請根據實際特徵修改
            self.target = 'Power_Demand'
            self.batch_size = 32
            self.device = device
    
    config = SimpleConfig()

# 載入數據集（以第一個客戶端為例）
csv_pattern = os.path.join(config.data_path, "*.csv")
csv_files = sorted(glob.glob(csv_pattern))

if csv_files:
    # 載入第一個客戶端的數據
    csv_file = csv_files[0]
    csv_name = os.path.splitext(os.path.basename(csv_file))[0]
    
    print(f"載入客戶端數據: {csv_name}")
    
    # 創建數據集對象
    dataset = SequenceCSVDataset(
        csv_path=config.data_path,
        csv_name=csv_name,
        input_len=config.input_length,
        output_len=config.output_length,
        features=config.features,
        target=config.target,
        save_path=config.data_path,
        train_ratio=0.8,
        val_ratio=0.1,
        split_type='time_based',
        fit_scalers=False  # 使用已保存的標準化器
    )
    
    # 使用 Trainer 分割數據集
    trainer = FederatedTrainer(HFL_Global_Model, config, device)
    train_dataset, val_dataset, test_dataset = trainer.split_dataset(dataset)
    
    # 創建數據加載器
    train_loader, val_loader, test_loader = trainer.create_data_loaders(
        train_dataset, val_dataset, test_dataset
    )
    
    print(f"\n✓ 數據載入成功！")
    print(f"  - 訓練集大小: {len(train_dataset)} 樣本")
    print(f"  - 驗證集大小: {len(val_dataset)} 樣本")
    print(f"  - 測試集大小: {len(test_dataset)} 樣本")
    print(f"  - 批次大小: {config.batch_size}")
    
    # 查看一個批次的數據形狀
    for inputs, targets in train_loader:
        print(f"\n數據形狀:")
        print(f"  - 輸入: {inputs.shape}")
        print(f"  - 目標: {targets.shape}")
        break
else:
    print(f"錯誤: 在 {config.data_path} 目錄下沒有找到任何 CSV 文件")

In [None]:
# 使用 Personalizer 獲取每個客戶端的個性化模型
from Personalizer import initialize_personalized_models, save_personalized_models

print("=" * 70)
print("Per-FedAvg 個性化模型初始化")
print("=" * 70)

# 方法 1: 使用配置文件進行個性化
# 確保 config 中包含必要的參數
if not hasattr(config, 'adaptation_lr'):
    config.adaptation_lr = 0.01  # 個性化適應學習率
if not hasattr(config, 'personalization_steps'):
    config.personalization_steps = 10  # 個性化適應步數
if not hasattr(config, 'model_save_path'):
    config.model_save_path = "./models"  # 模型保存路徑

# 指定全局模型路徑（Per-FedAvg 訓練好的模型）
global_model_path = hfl_model_path  # 使用前面載入的 HFL 模型路徑

try:
    # 初始化所有客戶端的個性化模型
    print(f"\n使用全局模型: {global_model_path}")
    print(f"個性化參數:")
    print(f"  - 適應學習率: {config.adaptation_lr}")
    print(f"  - 適應步數: {config.personalization_steps}")
    print(f"  - 設備: {config.device}")
    
    # 獲取所有客戶端的個性化模型狀態字典
    client_models = initialize_personalized_models(config, global_model_path)
    
    print(f"\n✓ 成功獲取 {len(client_models)} 個客戶端的個性化模型")
    print(f"\n客戶端列表:")
    for idx, client_name in enumerate(client_models.keys(), 1):
        print(f"  {idx}. {client_name}")
    
    # 可選：保存個性化模型到磁盤
    save_dir = "personalized_models"
    save_personalized_models(client_models, save_dir)
    
    # 示範：如何使用特定客戶端的個性化模型
    print(f"\n{'=' * 70}")
    print("使用範例：載入特定客戶端的個性化模型")
    print("=" * 70)
    
    if client_models:
        # 獲取第一個客戶端的名稱
        first_client = list(client_models.keys())[0]
        
        print(f"\n示範：載入客戶端 '{first_client}' 的個性化模型")
        
        # 創建新的模型實例
        personalized_model_client_1 = TransformerModel(
            feature_dim=config.feature_dim if hasattr(config, 'feature_dim') else 14,
            d_model=256,
            nhead=8,
            num_layers=4,
            output_dim=None,  # VFL 場景：不需要輸出層
            max_seq_length=100,
            dropout=0.1
        ).to(device)
        
        # 載入個性化後的權重
        personalized_model_client_1.load_state_dict(client_models[first_client])
        personalized_model_client_1.eval()
        
        print(f"✓ 成功載入客戶端 '{first_client}' 的個性化模型")
        print(f"\n模型已準備好用於 VFL 訓練！")
        
        # 示範：使用個性化模型進行前向傳播（不含輸出投影）
        print(f"\n{'=' * 70}")
        print("VFL 場景示範：使用個性化模型生成嵌入向量")
        print("=" * 70)
        
        # 模擬兩個參與方的數據
        batch_size = 8
        seq_length = 96
        
        # Party A 的數據（客戶端 1）
        x_party_a = torch.randn(batch_size, seq_length, 14).to(device)
        
        with torch.no_grad():
            # 使用個性化模型生成嵌入（不含輸出投影）
            embedding_a = personalized_model_client_1.forward_embedding(x_party_a)
        
        print(f"\n✓ Party A 嵌入向量形狀: {embedding_a.shape}")
        print(f"  - 批次大小: {embedding_a.shape[0]}")
        print(f"  - 嵌入維度: {embedding_a.shape[1]}")
        print(f"\n這個嵌入向量可以發送到 Server 進行融合預測！")
        print(f"原始數據 ({x_party_a.shape}) 保持在本地，保護隱私。")
        
except FileNotFoundError as e:
    print(f"\n❌ 錯誤: {e}")
    print("請確認全局模型路徑是否正確")
except Exception as e:
    print(f"\n❌ 發生錯誤: {e}")
    print("請檢查配置和數據路徑是否正確")

print(f"\n{'=' * 70}")
print("個性化模型初始化完成")
print("=" * 70)

In [None]:
# ============================================================================
# VFL 場景：載入 Weather (雲端) 和 HFL (本地) 的 Dataset
# ============================================================================

import pandas as pd
from sklearn.preprocessing import StandardScaler
import pickle

print("=" * 80)
print("VFL 數據載入：Weather (雲端) + HFL (本地)")
print("=" * 80)

# === 配置參數 ===
data_path = "./data"
weather_csv = "Weather.csv"
hfl_csv = "processed/Consumer_01.csv"  # 使用第一個消費者作為示例

# HFL 特徵（本地客戶端）
hfl_features = [
    'AC1', 'AC2', 'AC3', 'AC4', 
    'Dish washer', 'Washing Machine', 'Dryer', 'Water heater',
    'TV', 'Microwave', 'Kettle', 'Lighting', 'Refrigerator', 
    'Consumption_Total'
]

# Weather 特徵（雲端）
weather_features = [
    'TemperatureC', 'DewpointC', 'PressurehPa', 
    'WindSpeedKMH', 'WindSpeedGustKMH', 'Humidity',
    'HourlyPrecipMM', 'dailyrainMM', 'SolarRadiationWatts_m2'
]

# 目標變量
target = 'Power_Demand'

# 序列參數
seq_length = 96
batch_size = 32

print(f"\n【數據配置】")
print(f"  - Weather 特徵數: {len(weather_features)}")
print(f"  - HFL 特徵數: {len(hfl_features)}")
print(f"  - 序列長度: {seq_length}")
print(f"  - 批次大小: {batch_size}")

# ============================================================================
# 步驟 1: 載入 Weather 數據（雲端）
# ============================================================================
print(f"\n【步驟 1】載入 Weather 數據（雲端）")

weather_df = pd.read_csv(f"{data_path}/{weather_csv}")
print(f"  ✓ Weather 原始數據: {weather_df.shape}")
print(f"  - 可用特徵: {list(weather_df.columns)}")

# 確保有 datetime 欄位
if 'datetime' in weather_df.columns:
    weather_df['datetime'] = pd.to_datetime(weather_df['datetime'])
    weather_df.set_index('datetime', inplace=True)
    weather_df.sort_index(inplace=True)

# 提取 Weather 特徵
weather_data = weather_df[weather_features].values
print(f"  ✓ Weather 特徵矩陣: {weather_data.shape}")

# === Weather 標準化器（在雲端創建）===
weather_scaler = StandardScaler()
weather_data_scaled = weather_scaler.fit_transform(weather_data)
print(f"  ✓ Weather 數據已標準化")
print(f"    - 均值: {weather_scaler.mean_[:3]}...")
print(f"    - 標準差: {weather_scaler.scale_[:3]}...")

# 保存 Weather 標準化器（用於後續推理）
scaler_path = f"{data_path}/weather_scaler.pkl"
with open(scaler_path, 'wb') as f:
    pickle.dump(weather_scaler, f)
print(f"  ✓ Weather 標準化器已保存: {scaler_path}")

# ============================================================================
# 步驟 2: 載入 HFL 數據（本地客戶端）
# ============================================================================
print(f"\n【步驟 2】載入 HFL 數據（本地客戶端）")

hfl_df = pd.read_csv(f"{data_path}/{hfl_csv}")
print(f"  ✓ HFL 原始數據: {hfl_df.shape}")

# 檢查是否有 Power_Demand 目標
if target not in hfl_df.columns:
    print(f"  ⚠ 警告: 沒有找到 '{target}' 欄位，使用 'Consumption_Total' 作為目標")
    target = 'Consumption_Total'

# 提取 HFL 特徵
hfl_data = hfl_df[hfl_features].values
target_data = hfl_df[target].values
print(f"  ✓ HFL 特徵矩陣: {hfl_data.shape}")
print(f"  ✓ 目標變量: {target_data.shape}")

# === HFL 標準化器（在本地創建）===
hfl_scaler = StandardScaler()
hfl_data_scaled = hfl_scaler.fit_transform(hfl_data)
print(f"  ✓ HFL 數據已標準化")

target_scaler = StandardScaler()
target_data_scaled = target_scaler.fit_transform(target_data.reshape(-1, 1)).flatten()
print(f"  ✓ 目標變量已標準化")

# 保存 HFL 標準化器
hfl_scaler_path = f"{data_path}/hfl_scaler.pkl"
target_scaler_path = f"{data_path}/target_scaler.pkl"
with open(hfl_scaler_path, 'wb') as f:
    pickle.dump(hfl_scaler, f)
with open(target_scaler_path, 'wb') as f:
    pickle.dump(target_scaler, f)
print(f"  ✓ HFL 標準化器已保存: {hfl_scaler_path}")
print(f"  ✓ 目標標準化器已保存: {target_scaler_path}")

# ============================================================================
# 步驟 3: 對齊時間並創建序列數據
# ============================================================================
print(f"\n【步驟 3】對齊時間並創建序列數據")

# 確保兩個數據集長度一致
min_len = min(len(weather_data_scaled), len(hfl_data_scaled), len(target_data_scaled))
weather_data_scaled = weather_data_scaled[:min_len]
hfl_data_scaled = hfl_data_scaled[:min_len]
target_data_scaled = target_data_scaled[:min_len]

print(f"  ✓ 對齊後數據長度: {min_len}")

# 創建序列數據集
def create_sequences(weather, hfl, targets, seq_len):
    """創建時序序列數據"""
    X_weather, X_hfl, y = [], [], []
    
    for i in range(len(weather) - seq_len):
        X_weather.append(weather[i:i+seq_len])
        X_hfl.append(hfl[i:i+seq_len])
        y.append(targets[i+seq_len])  # 預測下一個時間點
    
    return np.array(X_weather), np.array(X_hfl), np.array(y)

X_weather_seq, X_hfl_seq, y_seq = create_sequences(
    weather_data_scaled, 
    hfl_data_scaled, 
    target_data_scaled,
    seq_length
)

print(f"  ✓ Weather 序列: {X_weather_seq.shape}")
print(f"  ✓ HFL 序列: {X_hfl_seq.shape}")
print(f"  ✓ 目標序列: {y_seq.shape}")

# ============================================================================
# 步驟 4: 分割訓練/驗證/測試集（時間順序）
# ============================================================================
print(f"\n【步驟 4】分割訓練/驗證/測試集（8:1:1）")

total_samples = len(X_weather_seq)
train_size = int(0.8 * total_samples)
val_size = int(0.1 * total_samples)

# 時間順序分割
X_weather_train = X_weather_seq[:train_size]
X_hfl_train = X_hfl_seq[:train_size]
y_train = y_seq[:train_size]

X_weather_val = X_weather_seq[train_size:train_size+val_size]
X_hfl_val = X_hfl_seq[train_size:train_size+val_size]
y_val = y_seq[train_size:train_size+val_size]

X_weather_test = X_weather_seq[train_size+val_size:]
X_hfl_test = X_hfl_seq[train_size+val_size:]
y_test = y_seq[train_size+val_size:]

print(f"  ✓ 訓練集: {len(X_weather_train)} 樣本")
print(f"  ✓ 驗證集: {len(X_weather_val)} 樣本")
print(f"  ✓ 測試集: {len(X_weather_test)} 樣本")

# ============================================================================
# 步驟 5: 創建 PyTorch DataLoader
# ============================================================================
print(f"\n【步驟 5】創建 PyTorch DataLoader")

from torch.utils.data import TensorDataset, DataLoader

# 轉換為 PyTorch Tensor
X_weather_train_t = torch.FloatTensor(X_weather_train).to(device)
X_hfl_train_t = torch.FloatTensor(X_hfl_train).to(device)
y_train_t = torch.FloatTensor(y_train).unsqueeze(1).to(device)

X_weather_val_t = torch.FloatTensor(X_weather_val).to(device)
X_hfl_val_t = torch.FloatTensor(X_hfl_val).to(device)
y_val_t = torch.FloatTensor(y_val).unsqueeze(1).to(device)

X_weather_test_t = torch.FloatTensor(X_weather_test).to(device)
X_hfl_test_t = torch.FloatTensor(X_hfl_test).to(device)
y_test_t = torch.FloatTensor(y_test).unsqueeze(1).to(device)

# 創建 Dataset
train_dataset_vfl = TensorDataset(X_weather_train_t, X_hfl_train_t, y_train_t)
val_dataset_vfl = TensorDataset(X_weather_val_t, X_hfl_val_t, y_val_t)
test_dataset_vfl = TensorDataset(X_weather_test_t, X_hfl_test_t, y_test_t)

# 創建 DataLoader
train_loader_vfl = DataLoader(train_dataset_vfl, batch_size=batch_size, shuffle=True)
val_loader_vfl = DataLoader(val_dataset_vfl, batch_size=batch_size, shuffle=False)
test_loader_vfl = DataLoader(test_dataset_vfl, batch_size=batch_size, shuffle=False)

print(f"  ✓ 訓練 DataLoader: {len(train_loader_vfl)} 批次")
print(f"  ✓ 驗證 DataLoader: {len(val_loader_vfl)} 批次")
print(f"  ✓ 測試 DataLoader: {len(test_loader_vfl)} 批次")

# ============================================================================
# 數據載入完成
# ============================================================================
print(f"\n{'=' * 80}")
print("✓ VFL 數據載入完成！")
print(f"{'=' * 80}")
print(f"\n數據摘要:")
print(f"  【雲端 - Weather】")
print(f"    - 特徵維度: {len(weather_features)}")
print(f"    - 訓練樣本: {X_weather_train.shape}")
print(f"  【本地 - HFL】")
print(f"    - 特徵維度: {len(hfl_features)}")
print(f"    - 訓練樣本: {X_hfl_train.shape}")
print(f"  【目標】")
print(f"    - 變量: {target}")
print(f"    - 訓練樣本: {y_train.shape}")
print(f"\n下一步：建立 Weather Model（雲端）和 Fusion Model（伺服器）")

In [None]:
# ============================================================================
# VFL 訓練：Fusion Model（本地）與 Weather Model（雲端）的訓練
# ============================================================================

from Model import TransformerModel, FusionModel
from tqdm import tqdm

print("=" * 80)
print("VFL 訓練：Fusion Model (本地) ↔ Weather Model (雲端)")
print("=" * 80)

# ============================================================================
# 步驟 1: 初始化模型
# ============================================================================
print(f"\n【步驟 1】初始化模型")

# === Weather Model（雲端，可訓練）===
weather_model = TransformerModel(
    feature_dim=len(weather_features),  # 9個氣象特徵
    d_model=256,
    nhead=8,
    num_layers=4,
    output_dim=None,  # 不需要輸出層，只產生嵌入
    max_seq_length=seq_length,
    dropout=0.1
).to(device)

print(f"  ✓ Weather Model 已初始化（雲端，可訓練）")
print(f"    - 輸入維度: {len(weather_features)}")
print(f"    - 嵌入維度: 256")
print(f"    - 可訓練參數: {sum(p.numel() for p in weather_model.parameters()):,}")

# === HFL Model（本地，凍結不訓練）===
# 使用個性化後的模型（從前面的 Cell 獲得）
if 'client_models' in locals() and len(client_models) > 0:
    first_client = list(client_models.keys())[0]
    hfl_model = TransformerModel(
        feature_dim=len(hfl_features),  # 14個 HFL 特徵
        d_model=256,
        nhead=8,
        num_layers=4,
        output_dim=None,  # 不需要輸出層
        max_seq_length=seq_length,
        dropout=0.1
    ).to(device)
    hfl_model.load_state_dict(client_models[first_client])
    print(f"  ✓ HFL Model 已載入（本地，來自個性化模型 '{first_client}'）")
else:
    # 如果沒有個性化模型，使用 HFL 全局模型
    hfl_model = TransformerModel(
        feature_dim=len(hfl_features),
        d_model=256,
        nhead=8,
        num_layers=4,
        output_dim=None,
        max_seq_length=seq_length,
        dropout=0.1
    ).to(device)
    # 嘗試載入 HFL 全局模型權重
    try:
        # 需要創建一個臨時模型來載入權重，因為 output_dim 不同
        temp_model = TransformerModel(
            feature_dim=14,
            d_model=256,
            nhead=8,
            num_layers=4,
            output_dim=1,  # 臨時使用
            max_seq_length=100,
            dropout=0.1
        ).to(device)
        temp_model.load_state_dict(torch.load(hfl_model_path, map_location=device))
        
        # 只複製非輸出層的權重
        hfl_state_dict = {}
        for key, value in temp_model.state_dict().items():
            if 'output_proj' not in key:
                hfl_state_dict[key] = value
        
        hfl_model.load_state_dict(hfl_state_dict, strict=False)
        print(f"  ✓ HFL Model 已載入（本地，來自全局模型）")
    except:
        print(f"  ⚠ HFL Model 使用隨機初始化（無法載入預訓練權重）")

# 凍結 HFL Model（不訓練）
for param in hfl_model.parameters():
    param.requires_grad = False
hfl_model.eval()
print(f"  ✓ HFL Model 已凍結（不參與訓練）")

# === Fusion Model（本地客戶端，可訓練）===
fusion_model = FusionModel(
    embedding_dim_party_a=256,  # Weather 嵌入維度
    embedding_dim_party_b=256,  # HFL 嵌入維度
    hidden_dim=256,
    output_dim=1,  # 預測電力需求
    dropout=0.1
).to(device)

print(f"  ✓ Fusion Model 已初始化（本地客戶端，可訓練）")
print(f"    - 可訓練參數: {sum(p.numel() for p in fusion_model.parameters()):,}")

# ============================================================================
# 步驟 2: 設置優化器和損失函數
# ============================================================================
print(f"\n【步驟 2】設置優化器和損失函數")

# Weather Model 優化器（雲端）
weather_optimizer = torch.optim.Adam(
    weather_model.parameters(),
    lr=0.001,
    weight_decay=1e-4
)

# Fusion Model 優化器（本地）
fusion_optimizer = torch.optim.Adam(
    fusion_model.parameters(),
    lr=0.001,
    weight_decay=1e-4
)

# 損失函數
criterion = nn.MSELoss()

print(f"  ✓ 優化器已設置")
print(f"    - Weather Model LR: 0.001 (雲端)")
print(f"    - Fusion Model LR: 0.001 (本地)")

# ============================================================================
# 步驟 3: 訓練配置
# ============================================================================
num_epochs = 50
print_every = 5

print(f"\n【步驟 3】訓練配置")
print(f"  - 訓練輪數: {num_epochs}")
print(f"  - 訓練策略: 本地訓練 Fusion Model，雲端訓練 Weather Model")
print(f"  - HFL Model: 凍結（不訓練）")

print(f"\n【VFL 架構說明】")
print(f"  ┌─────────────────────────────────────────────────────────┐")
print(f"  │                    雲端 (Cloud)                         │")
print(f"  │  ┌───────────────────────────────────────────┐          │")
print(f"  │  │  Weather Model (可訓練)                    │          │")
print(f"  │  │  - 輸入: Weather 數據 (9 特徵)             │          │")
print(f"  │  │  - 輸出: Weather 嵌入向量 (256 維)        │          │")
print(f"  │  └──────────────┬────────────────────────────┘          │")
print(f"  └─────────────────┼─────────────────────────────────────────┘")
print(f"                    │ Weather 嵌入向量 (256 維)")
print(f"                    │ 傳送到本地客戶端")
print(f"                    ▼")
print(f"  ┌─────────────────────────────────────────────────────────┐")
print(f"  │                 本地客戶端 (Client)                      │")
print(f"  │  ┌───────────────────────────────────────────┐          │")
print(f"  │  │  HFL Model (凍結，不訓練)                  │          │")
print(f"  │  │  - 輸入: HFL 本地數據 (14 特徵)            │          │")
print(f"  │  │  - 輸出: HFL 嵌入向量 (256 維)            │          │")
print(f"  │  └──────────────┬────────────────────────────┘          │")
print(f"  │                 │ HFL 嵌入 (256)                         │")
print(f"  │                 │                                        │")
print(f"  │  ┌──────────────▼────────────────────────────┐          │")
print(f"  │  │  Fusion Model (本地，可訓練)               │          │")
print(f"  │  │  - 輸入: Weather 嵌入 + HFL 嵌入          │          │")
print(f"  │  │  - 輸出: Power Demand 預測                │          │")
print(f"  │  └──────────────┬────────────────────────────┘          │")
print(f"  └─────────────────┼─────────────────────────────────────────┘")
print(f"                    │ Weather 嵌入的梯度")
print(f"                    │ 傳回雲端")
print(f"                    ▼")
print(f"            更新 Weather Model")

# 記錄訓練歷史
train_losses = []
val_losses = []

# ============================================================================
# 步驟 4: 訓練循環
# ============================================================================
print(f"\n{'=' * 80}")
print("開始訓練...")
print("=" * 80)

for epoch in range(num_epochs):
    # ========================================================================
    # 本地訓練階段
    # ========================================================================
    fusion_model.train()
    weather_model.train()
    
    epoch_train_loss = 0.0
    
    for weather_batch, hfl_batch, targets in train_loader_vfl:
        # === 本地前向傳播 ===
        # 步驟 1: 雲端 Weather Model 生成嵌入（需要梯度）
        weather_embedding = weather_model.forward_embedding(weather_batch)
        
        # 步驟 2: 本地 HFL Model 生成嵌入（凍結，不需要梯度）
        with torch.no_grad():
            hfl_embedding = hfl_model.forward_embedding(hfl_batch)
        
        # === 本地 Fusion Model 訓練 ===
        # 步驟 3: 本地 Fusion Model 預測
        fusion_optimizer.zero_grad()
        weather_optimizer.zero_grad()
        
        predictions = fusion_model(weather_embedding, hfl_embedding)
        loss = criterion(predictions, targets)
        
        # === 反向傳播 ===
        # 步驟 4: 計算梯度（會傳播到 Weather Model）
        loss.backward()
        
        # 步驟 5: 更新本地 Fusion Model
        fusion_optimizer.step()
        
        # 步驟 6: 更新雲端 Weather Model（通過梯度）
        weather_optimizer.step()
        
        epoch_train_loss += loss.item()
    
    # ========================================================================
    # 驗證階段
    # ========================================================================
    weather_model.eval()
    fusion_model.eval()
    
    val_loss = 0.0
    with torch.no_grad():
        for weather_batch, hfl_batch, targets in val_loader_vfl:
            weather_embedding = weather_model.forward_embedding(weather_batch)
            hfl_embedding = hfl_model.forward_embedding(hfl_batch)
            predictions = fusion_model(weather_embedding, hfl_embedding)
            loss = criterion(predictions, targets)
            val_loss += loss.item()
    
    # 計算平均損失
    avg_train_loss = epoch_train_loss / len(train_loader_vfl)
    avg_val_loss = val_loss / len(val_loader_vfl)
    
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    
    # 打印進度
    if (epoch + 1) % print_every == 0 or epoch == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  - Train Loss: {avg_train_loss:.6f}")
        print(f"  - Val Loss: {avg_val_loss:.6f}")

print(f"\n{'=' * 80}")
print("✓ 訓練完成！")
print("=" * 80)

# ============================================================================
# 步驟 5: 評估模型
# ============================================================================
print(f"\n【步驟 5】評估模型")

weather_model.eval()
fusion_model.eval()

test_loss = 0.0
all_predictions = []
all_targets = []

with torch.no_grad():
    for weather_batch, hfl_batch, targets in test_loader_vfl:
        # 雲端: Weather 嵌入
        weather_embedding = weather_model.forward_embedding(weather_batch)
        # 本地: HFL 嵌入
        hfl_embedding = hfl_model.forward_embedding(hfl_batch)
        # 本地: Fusion 預測
        predictions = fusion_model(weather_embedding, hfl_embedding)
        
        loss = criterion(predictions, targets)
        test_loss += loss.item()
        
        all_predictions.extend(predictions.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

avg_test_loss = test_loss / len(test_loader_vfl)
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)

# 計算評估指標
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

mse = mean_squared_error(all_targets, all_predictions)
mae = mean_absolute_error(all_targets, all_predictions)
rmse = np.sqrt(mse)
r2 = r2_score(all_targets, all_predictions)

print(f"\n測試集性能:")
print(f"  - MSE: {mse:.6f}")
print(f"  - MAE: {mae:.6f}")
print(f"  - RMSE: {rmse:.6f}")
print(f"  - R²: {r2:.6f}")

# ============================================================================
# 步驟 6: 可視化結果
# ============================================================================
print(f"\n【步驟 6】可視化結果")

import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 訓練損失曲線
axes[0, 0].plot(train_losses, label='Train Loss', color='blue')
axes[0, 0].plot(val_losses, label='Val Loss', color='orange')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 預測 vs 實際（散點圖）
axes[0, 1].scatter(all_targets, all_predictions, alpha=0.5, s=10)
axes[0, 1].plot([all_targets.min(), all_targets.max()], 
                [all_targets.min(), all_targets.max()], 
                'r--', lw=2, label='Perfect Prediction')
axes[0, 1].set_xlabel('Actual')
axes[0, 1].set_ylabel('Predicted')
axes[0, 1].set_title(f'Predictions vs Actual (R²={r2:.4f})')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 時間序列預測（前 200 個樣本）
sample_size = min(200, len(all_targets))
axes[1, 0].plot(all_targets[:sample_size], label='Actual', color='blue', linewidth=2)
axes[1, 0].plot(all_predictions[:sample_size], label='Predicted', color='red', linewidth=1, alpha=0.7)
axes[1, 0].set_xlabel('Time Step')
axes[1, 0].set_ylabel('Power Demand')
axes[1, 0].set_title('Time Series Prediction (First 200 samples)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 誤差分佈
errors = all_predictions.flatten() - all_targets.flatten()
axes[1, 1].hist(errors, bins=50, alpha=0.7, color='green', edgecolor='black')
axes[1, 1].axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero Error')
axes[1, 1].set_xlabel('Prediction Error')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title(f'Error Distribution (MAE={mae:.4f})')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vfl_training_results.png', dpi=150, bbox_inches='tight')
print(f"  ✓ 可視化圖表已保存: vfl_training_results.png")

plt.show()

print(f"\n{'=' * 80}")
print("VFL 訓練和評估完成！")
print("=" * 80)
print(f"\n模型摘要:")
print(f"  【Weather Model（雲端）】")
print(f"    - 狀態: 已訓練")
print(f"    - 位置: 雲端")
print(f"    - 參數: {sum(p.numel() for p in weather_model.parameters()):,}")
print(f"  【HFL Model（本地）】")
print(f"    - 狀態: 凍結（未訓練）")
print(f"    - 位置: 本地客戶端")
print(f"    - 參數: {sum(p.numel() for p in hfl_model.parameters()):,}")
print(f"  【Fusion Model（本地）】")
print(f"    - 狀態: 已訓練")
print(f"    - 位置: 本地客戶端（每個客戶端自有）")
print(f"    - 參數: {sum(p.numel() for p in fusion_model.parameters()):,}")
print(f"\n訓練策略:")
print(f"  ✓ Fusion Model 在本地訓練，融合 Weather 和 HFL 嵌入")
print(f"  ✓ Weather Model 在雲端訓練，通過梯度反向傳播更新")
print(f"  ✓ HFL Model 保持凍結，保留個性化特徵")
print(f"  ✓ 隱私保護機制:")
print(f"    - Weather 數據存於雲端，只傳送嵌入向量到本地")
print(f"    - HFL 原始數據不離開本地")
print(f"    - Fusion Model 在本地，預測結果不外傳")
print(f"    - 只有 Weather 嵌入的梯度傳回雲端")
print(f"\n資料流向:")
print(f"  前向: 雲端(Weather嵌入) → 本地(融合+預測)")
print(f"  反向: 本地(梯度) → 雲端(更新Weather Model)")