In [None]:
pip install torch_geometric torchprofile

##Data Loader

In [None]:
# ===========================================================
# 0. 导入 & 数据加载 (只运行一次)
# ===========================================================
import pickle
from pathlib import Path
import os
import time
from torch_geometric.data import Data # 导入 Data 以进行类型检查

print("--- 正在启动数据加载单元格 ---")

# --- 1. 定义关键路径 ---
# (确保这些库已在 STGCN/ASTGCN 单元格中导入，或者在这里导入)
DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")
DATA_SUBDIR = Path("Result/Sequential_13Hour_Data")
DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME

# --- 2. 一次性加载数据到全局变量 ---
GLOBAL_LOADED_DATA = None
print(f"正在从 {DATA_PATH} 加载数据...")
load_start_time = time.time()

try:
    if not DATA_PATH.exists():
        raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")

    with open(DATA_PATH, "rb") as f:
        # 这个变量将在整个 Colab 运行时中可用
        GLOBAL_LOADED_DATA = pickle.load(f)

    load_duration = time.time() - load_start_time
    print(f"数据加载成功! 耗时: {load_duration:.2f} 秒。")

    # 基本的数据验证
    if (not GLOBAL_LOADED_DATA or
        not isinstance(GLOBAL_LOADED_DATA, list) or
        not GLOBAL_LOADED_DATA[0] or
        not isinstance(GLOBAL_LOADED_DATA[0], list)):
        print("警告：加载的数据为空或格式不正确！")
    else:
        print(f"数据类型: {type(GLOBAL_LOADED_DATA)}, 总序列数: {len(GLOBAL_LOADED_DATA)}")
        print(f"第一条序列的类型: {type(GLOBAL_LOADED_DATA[0])}, 长度: {len(GLOBAL_LOADED_DATA[0])}")

except Exception as e:
    print(f"加载数据时发生严重错误: {e}")
    import traceback
    traceback.print_exc()

print("--- 数据加载单元格执行完毕 ---")

##Ours

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import RGCNConv # ⚡ USE RGCNConv
from torch_geometric.data import Data, Batch # Batch很重要
from torch_geometric.loader import DataLoader
# import matplotlib.pyplot as plt # Not used in the final reporting directly, can be commented if not needed elsewhere
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta # For time feature
from pathlib import Path
import time # Added for timing
import json # Added for report export
import torchprofile # Added for FLOPS calculation

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


# ===========================================================
# 1. 特征生成 & 辅助模块
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class RGCNModule(nn.Module):
    def __init__(self, rgcn_input_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate=0.5):
        super().__init__()
        self.rgcn_input_dim = rgcn_input_dim
        self.rgcn_hidden_dim = rgcn_hidden_dim
        self.rgcn_output_dim = rgcn_output_dim
        self.num_relations = num_relations

        self.conv1 = RGCNConv(rgcn_input_dim, rgcn_hidden_dim, num_relations)
        self.bn1 = nn.BatchNorm1d(rgcn_hidden_dim)
        self.prelu1 = nn.PReLU(rgcn_hidden_dim)

        self.conv2 = RGCNConv(rgcn_hidden_dim, rgcn_hidden_dim, num_relations)
        self.bn2 = nn.BatchNorm1d(rgcn_hidden_dim)
        self.prelu2 = nn.PReLU(rgcn_hidden_dim)

        self.conv3 = RGCNConv(rgcn_hidden_dim, rgcn_output_dim, num_relations)
        self.bn3 = nn.BatchNorm1d(rgcn_output_dim)
        self.prelu3 = nn.PReLU(rgcn_output_dim)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index, edge_attr):
        if edge_attr is None or edge_attr.shape[1] < 5:
            raise ValueError("RGCNModule: edge_attr is missing or has insufficient columns for edge_type.")
        edge_type = edge_attr[:, 4].long()

        x = self.conv1(x, edge_index, edge_type=edge_type)
        x = self.bn1(x)
        x = self.prelu1(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_type=edge_type)
        x = self.bn2(x)
        x = self.prelu2(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index, edge_type=edge_type)
        x = self.bn3(x)
        x = self.prelu3(x)
        return x

# ===========================================================
# 2. RGCN-LSTM 模型定义
# ===========================================================
class RGCNLSTMModelWithHourlyHeads(nn.Module): # Renamed from RGCNGRUModelWithHourlyHeads
    def __init__(self,
                 static_node_in_dim,
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 rgcn_hidden_dim,
                 rgcn_output_dim,
                 num_relations,
                 lstm_hidden_dim, # Renamed from gru_hidden_dim
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 num_lstm_layers=1, # Renamed from num_gru_layers
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 dropout_rate_rgcn=0.3,
                 dropout_rate_lstm=0.2, # Renamed from dropout_rate_gru
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.num_relations = num_relations
        self.rgcn_output_dim = rgcn_output_dim
        self.lstm_hidden_dim = lstm_hidden_dim # Storing for component profiling & consistency


        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)
        # MLP to encode GCN output to match LSTM hidden dim for h0, c0
        self.h0_c0_from_rgcn_encoder = MLPEncoder(rgcn_output_dim, lstm_hidden_dim, dropout_rate=dropout_rate_encoders)

        self.rgcn_module_for_h0 = RGCNModule(static_node_in_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate_rgcn)
        self.rgcn_module_for_sequence = RGCNModule(static_node_in_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate_rgcn)

        concatenated_feature_dim = rgcn_output_dim + global_env_emb_dim + time_emb_dim
        actual_fusion_mlp_output_dim = fusion_mlp_output_dim if fusion_mlp_output_dim is not None else concatenated_feature_dim
        self.fusion_mlp_input_dim = concatenated_feature_dim

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=actual_fusion_mlp_output_dim,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        lstm_input_size_actual = actual_fusion_mlp_output_dim
        self.lstm_input_dim = lstm_input_size_actual # Store for component profiling

        self.lstm = nn.LSTM( # Changed from nn.GRU
            input_size=lstm_input_size_actual,
            hidden_size=lstm_hidden_dim,
            num_layers=num_lstm_layers,
            batch_first=True,
            dropout=dropout_rate_lstm if num_lstm_layers > 1 else 0.0
        )

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(lstm_hidden_dim, mlp_prediction_hidden_dim), # Input from LSTM
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))


    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        normalized_x_7am = (pyg_batch_7am.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

        rgcn_output_7am = self.rgcn_module_for_h0(
            normalized_x_7am,
            pyg_batch_7am.edge_index,
            pyg_batch_7am.edge_attr
        )
        # For LSTM, h0 and c0 are needed. We'll derive h0 from GCN and initialize c0 to zeros.
        h0_features_for_lstm_nodes = self.h0_c0_from_rgcn_encoder(rgcn_output_7am)

        h0_for_lstm = h0_features_for_lstm_nodes.unsqueeze(0) # Shape: (1, N_nodes, lstm_hidden_dim)
        c0_for_lstm = torch.zeros_like(h0_for_lstm) # Shape: (1, N_nodes, lstm_hidden_dim)

        if self.lstm.num_layers > 1:
            h0_for_lstm = h0_for_lstm.repeat(self.lstm.num_layers, 1, 1)
            c0_for_lstm = c0_for_lstm.repeat(self.lstm.num_layers, 1, 1)

        initial_hidden_state = (h0_for_lstm, c0_for_lstm)

        all_lstm_input_features_over_time = [] # Renamed
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            normalized_x = (pyg_batch_this_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

            rgcn_output_nodes_t = self.rgcn_module_for_sequence(
                normalized_x,
                pyg_batch_this_timestep.edge_index,
                pyg_batch_this_timestep.edge_attr
            )

            global_env_feat_t_unencoded = pyg_batch_this_timestep.graph_global_env_features
            expected_num_graphs_in_batch_t = pyg_batch_this_timestep.num_graphs
            expected_global_features_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unencoded.shape == (expected_num_graphs_in_batch_t, expected_global_features_dim)):
                if global_env_feat_t_unencoded.ndim == 1 and \
                   global_env_feat_t_unencoded.shape[0] == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                elif global_env_feat_t_unencoded.numel() == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                else:
                    print(f"Warning: LSTM Input Time {t_pred_idx}: Correcting global_env_feat shape from {global_env_feat_t_unencoded.shape} to ({expected_num_graphs_in_batch_t}, {expected_global_features_dim}) with zeros due to mismatch.")
                    global_env_feat_t_unencoded = torch.zeros(expected_num_graphs_in_batch_t, expected_global_features_dim, device=device)
            global_env_emb_t = self.global_env_encoder(global_env_feat_t_unencoded)
            global_env_emb_t_expanded = global_env_emb_t[pyg_batch_this_timestep.batch]

            current_raw_time_feat_for_timestep_t = timeline_time_features[t_pred_idx, :].to(device)
            current_emb_time_feat_for_timestep_t = self.time_encoder(current_raw_time_feat_for_timestep_t)
            num_nodes_in_pyg_batch = pyg_batch_this_timestep.num_nodes
            time_emb_t_expanded_to_nodes = current_emb_time_feat_for_timestep_t.unsqueeze(0).expand(num_nodes_in_pyg_batch, -1)

            concatenated_features_for_timestep_t = torch.cat([rgcn_output_nodes_t, global_env_emb_t_expanded, time_emb_t_expanded_to_nodes], dim=-1)
            fused_features_for_timestep_t = self.fusion_mlp(concatenated_features_for_timestep_t)
            all_lstm_input_features_over_time.append(fused_features_for_timestep_t)

        stacked_lstm_input_features = torch.stack(all_lstm_input_features_over_time, dim=1) # Renamed

        if initial_hidden_state[0].shape[1] != stacked_lstm_input_features.shape[0]: # Check h0's N_nodes
            print(f"CRITICAL WARNING: Node count mismatch for LSTM h0 ({initial_hidden_state[0].shape[1]}) and LSTM input sequence ({stacked_lstm_input_features.shape[0]}).")
            if initial_hidden_state[0].shape[1] > stacked_lstm_input_features.shape[0]:
                h0_adj = initial_hidden_state[0][:, :stacked_lstm_input_features.shape[0], :]
                c0_adj = initial_hidden_state[1][:, :stacked_lstm_input_features.shape[0], :]
                initial_hidden_state = (h0_adj, c0_adj)

        lstm_out, _ = self.lstm(stacked_lstm_input_features, initial_hidden_state) # Changed from self.gru

        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            lstm_out_t = lstm_out[:, t, :] # Use lstm_out
            prediction_t_scaled = self.hourly_prediction_heads[t](lstm_out_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 (在原始尺度上计算指标) - NO CHANGES
# ===========================================================
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 训练与评估循环 (适配y归一化) - NO CHANGES
# ===========================================================
def train_epoch(model, loader, optimizer, device, timeline_time_features,
                node_feat_mean, node_feat_std, target_mean, target_std):
    model.train()
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    epoch_start_time = time.time()
    for list_of_batched_timesteps in loader:
        optimizer.zero_grad()
        predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)

        targets_list_for_loss_scaled = []
        first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
        mask_for_loss = ~first_predicted_timestep_batch.building_mask

        for t_pred_idx in range(model.T_pred_horizon):
            current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)

        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)
        loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled, mask_for_loss)
        num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

        if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() > 0 :
            loss.backward()
            optimizer.step()
            total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
        num_sequences_processed += num_sequences_in_this_super_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0
    return avg_loss, epoch_duration

def evaluate_epoch(model, loader, device, timeline_time_features,
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    model.eval()
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []
            first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
            mask_for_metrics = ~first_predicted_timestep_batch.building_mask

            for t_pred_idx in range(model.T_pred_horizon):
                current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)
            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)
            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (RGCN-LSTM)
# ===========================================================

def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report


def main_training_rgcn_lstm_hourly_heads( # Renamed
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    expected_input_len = T_PRED_HORIZON + 1
    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None: all_train_node_features_list.append(graph_data.x)
            if i_step > 0 and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0: all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])
    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_rgcn_lstm.pth" # Renamed
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")
    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_rgcn_lstm.pth" # Renamed
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization
    sample_graph_7am_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_7am_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_7am_for_dims.graph_global_env_features.shape[0] if sample_graph_7am_for_dims.graph_global_env_features.ndim == 1 else sample_graph_7am_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]
    num_relations = config.get('num_relations', 5)

    model = RGCNLSTMModelWithHourlyHeads( # Use RGCNLSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        rgcn_hidden_dim=config.get('gcn_hidden_dim', 128), rgcn_output_dim=config.get('gcn_output_dim', 128),
        num_relations=num_relations,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128), # Use existing config key for hidden_dim
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), # Use existing config key for num_layers
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_rgcn=config.get('dropout_rate_gcn', 0.3),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2), # Use existing config key for dropout
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)
    model.node_feat_mean = node_feat_mean.to(device); model.node_feat_std = node_feat_std.to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型总参数量 (RGCN-LSTM): {total_params:,}"); report_data['model_total_parameters'] = total_params

    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500
    dummy_edges_component = 60000
    dummy_batch_global_comp = 1

    model.eval()

    # 1. RGCNModule (No change)
    try:
        rgcn_module_to_profile = model.rgcn_module_for_h0
        dummy_x_rgcn = torch.randn(dummy_nodes_component, rgcn_module_to_profile.rgcn_input_dim, device=device)
        dummy_ei_rgcn = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ea_rgcn = torch.randn(dummy_edges_component, 5, device=device)
        dummy_ea_rgcn[:, 4] = torch.randint(0, model.num_relations, (dummy_edges_component,), device=device).float()
        macs_rgcn = torchprofile.profile_macs(rgcn_module_to_profile, args=(dummy_x_rgcn, dummy_ei_rgcn, dummy_ea_rgcn))
        report_data['component_gmacs']['rgcn_module'] = macs_rgcn / 1e9
        print(f"  RGCNModule GMACs: {macs_rgcn / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling RGCNModule: {e}"); report_data['component_gmacs']['rgcn_module'] = "Error"

    # 2. LSTM Layer (Manual MAC Calculation)
    print(f"  Manually Calculating MACs for LSTM Layer:")
    try:
        lstm_layer = model.lstm # Changed from model.gru
        N_nodes = dummy_nodes_component
        L_seq = T_PRED_HORIZON
        H_in = lstm_layer.input_size
        H_hidden = lstm_layer.hidden_size
        num_layers = lstm_layer.num_layers

        macs_lstm_manual = 0
        # For a single layer LSTM: MACs ≈ N * L * 4 * (H_in * H_hidden + H_hidden^2)
        # (Input, Forget, Cell, Output gates each have similar complexity to a GRU gate part)
        macs_lstm_manual = N_nodes * L_seq * 4 * (H_in * H_hidden + H_hidden * H_hidden) # For the first layer
        if num_layers > 1:
            # Subsequent (num_layers - 1) layers: N * L * 4 * (H_hidden * H_hidden + H_hidden^2)
            # (because input to subsequent layers is H_hidden from the layer below)
            macs_lstm_manual += N_nodes * L_seq * (num_layers - 1) * 4 * (H_hidden * H_hidden + H_hidden * H_hidden)

        gmacs_lstm_manual = macs_lstm_manual / 1e9
        report_data['component_gmacs']['lstm_layer'] = gmacs_lstm_manual # Renamed from gru_layer
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = "Manually calculated based on formula."
        print(f"  LSTM Parameters: input_size={H_in}, hidden_size={H_hidden}, num_layers={num_layers}")
        print(f"  Used for calculation: N_nodes={N_nodes}, L_seq={L_seq}")
        print(f"  LSTM Layer GMACs (Manual): {gmacs_lstm_manual:.4f} (for sequence length {L_seq})")

    except Exception as e:
        print(f"  Error manually calculating LSTM Layer MACs: {e}")
        report_data['component_gmacs']['lstm_layer'] = "Error" # Renamed
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = f"Error during manual calculation: {str(e)}"

    # 3. Fusion MLP (MLPEncoder) - No change
    try:
        fusion_mlp_to_profile = model.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"

    # 4. Prediction Head (one MLP from ModuleList) - Input dim is now lstm_hidden_dim
    try:
        pred_head_to_profile = model.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model.lstm_hidden_dim, device=device) # Uses lstm_hidden_dim
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP (single hour) GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"

    # 5. Global Environment Encoder (MLPEncoder) - No change
    try:
        encoder_to_profile = model.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"

    # 6. Time Encoder (MLPEncoder) - No change
    try:
        encoder_to_profile = model.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.time_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"

    # 7. H0/C0 from RGCN Encoder (MLPEncoder) - Output dim is now lstm_hidden_dim
    try:
        encoder_to_profile = model.h0_c0_from_rgcn_encoder # Renamed internal variable
        dummy_input_encoder = torch.randn(dummy_nodes_component, model.rgcn_output_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['h0_c0_from_rgcn_encoder_mlp'] = macs_encoder / 1e9 # Renamed key
        print(f"  H0/C0 from RGCN Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling H0/C0 from RGCN Encoder: {e}"); report_data['component_gmacs']['h0_c0_from_rgcn_encoder_mlp'] = "Error"

    model.train()
    # ===== End Component FLOPS Calculation =====

    optimizer = torch.optim.Adam(model.parameters(), lr=config.get('lr', 0.001), weight_decay=config.get('weight_decay', 1e-5))
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 20), verbose=True)

    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    model_save_path = Path(config['results_dir']) / f"best_rgcnlstm_hourly_heads_model_seed{seed}.pth" # Renamed
    timeline_time_features_on_device = time_features_for_dataset.to(device)
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        train_loss_scaled, epoch_duration = train_epoch(model, train_loader, optimizer, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_on_device, target_std_on_device)
        epoch_times.append(epoch_duration)
        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(model, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Validation")
        scheduler.step(val_loss_scaled)
        print(f"Epoch {epoch:03d} | Train Scaled MSE: {train_loss_scaled:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Epoch Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")
        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0; torch.save(model.state_dict(), model_save_path)
            print(f"                     ---> Best model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    model_for_eval = RGCNLSTMModelWithHourlyHeads( # Use RGCNLSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        rgcn_hidden_dim=config.get('gcn_hidden_dim', 128), rgcn_output_dim=config.get('gcn_output_dim', 128),
        num_relations=num_relations,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_rgcn=config.get('dropout_rate_gcn', 0.3),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)
    try: model_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
    except Exception as e: print(f"无法加载最佳模型 ({e})，将使用训练循环结束时的模型。"); model_for_eval = model

    print("\n评估最佳模型在训练集上..."); best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(model_for_eval, train_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Train")
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics; report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration; _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original: report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    else:
        _, reeval_val_metrics, val_eval_duration = evaluate_epoch(model_for_eval, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Val (Re-eval)")
        report_data['best_model_validation_set_metrics_hourly'] = reeval_val_metrics
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(reeval_val_metrics, T_PRED_HORIZON)
        report_data['best_model_validation_set_eval_time_seconds'] = val_eval_duration
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳模型在测试集上..."); test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(model_for_eval, test_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test")
    report_data['test_set_inference_time_seconds'] = test_inference_duration; report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original; report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (RGCN-LSTM) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}") # Renamed
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    report_file_path = Path(config['results_dir']) / f"training_report_lstm_seed{seed}.json" # Renamed
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_for_eval, node_feat_mean, node_feat_std, target_mean, target_std


def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")


# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")
    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data") # Assuming same data structure
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl" # Using same dummy data name
    RESULTS_SUBDIR = Path("Result/Final_RGCNLSTM1") # Renamed output subdir
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME

    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12
    # Config keys for GRU ('gru_hidden_dim', 'num_gru_layers', 'dropout_rate_gru')
    # will be used to provide values for the LSTM layer to maintain parameter consistency as requested.
    training_config = {
        'seed': 42, 'batch_size': 8, 'lr': 0.001, 'weight_decay': 1e-5,
        'max_epochs': 1000, 'scheduler_patience': 20, 'early_stopping_patience': 45, # Shortened for quick test
        'T_pred_horizon': PREDICTION_HORIZON, 'results_dir': str(RESULTS_SAVE_DIR),
        'global_env_emb_dim': 16, 'time_emb_dim': 8,
        'gcn_hidden_dim': 128, 'gcn_output_dim': 128, # Example smaller dims for test
        'num_relations': 5,
        'gru_hidden_dim': 128, # This will be lstm_hidden_dim
        'num_gru_layers': 1,  # This will be num_lstm_layers
        'mlp_prediction_hidden_dim': 64,
        'fusion_mlp_output_dim': 128, 'fusion_mlp_hidden_dim': 64,
        'dropout_rate_fusion_mlp': 0.2, 'dropout_rate_encoders': 0.1,
        'dropout_rate_gcn': 0.3,
        'dropout_rate_gru': 0.2,  # This will be dropout_rate_lstm
        'dropout_rate_pred_head': 0.2, 'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2, 'h0_from_first_step': True
    }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")
        expected_len_per_sequence = training_config['T_pred_horizon'] + 1
        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'edge_attr') or graph_step_data.edge_attr is None or \
                   graph_step_data.edge_attr.shape[1] < 5 or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx > 0 and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    valid_seq = False; break
                if step_idx > 0 and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset_timeline = generate_time_features_for_sequence(base_datetime_for_timeline, training_config['T_pred_horizon'])
        # Call the renamed main training function
        trained_model, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_rgcn_lstm_hourly_heads(
            all_graph_sequences, training_config, time_features_for_dataset_timeline
        )
        print("RGCN-LSTM 模型训练和评估完成!") # Updated print
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##RGCN+GRU

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import RGCNConv # ⚡ USE RGCNConv
from torch_geometric.data import Data, Batch # Batch很重要
from torch_geometric.loader import DataLoader
# import matplotlib.pyplot as plt # Not used in the final reporting directly, can be commented if not needed elsewhere
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta # For time feature
from pathlib import Path
import time # Added for timing
import json # Added for report export
import torchprofile # Added for FLOPS calculation

torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

# ===========================================================
# 1. 特征生成 & 辅助模块
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class RGCNModule(nn.Module): # Renamed from GCNModule
    def __init__(self, rgcn_input_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate=0.5):
        super().__init__()
        self.rgcn_input_dim = rgcn_input_dim
        self.rgcn_hidden_dim = rgcn_hidden_dim
        self.rgcn_output_dim = rgcn_output_dim
        self.num_relations = num_relations

        self.conv1 = RGCNConv(rgcn_input_dim, rgcn_hidden_dim, num_relations)
        self.bn1 = nn.BatchNorm1d(rgcn_hidden_dim)
        self.prelu1 = nn.PReLU(rgcn_hidden_dim)

        self.conv2 = RGCNConv(rgcn_hidden_dim, rgcn_hidden_dim, num_relations)
        self.bn2 = nn.BatchNorm1d(rgcn_hidden_dim)
        self.prelu2 = nn.PReLU(rgcn_hidden_dim)

        self.conv3 = RGCNConv(rgcn_hidden_dim, rgcn_output_dim, num_relations)
        self.bn3 = nn.BatchNorm1d(rgcn_output_dim)
        self.prelu3 = nn.PReLU(rgcn_output_dim)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index, edge_attr): # Changed edge_weight to edge_attr
        if edge_attr is None or edge_attr.shape[1] < 5:
            raise ValueError("RGCNModule: edge_attr is missing or has insufficient columns for edge_type.")
        edge_type = edge_attr[:, 4].long()

        x = self.conv1(x, edge_index, edge_type=edge_type)
        x = self.bn1(x)
        x = self.prelu1(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_type=edge_type)
        x = self.bn2(x)
        x = self.prelu2(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index, edge_type=edge_type)
        x = self.bn3(x)
        x = self.prelu3(x)
        return x

# ===========================================================
# 2. RGCN-GRU 模型定义 (包含Fusion MLP, 预测归一化的y)
# ===========================================================
class RGCNGRUModelWithHourlyHeads(nn.Module): # Renamed
    def __init__(self,
                 static_node_in_dim,
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 rgcn_hidden_dim,
                 rgcn_output_dim,
                 num_relations,
                 gru_hidden_dim,
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 num_gru_layers=1,
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 dropout_rate_rgcn=0.3,
                 dropout_rate_gru=0.2,
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.num_relations = num_relations
        self.rgcn_output_dim = rgcn_output_dim # Store for component profiling
        self.gru_hidden_dim = gru_hidden_dim # Store for component profiling


        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)
        self.h0_from_rgcn_encoder = MLPEncoder(rgcn_output_dim, gru_hidden_dim, dropout_rate=dropout_rate_encoders)

        self.rgcn_module_for_h0 = RGCNModule(static_node_in_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate_rgcn)
        self.rgcn_module_for_sequence = RGCNModule(static_node_in_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate_rgcn)

        concatenated_feature_dim = rgcn_output_dim + global_env_emb_dim + time_emb_dim
        actual_fusion_mlp_output_dim = fusion_mlp_output_dim if fusion_mlp_output_dim is not None else concatenated_feature_dim
        self.fusion_mlp_input_dim = concatenated_feature_dim # Store for component profiling

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=actual_fusion_mlp_output_dim,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        gru_input_size_actual = actual_fusion_mlp_output_dim
        self.gru_input_dim = gru_input_size_actual # Store for component profiling

        self.gru = nn.GRU(
            input_size=gru_input_size_actual,
            hidden_size=gru_hidden_dim,
            num_layers=num_gru_layers,
            batch_first=True,
            dropout=dropout_rate_gru if num_gru_layers > 1 else 0.0
        )

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(gru_hidden_dim, mlp_prediction_hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))


    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        normalized_x_7am = (pyg_batch_7am.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

        rgcn_output_7am = self.rgcn_module_for_h0(
            normalized_x_7am,
            pyg_batch_7am.edge_index,
            pyg_batch_7am.edge_attr
        )
        h0_features_for_gru_nodes = self.h0_from_rgcn_encoder(rgcn_output_7am)
        h0_for_gru = h0_features_for_gru_nodes.unsqueeze(0)
        if self.gru.num_layers > 1:
            h0_for_gru = h0_for_gru.repeat(self.gru.num_layers, 1, 1)

        all_gru_input_features_over_time = []
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            normalized_x = (pyg_batch_this_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

            rgcn_output_nodes_t = self.rgcn_module_for_sequence(
                normalized_x,
                pyg_batch_this_timestep.edge_index,
                pyg_batch_this_timestep.edge_attr
            )

            global_env_feat_t_unencoded = pyg_batch_this_timestep.graph_global_env_features
            expected_num_graphs_in_batch_t = pyg_batch_this_timestep.num_graphs
            expected_global_features_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unencoded.shape == (expected_num_graphs_in_batch_t, expected_global_features_dim)):
                if global_env_feat_t_unencoded.ndim == 1 and \
                   global_env_feat_t_unencoded.shape[0] == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                elif global_env_feat_t_unencoded.numel() == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                else:
                    print(f"Warning: GRU Input Time {t_pred_idx}: Correcting global_env_feat shape from {global_env_feat_t_unencoded.shape} to ({expected_num_graphs_in_batch_t}, {expected_global_features_dim}) with zeros due to mismatch.")
                    global_env_feat_t_unencoded = torch.zeros(expected_num_graphs_in_batch_t, expected_global_features_dim, device=device)
            global_env_emb_t = self.global_env_encoder(global_env_feat_t_unencoded)
            global_env_emb_t_expanded = global_env_emb_t[pyg_batch_this_timestep.batch]

            current_raw_time_feat_for_timestep_t = timeline_time_features[t_pred_idx, :].to(device)
            current_emb_time_feat_for_timestep_t = self.time_encoder(current_raw_time_feat_for_timestep_t)
            num_nodes_in_pyg_batch = pyg_batch_this_timestep.num_nodes
            time_emb_t_expanded_to_nodes = current_emb_time_feat_for_timestep_t.unsqueeze(0).expand(num_nodes_in_pyg_batch, -1)

            concatenated_features_for_timestep_t = torch.cat([rgcn_output_nodes_t, global_env_emb_t_expanded, time_emb_t_expanded_to_nodes], dim=-1)
            fused_features_for_timestep_t = self.fusion_mlp(concatenated_features_for_timestep_t)
            all_gru_input_features_over_time.append(fused_features_for_timestep_t)

        stacked_gru_input_features = torch.stack(all_gru_input_features_over_time, dim=1)

        if h0_for_gru.shape[1] != stacked_gru_input_features.shape[0]:
            print(f"CRITICAL WARNING: Node count mismatch for GRU h0 ({h0_for_gru.shape[1]}) and GRU input sequence ({stacked_gru_input_features.shape[0]}).")
            if h0_for_gru.shape[1] > stacked_gru_input_features.shape[0]:
                h0_for_gru = h0_for_gru[:, :stacked_gru_input_features.shape[0], :]

        gru_out, _ = self.gru(stacked_gru_input_features, h0_for_gru)

        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            gru_out_t = gru_out[:, t, :]
            prediction_t_scaled = self.hourly_prediction_heads[t](gru_out_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 (在原始尺度上计算指标)
# ===========================================================
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 训练与评估循环 (适配y归一化)
# ===========================================================
def train_epoch(model, loader, optimizer, device, timeline_time_features,
                node_feat_mean, node_feat_std, target_mean, target_std):
    model.train()
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    epoch_start_time = time.time()
    for list_of_batched_timesteps in loader:
        optimizer.zero_grad()
        predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)

        targets_list_for_loss_scaled = []
        first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
        mask_for_loss = ~first_predicted_timestep_batch.building_mask

        for t_pred_idx in range(model.T_pred_horizon):
            current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)

        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)
        loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled, mask_for_loss)
        num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

        if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() > 0 :
            loss.backward()
            optimizer.step()
            total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
        num_sequences_processed += num_sequences_in_this_super_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0
    return avg_loss, epoch_duration

def evaluate_epoch(model, loader, device, timeline_time_features,
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    model.eval()
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []
            first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
            mask_for_metrics = ~first_predicted_timestep_batch.building_mask

            for t_pred_idx in range(model.T_pred_horizon):
                current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)
            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)
            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (集成y归一化, 使用RGCN, 添加报告)
# ===========================================================

def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report


def main_training_rgcn_gru_hourly_heads(
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    expected_input_len = T_PRED_HORIZON + 1
    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split (same as before)
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation (same as before)
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None: all_train_node_features_list.append(graph_data.x)
            if i_step > 0 and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0: all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])
    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_rgcn_gru.pth"
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")
    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_rgcn_gru.pth"
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders (same as before)
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization (same as before)
    sample_graph_7am_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_7am_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_7am_for_dims.graph_global_env_features.shape[0] if sample_graph_7am_for_dims.graph_global_env_features.ndim == 1 else sample_graph_7am_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]
    num_relations = config.get('num_relations', 5)
    model = RGCNGRUModelWithHourlyHeads( static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
                                        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
                                         rgcn_hidden_dim=config.get('gcn_hidden_dim', 128), rgcn_output_dim=config.get('gcn_output_dim', 128),
                                         num_relations=num_relations, gru_hidden_dim=config.get('gru_hidden_dim', 128),
                                         fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128), fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64), dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2), num_gru_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON, dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1), dropout_rate_rgcn=config.get('dropout_rate_gcn', 0.3), dropout_rate_gru=config.get('dropout_rate_gru', 0.2), mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64), dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2) ).to(device)
    model.node_feat_mean = node_feat_mean.to(device); model.node_feat_std = node_feat_std.to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型总参数量: {total_params:,}"); report_data['model_total_parameters'] = total_params

    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500  # Representative number of nodes for a single graph component
    dummy_edges_component = 60000 # Representative number of edges
    dummy_batch_global_comp = 1 # For components that process per-graph or per-batch features

    model.eval() # Ensure model is in eval mode for profiling if it has dropout/batchnorm

    # 1. RGCNModule
    try:
        rgcn_module_to_profile = model.rgcn_module_for_h0 # Or _for_sequence
        dummy_x_rgcn = torch.randn(dummy_nodes_component, rgcn_module_to_profile.rgcn_input_dim, device=device)
        dummy_ei_rgcn = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ea_rgcn = torch.randn(dummy_edges_component, 5, device=device) # Assuming 5 edge features
        dummy_ea_rgcn[:, 4] = torch.randint(0, model.num_relations, (dummy_edges_component,), device=device).float()
        macs_rgcn = torchprofile.profile_macs(rgcn_module_to_profile, args=(dummy_x_rgcn, dummy_ei_rgcn, dummy_ea_rgcn))
        report_data['component_gmacs']['rgcn_module'] = macs_rgcn / 1e9
        print(f"  RGCNModule GMACs: {macs_rgcn / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling RGCNModule: {e}"); report_data['component_gmacs']['rgcn_module'] = "Error"

    # 2. GRU Layer
    print(f"  Manually Calculating MACs for GRU Layer:")
    try:
        gru_layer = model.gru
        N_nodes = dummy_nodes_component # Number of nodes (batch items for GRU)
        L_seq = T_PRED_HORIZON        # Sequence length
        H_in = gru_layer.input_size
        H_hidden = gru_layer.hidden_size
        num_layers = gru_layer.num_layers

        macs_gru_manual = 0
        if num_layers == 1:
            # For a single layer GRU: MACs ≈ N * L * 3 * (H_in * H_hidden + H_hidden^2)
            macs_gru_manual = N_nodes * L_seq * 3 * (H_in * H_hidden + H_hidden * H_hidden)
        else:
            # For multi-layer GRU:
            # First layer: N * L * 3 * (H_in * H_hidden + H_hidden^2)
            # Subsequent (num_layers - 1) layers: N * L * 3 * (H_hidden * H_hidden + H_hidden^2)
            #   (because input to subsequent layers is H_hidden)
            macs_gru_manual = N_nodes * L_seq * 3 * (H_in * H_hidden + H_hidden * H_hidden)
            if num_layers > 1:
                macs_gru_manual += N_nodes * L_seq * (num_layers - 1) * 3 * (H_hidden * H_hidden + H_hidden * H_hidden)

        gmacs_gru_manual = macs_gru_manual / 1e9
        report_data['component_gmacs']['gru_layer'] = gmacs_gru_manual
        report_data['component_gmacs']['gru_layer_profiling_notes'] = "Manually calculated based on formula."
        print(f"  GRU Parameters: input_size={H_in}, hidden_size={H_hidden}, num_layers={num_layers}")
        print(f"  Used for calculation: N_nodes={N_nodes}, L_seq={L_seq}")
        print(f"  GRU Layer GMACs (Manual): {gmacs_gru_manual:.4f} (for sequence length {L_seq})")

    except Exception as e:
        print(f"  Error manually calculating GRU Layer MACs: {e}")
        report_data['component_gmacs']['gru_layer'] = "Error"
        report_data['component_gmacs']['gru_layer_profiling_notes'] = f"Error during manual calculation: {str(e)}"

    # 3. Fusion MLP (MLPEncoder)
    try:
        fusion_mlp_to_profile = model.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"

    # 4. Prediction Head (one MLP from ModuleList)
    try:
        pred_head_to_profile = model.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model.gru_hidden_dim, device=device)
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP (single hour) GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"

    # 5. Global Environment Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"

    # 6. Time Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.time_in_dim, device=device) # Input usually (1, time_in_dim)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"

    # 7. H0 from RGCN Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.h0_from_rgcn_encoder
        dummy_input_encoder = torch.randn(dummy_nodes_component, model.rgcn_output_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['h0_from_rgcn_encoder_mlp'] = macs_encoder / 1e9
        print(f"  H0 from RGCN Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling H0 from RGCN Encoder: {e}"); report_data['component_gmacs']['h0_from_rgcn_encoder_mlp'] = "Error"

    model.train() # Return to train mode if changed
    # ===== End Component FLOPS Calculation =====

    # Optimizer and Scheduler (same as before)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.get('lr', 0.001), weight_decay=config.get('weight_decay', 1e-5))
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 20), verbose=True)

    # Training Loop (same as before, with metric storage)
    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    model_save_path = Path(config['results_dir']) / f"best_rgcngru_hourly_heads_model_seed{seed}.pth"
    timeline_time_features_on_device = time_features_for_dataset.to(device)
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        train_loss_scaled, epoch_duration = train_epoch(model, train_loader, optimizer, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_on_device, target_std_on_device)
        epoch_times.append(epoch_duration)
        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(model, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Validation")
        scheduler.step(val_loss_scaled)
        print(f"Epoch {epoch:03d} | Train Scaled MSE: {train_loss_scaled:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Epoch Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")
        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0; torch.save(model.state_dict(), model_save_path)
            print(f"                     ---> Best model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    # Evaluation with Best Model (same as before)
    model_for_eval = RGCNGRUModelWithHourlyHeads( static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim, global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8), rgcn_hidden_dim=config.get('gcn_hidden_dim', 128), rgcn_output_dim=config.get('gcn_output_dim', 128), num_relations=num_relations, gru_hidden_dim=config.get('gru_hidden_dim', 128), fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128), fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64), dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2), num_gru_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON, dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1), dropout_rate_rgcn=config.get('dropout_rate_gcn', 0.3), dropout_rate_gru=config.get('dropout_rate_gru', 0.2), mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64), dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2) ).to(device)
    try: model_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
    except Exception as e: print(f"无法加载最佳模型 ({e})，将使用训练循环结束时的模型。"); model_for_eval = model

    print("\n评估最佳模型在训练集上..."); best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(model_for_eval, train_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Train")
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics; report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration; _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original: report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    else: # Fallback if training didn't run/save
        _, reeval_val_metrics, val_eval_duration = evaluate_epoch(model_for_eval, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Val (Re-eval)")
        report_data['best_model_validation_set_metrics_hourly'] = reeval_val_metrics
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(reeval_val_metrics, T_PRED_HORIZON)
        report_data['best_model_validation_set_eval_time_seconds'] = val_eval_duration
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳模型在测试集上..."); test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(model_for_eval, test_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test")
    report_data['test_set_inference_time_seconds'] = test_inference_duration; report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original; report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (RGCN-GRU) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}")
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    # Save Report (same as before)
    report_file_path = Path(config['results_dir']) / f"training_report_seed{seed}.json"
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_for_eval, node_feat_mean, node_feat_std, target_mean, target_std


def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    # Print aggregated for this set as well
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")


# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")
    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data")
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR = Path("Result/Final_RGCNGRU1")
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME

    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12
    training_config = { 'seed': 42, 'batch_size': 8, 'lr': 0.001, 'weight_decay': 1e-5, 'max_epochs': 1000, 'scheduler_patience': 20,
                       'early_stopping_patience': 45, 'T_pred_horizon': PREDICTION_HORIZON, 'results_dir': str(RESULTS_SAVE_DIR),
                        'global_env_emb_dim': 16, 'time_emb_dim': 8,
                        'gcn_hidden_dim': 128, 'gcn_output_dim': 128, 'num_relations': 5, 'gru_hidden_dim': 128,
                        'num_gru_layers': 1, 'mlp_prediction_hidden_dim': 64, 'fusion_mlp_output_dim': 128, 'fusion_mlp_hidden_dim': 64,
                        'dropout_rate_fusion_mlp': 0.2, 'dropout_rate_encoders': 0.1, 'dropout_rate_gcn': 0.3, 'dropout_rate_gru': 0.2,
                        'dropout_rate_pred_head': 0.2, 'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
                        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2, 'h0_from_first_step': True }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")
        expected_len_per_sequence = training_config['T_pred_horizon'] + 1
        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'edge_attr') or graph_step_data.edge_attr is None or \
                   graph_step_data.edge_attr.shape[1] < 5 or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx > 0 and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    valid_seq = False; break
                if step_idx > 0 and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset_timeline = generate_time_features_for_sequence(base_datetime_for_timeline, training_config['T_pred_horizon'])
        trained_model, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_rgcn_gru_hourly_heads(all_graph_sequences, training_config, time_features_for_dataset_timeline)
        print("模型训练和评估完成!")
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##RGCN+Transformer

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import RGCNConv
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta
from pathlib import Path
import time
import json
import torchprofile

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


# ===========================================================
# 1. 特征生成 & 辅助模块
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class RGCNModule(nn.Module):
    def __init__(self, rgcn_input_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate=0.5):
        super().__init__()
        self.rgcn_input_dim = rgcn_input_dim
        self.rgcn_hidden_dim = rgcn_hidden_dim
        self.rgcn_output_dim = rgcn_output_dim
        self.num_relations = num_relations

        self.conv1 = RGCNConv(rgcn_input_dim, rgcn_hidden_dim, num_relations)
        self.bn1 = nn.BatchNorm1d(rgcn_hidden_dim)
        self.prelu1 = nn.PReLU(rgcn_hidden_dim)

        self.conv2 = RGCNConv(rgcn_hidden_dim, rgcn_hidden_dim, num_relations)
        self.bn2 = nn.BatchNorm1d(rgcn_hidden_dim)
        self.prelu2 = nn.PReLU(rgcn_hidden_dim)

        self.conv3 = RGCNConv(rgcn_hidden_dim, rgcn_output_dim, num_relations)
        self.bn3 = nn.BatchNorm1d(rgcn_output_dim)
        self.prelu3 = nn.PReLU(rgcn_output_dim)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index, edge_attr):
        if x.size(0) == 0: # Handle empty graphs if they occur
            return x
        if edge_attr is None or edge_attr.shape[1] < 5:
            raise ValueError("RGCNModule: edge_attr is missing or has insufficient columns for edge_type.")
        edge_type = edge_attr[:, 4].long()

        x = self.conv1(x, edge_index, edge_type=edge_type)
        if x.size(0) > 1: x = self.bn1(x)
        x = self.prelu1(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_type=edge_type)
        if x.size(0) > 1: x = self.bn2(x)
        x = self.prelu2(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index, edge_type=edge_type)
        if x.size(0) > 1: x = self.bn3(x)
        x = self.prelu3(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 != 0: # Odd d_model case
             pe[:, 1::2] = torch.cos(position * div_term[:-1]) # Use div_term up to second to last
        else:
             pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)  # Shape: (max_len, d_model)

    def forward(self, x):
        """
        x: Tensor, shape [sequence_length, batch_size, feature_dim] (for batch_first=False Transformer)
        self.pe shape is [max_len, feature_dim]
        """
        # self.pe[:x.size(0), :] slices to (sequence_length, feature_dim)
        # .unsqueeze(1) changes it to (sequence_length, 1, feature_dim) for broadcasting
        x = x + self.pe[:x.size(0), :].unsqueeze(1)
        return self.dropout(x)

# ===========================================================
# 2. RGCN-Transformer 模型定义
# ===========================================================
class RGCNTransformerModelWithHourlyHeads(nn.Module):
    def __init__(self,
                 static_node_in_dim,
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 rgcn_hidden_dim,
                 rgcn_output_dim,
                 num_relations,
                 transformer_d_model,
                 transformer_nhead,
                 transformer_num_encoder_layers,
                 transformer_num_decoder_layers,
                 transformer_dim_feedforward,
                 transformer_dropout_rate,
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 dropout_rate_rgcn=0.3,
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2,
                 num_encoder_obs_steps=1
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.num_relations = num_relations
        self.rgcn_output_dim = rgcn_output_dim
        self.num_encoder_obs_steps = num_encoder_obs_steps


        concatenated_feature_dim = rgcn_output_dim + global_env_emb_dim + time_emb_dim
        # If fusion_mlp_output_dim is specified, it should match transformer_d_model
        # If not specified, concatenated_feature_dim should ideally be transformer_d_model,
        # or fusion_mlp makes it transformer_d_model.
        if fusion_mlp_output_dim is not None and fusion_mlp_output_dim != transformer_d_model:
            print(f"Warning: fusion_mlp_output_dim ({fusion_mlp_output_dim}) "
                  f"does not match transformer_d_model ({transformer_d_model}). "
                  f"Setting fusion_mlp_output_dim to transformer_d_model.")
            actual_fusion_mlp_output_dim = transformer_d_model
        elif fusion_mlp_output_dim is None:
             actual_fusion_mlp_output_dim = transformer_d_model # Fusion MLP will output d_model
        else: # fusion_mlp_output_dim == transformer_d_model
            actual_fusion_mlp_output_dim = fusion_mlp_output_dim

        self.d_model = transformer_d_model
        self.fusion_mlp_input_dim = concatenated_feature_dim
        self.transformer_input_dim = self.d_model


        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)

        self.rgcn_module_for_encoder_inputs = RGCNModule(static_node_in_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate_rgcn)
        self.rgcn_module_for_decoder_inputs = RGCNModule(static_node_in_dim, rgcn_hidden_dim, rgcn_output_dim, num_relations, dropout_rate_rgcn)

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=self.d_model,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        self.pos_encoder = PositionalEncoding(self.d_model, transformer_dropout_rate, max_len=T_pred_horizon + num_encoder_obs_steps + 5) # Extra buffer for max_len

        encoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=transformer_nhead,
                                                 dim_feedforward=transformer_dim_feedforward,
                                                 dropout=transformer_dropout_rate, batch_first=False)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model=self.d_model, nhead=transformer_nhead,
                                                 dim_feedforward=transformer_dim_feedforward,
                                                 dropout=transformer_dropout_rate, batch_first=False)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=transformer_num_decoder_layers)

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(self.d_model, mlp_prediction_hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))

    def _generate_square_subsequent_mask(self, sz, device):
        mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _process_one_timestep_input(self, pyg_batch_timestep, rgcn_module, time_feature_for_step, device, step_type="generic"):
        normalized_x = (pyg_batch_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)
        rgcn_output_nodes = rgcn_module(
            normalized_x,
            pyg_batch_timestep.edge_index,
            pyg_batch_timestep.edge_attr
        )

        global_env_feat_unencoded = pyg_batch_timestep.graph_global_env_features
        expected_num_graphs = pyg_batch_timestep.num_graphs
        expected_global_dim = self.global_env_encoder.mlp[0].in_features # Get in_dim from the layer

        if not (global_env_feat_unencoded.shape == (expected_num_graphs, expected_global_dim)):
             if global_env_feat_unencoded.numel() == expected_num_graphs * expected_global_dim:
                global_env_feat_unencoded = global_env_feat_unencoded.view(expected_num_graphs, expected_global_dim)
             else:
                # print(f"Warning ({step_type}): Correcting global_env_feat shape from {global_env_feat_unencoded.shape} to ({expected_num_graphs}, {expected_global_dim}) with zeros.")
                global_env_feat_unencoded = torch.zeros(expected_num_graphs, expected_global_dim, device=device)

        global_env_emb = self.global_env_encoder(global_env_feat_unencoded)
        global_env_emb_expanded = global_env_emb[pyg_batch_timestep.batch] # Expand to node level

        current_emb_time_feat = self.time_encoder(time_feature_for_step.to(device))
        num_nodes_in_batch = pyg_batch_timestep.num_nodes
        time_emb_expanded_to_nodes = current_emb_time_feat.unsqueeze(0).expand(num_nodes_in_batch, -1)

        concatenated_features = torch.cat([rgcn_output_nodes, global_env_emb_expanded, time_emb_expanded_to_nodes], dim=-1)
        fused_features = self.fusion_mlp(concatenated_features)
        return fused_features

    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        # timeline_time_features: (T_pred_horizon, time_feat_dim) for decoder steps
        # list_of_batched_timesteps: Length NUM_ENCODER_OBS_STEPS + T_PRED_HORIZON
        # First NUM_ENCODER_OBS_STEPS are for encoder, rest for decoder targets/inputs

        # 1. Prepare Encoder Input
        encoder_input_fused_features_list = []
        for i in range(self.num_encoder_obs_steps):
            pyg_batch_enc_step = list_of_batched_timesteps[i].to(device)
            # We need time features for encoder steps.
            # If timeline_time_features is only for decoder, we make a placeholder or require them to be passed.
            # For this iteration, let's assume timeline_time_features needs to be extended/pre-pended if obs steps > 0
            # For num_encoder_obs_steps=1, we'll use a placeholder or derive one.
            # The simplest for now is to pass a dummy time feature for these encoder steps.
            # A more robust solution would be to have time features for all input steps.
            if i < timeline_time_features.size(0) : # If timeline features also cover obs steps
                 current_time_feat_for_enc = timeline_time_features[i, :]
            else: # Fallback if timeline_time_features is only for prediction horizon
                 # This case needs to be handled based on how data is prepared.
                 # Using a zero tensor as a placeholder if obs steps are not covered by timeline_time_features
                 # print(f"Warning: Using placeholder time feature for encoder step {i}")
                 current_time_feat_for_enc = torch.zeros(self.time_in_dim, device=device)


            fused_enc_step = self._process_one_timestep_input(
                pyg_batch_enc_step, self.rgcn_module_for_encoder_inputs,
                current_time_feat_for_enc, device, step_type=f"encoder_step_{i}"
            ) # (N_nodes_in_batch_i, d_model)
            encoder_input_fused_features_list.append(fused_enc_step)

        # Stack along sequence length for Transformer: (L_enc, N_nodes_total, d_model)
        src = torch.stack(encoder_input_fused_features_list, dim=0)
        src = self.pos_encoder(src)
        memory = self.transformer_encoder(src) # Shape: (L_enc, N_nodes_total, d_model)

        # 2. Prepare Decoder Inputs (for teacher forcing during training)
        all_decoder_input_fused_features_list = []
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[self.num_encoder_obs_steps + t_pred_idx].to(device)
            # timeline_time_features is already aligned with prediction horizon
            current_time_feature = timeline_time_features[t_pred_idx, :]

            fused_features_step_t = self._process_one_timestep_input(
                pyg_batch_this_timestep, self.rgcn_module_for_decoder_inputs,
                current_time_feature, device, step_type=f"decoder_step_{t_pred_idx}"
            ) # (N_nodes_in_batch_t, d_model)
            all_decoder_input_fused_features_list.append(fused_features_step_t)

        tgt = torch.stack(all_decoder_input_fused_features_list, dim=0) # (T_pred_horizon, N_nodes_total, d_model)
        tgt = self.pos_encoder(tgt)

        tgt_mask = self._generate_square_subsequent_mask(self.T_pred_horizon, device)

        decoder_output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
        # decoder_output shape: (T_pred_horizon, N_nodes_total, d_model)

        # 3. Prediction Heads
        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            decoder_output_t = decoder_output[t, :, :]
            prediction_t_scaled = self.hourly_prediction_heads[t](decoder_output_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1) # (N_nodes_total, T_pred_horizon)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 - NO CHANGES
# ===========================================================
# ... (mse_loss_masked, calculate_hourly_metrics are identical) ...
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 训练与评估循环 - Train/Eval loops need to use correct target indices
# ===========================================================
def train_epoch(model, loader, optimizer, device, timeline_time_features,
                node_feat_mean, node_feat_std, target_mean, target_std):
    model.train()
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    epoch_start_time = time.time()
    for list_of_batched_timesteps in loader: # Each element is a list of Batch objects for one super-sample
        optimizer.zero_grad()
        # timeline_time_features corresponds to the T_PRED_HORIZON steps for the decoder
        predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)

        targets_list_for_loss_scaled = []
        # Mask is based on the first *prediction* step's graph data
        # If num_encoder_obs_steps=1, this is list_of_batched_timesteps[1]
        # If num_encoder_obs_steps=k, this is list_of_batched_timesteps[k]
        first_prediction_graph_step_in_batch = list_of_batched_timesteps[model.num_encoder_obs_steps].to(device)
        mask_for_loss = ~first_prediction_graph_step_in_batch.building_mask

        for t_pred_idx in range(model.T_pred_horizon):
            # Targets come from the prediction horizon part of the sequence
            current_target_timestep_batch = list_of_batched_timesteps[model.num_encoder_obs_steps + t_pred_idx].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)

        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)
        loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled, mask_for_loss)
        num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs # num_graphs from first element (e.g. obs step)

        if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() > 0 :
            loss.backward()
            optimizer.step()
            total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
        num_sequences_processed += num_sequences_in_this_super_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0
    return avg_loss, epoch_duration

def evaluate_epoch(model, loader, device, timeline_time_features,
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    model.eval()
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []

            first_prediction_graph_step_in_batch = list_of_batched_timesteps[model.num_encoder_obs_steps].to(device)
            mask_for_metrics = ~first_prediction_graph_step_in_batch.building_mask

            for t_pred_idx in range(model.T_pred_horizon):
                current_target_timestep_batch = list_of_batched_timesteps[model.num_encoder_obs_steps + t_pred_idx].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)
            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)
            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (RGCN-Transformer)
# ===========================================================

def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    # ... (identical to previous version) ...
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report

def main_training_rgcn_transformer_hourly_heads( # Renamed
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor # This is for decoder input steps
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    NUM_ENCODER_OBS_STEPS = config.get('num_encoder_obs_steps', 1)
    expected_input_len = NUM_ENCODER_OBS_STEPS + T_PRED_HORIZON

    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                all_train_node_features_list.append(graph_data.x)
            if i_step >= NUM_ENCODER_OBS_STEPS and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0:
                    all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])

    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_rgcn_transformer.pth"
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")

    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_rgcn_transformer.pth"
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization
    sample_graph_encoder_input_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_encoder_input_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_encoder_input_for_dims.graph_global_env_features.shape[0] if sample_graph_encoder_input_for_dims.graph_global_env_features.ndim == 1 else sample_graph_encoder_input_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]
    num_relations = config.get('num_relations', 5)

    d_model = config.get('transformer_d_model', config.get('gru_hidden_dim', 128))
    n_head = config.get('transformer_nhead', 4)
    num_enc_layers = config.get('transformer_num_encoder_layers', 2)
    num_dec_layers = config.get('transformer_num_decoder_layers', 2)
    dim_ff = config.get('transformer_dim_feedforward', d_model * 4)
    transformer_dropout = config.get('transformer_dropout_rate', config.get('dropout_rate_gru', 0.2))
    fusion_mlp_out_dim = config.get('fusion_mlp_output_dim', d_model)
    fusion_mlp_hid_dim = config.get('fusion_mlp_hidden_dim', d_model // 2 if d_model else 64)


    model = RGCNTransformerModelWithHourlyHeads(
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        rgcn_hidden_dim=config.get('gcn_hidden_dim', 128), rgcn_output_dim=config.get('gcn_output_dim', 128),
        num_relations=num_relations,
        transformer_d_model=d_model,
        transformer_nhead=n_head,
        transformer_num_encoder_layers=num_enc_layers,
        transformer_num_decoder_layers=num_dec_layers,
        transformer_dim_feedforward=dim_ff,
        transformer_dropout_rate=transformer_dropout,
        fusion_mlp_output_dim=fusion_mlp_out_dim,
        fusion_mlp_hidden_dim=fusion_mlp_hid_dim,
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_rgcn=config.get('dropout_rate_gcn', 0.3),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2),
        num_encoder_obs_steps=NUM_ENCODER_OBS_STEPS
    ).to(device)

    model.node_feat_mean = node_feat_mean.to(device); model.node_feat_std = node_feat_std.to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型总参数量 (RGCN-Transformer): {total_params:,}"); report_data['model_total_parameters'] = total_params

    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500
    dummy_edges_component = 60000
    dummy_batch_global_comp = 1
    model.eval()

    # 1. RGCNModule
    try:
        rgcn_module_to_profile = model.rgcn_module_for_encoder_inputs
        dummy_x_rgcn = torch.randn(dummy_nodes_component, rgcn_module_to_profile.rgcn_input_dim, device=device)
        dummy_ei_rgcn = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ea_rgcn = torch.randn(dummy_edges_component, 5, device=device)
        dummy_ea_rgcn[:, 4] = torch.randint(0, model.num_relations, (dummy_edges_component,), device=device).float()
        macs_rgcn = torchprofile.profile_macs(rgcn_module_to_profile, args=(dummy_x_rgcn, dummy_ei_rgcn, dummy_ea_rgcn))
        report_data['component_gmacs']['rgcn_module'] = macs_rgcn / 1e9
        print(f"  RGCNModule GMACs: {macs_rgcn / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling RGCNModule: {e}"); report_data['component_gmacs']['rgcn_module'] = "Error"

    # 2. Transformer Encoder Layer (Manual MAC Calculation for one layer)
    print(f"  Manually Calculating MACs for one Transformer Encoder Layer:")
    try:
        N_nodes_trans = dummy_nodes_component
        L_enc_seq_trans = NUM_ENCODER_OBS_STEPS

        macs_self_attn_enc = N_nodes_trans * L_enc_seq_trans * (4 * d_model * d_model)
        macs_ffn_enc = N_nodes_trans * L_enc_seq_trans * ( (d_model * dim_ff) + (dim_ff * d_model) )
        macs_one_encoder_layer_manual = macs_self_attn_enc + macs_ffn_enc

        gmacs_one_encoder_layer_manual = macs_one_encoder_layer_manual / 1e9
        report_data['component_gmacs']['transformer_encoder_layer_one'] = gmacs_one_encoder_layer_manual
        report_data['component_gmacs']['transformer_encoder_layer_profiling_notes'] = "Manually calculated (SelfAttn(4d^2NL) + FFN(2d*d_ff*NL))."
        print(f"  Transformer Encoder Layer (One): d_model={d_model}, n_head={n_head}, d_ff={dim_ff}")
        print(f"  Used for calculation: N_nodes={N_nodes_trans}, L_enc_seq={L_enc_seq_trans}")
        print(f"  One Encoder Layer GMACs (Manual): {gmacs_one_encoder_layer_manual:.4f}")
        total_encoder_gmacs = num_enc_layers * gmacs_one_encoder_layer_manual
        report_data['component_gmacs']['transformer_encoder_total'] = total_encoder_gmacs
        print(f"  Total Transformer Encoder ({num_enc_layers} layers) GMACs (Manual): {total_encoder_gmacs:.4f}")
    except Exception as e:
        print(f"  Error manually calculating Transformer Encoder Layer MACs: {e}")
        report_data['component_gmacs']['transformer_encoder_layer_one'] = "Error"; report_data['component_gmacs']['transformer_encoder_total'] = "Error"
        report_data['component_gmacs']['transformer_encoder_layer_profiling_notes'] = f"Error: {str(e)}"

    # 3. Transformer Decoder Layer (Manual MAC Calculation for one layer)
    print(f"  Manually Calculating MACs for one Transformer Decoder Layer:")
    try:
        N_nodes_trans = dummy_nodes_component
        L_dec_seq_trans = T_PRED_HORIZON
        # L_enc_seq_mem_trans = NUM_ENCODER_OBS_STEPS # Length of memory from encoder

        macs_self_attn_dec = N_nodes_trans * L_dec_seq_trans * (4 * d_model * d_model)
        macs_cross_attn_dec = N_nodes_trans * L_dec_seq_trans * (4 * d_model * d_model)
        macs_ffn_dec = N_nodes_trans * L_dec_seq_trans * ( (d_model * dim_ff) + (dim_ff * d_model) )
        macs_one_decoder_layer_manual = macs_self_attn_dec + macs_cross_attn_dec + macs_ffn_dec

        gmacs_one_decoder_layer_manual = macs_one_decoder_layer_manual / 1e9
        report_data['component_gmacs']['transformer_decoder_layer_one'] = gmacs_one_decoder_layer_manual
        report_data['component_gmacs']['transformer_decoder_layer_profiling_notes'] = "Manually calculated (SelfAttn(4d^2NL) + CrossAttn(4d^2NL) + FFN(2d*d_ff*NL))."
        print(f"  Transformer Decoder Layer (One): d_model={d_model}, n_head={n_head}, d_ff={dim_ff}")
        print(f"  Used for calculation: N_nodes={N_nodes_trans}, L_dec_seq={L_dec_seq_trans}")
        print(f"  One Decoder Layer GMACs (Manual): {gmacs_one_decoder_layer_manual:.4f}")
        total_decoder_gmacs = num_dec_layers * gmacs_one_decoder_layer_manual
        report_data['component_gmacs']['transformer_decoder_total'] = total_decoder_gmacs
        print(f"  Total Transformer Decoder ({num_dec_layers} layers) GMACs (Manual): {total_decoder_gmacs:.4f}")
    except Exception as e:
        print(f"  Error manually calculating Transformer Decoder Layer MACs: {e}")
        report_data['component_gmacs']['transformer_decoder_layer_one'] = "Error"; report_data['component_gmacs']['transformer_decoder_total'] = "Error"
        report_data['component_gmacs']['transformer_decoder_layer_profiling_notes'] = f"Error: {str(e)}"

    # 4. Fusion MLP
    try:
        fusion_mlp_to_profile = model.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"

    # 5. Prediction Head
    try:
        pred_head_to_profile = model.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model.d_model, device=device)
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP (single hour) GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"

    # 6. Global Environment Encoder
    try:
        encoder_to_profile = model.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"

    # 7. Time Encoder
    try:
        encoder_to_profile = model.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.time_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"

    model.train()
    # ===== End Component FLOPS Calculation =====

    optimizer = torch.optim.Adam(model.parameters(), lr=config.get('lr', 0.001), weight_decay=config.get('weight_decay', 1e-5))
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 20), verbose=True)

    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    model_save_path = Path(config['results_dir']) / f"best_rgcntransformer_hourly_heads_model_seed{seed}.pth"
    timeline_time_features_on_device = time_features_for_dataset.to(device) # This is for decoder steps
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        train_loss_scaled, epoch_duration = train_epoch(model, train_loader, optimizer, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_on_device, target_std_on_device)
        epoch_times.append(epoch_duration)
        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(model, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Validation")
        scheduler.step(val_loss_scaled)
        print(f"Epoch {epoch:03d} | Train Scaled MSE: {train_loss_scaled:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Epoch Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")
        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0; torch.save(model.state_dict(), model_save_path)
            print(f"                     ---> Best model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    model_for_eval = RGCNTransformerModelWithHourlyHeads(
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        rgcn_hidden_dim=config.get('gcn_hidden_dim', 128), rgcn_output_dim=config.get('gcn_output_dim', 128),
        num_relations=num_relations,
        transformer_d_model=d_model, transformer_nhead=n_head,
        transformer_num_encoder_layers=num_enc_layers, transformer_num_decoder_layers=num_dec_layers,
        transformer_dim_feedforward=dim_ff, transformer_dropout_rate=transformer_dropout,
        fusion_mlp_output_dim=fusion_mlp_out_dim,
        fusion_mlp_hidden_dim=fusion_mlp_hid_dim,
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_rgcn=config.get('dropout_rate_gcn', 0.3),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2),
        num_encoder_obs_steps=NUM_ENCODER_OBS_STEPS
    ).to(device)
    try: model_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
    except Exception as e: print(f"无法加载最佳模型 ({e})，将使用训练循环结束时的模型。"); model_for_eval = model

    print("\n评估最佳模型在训练集上..."); best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(model_for_eval, train_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Train")
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics; report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration; _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original: report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    else:
        _, reeval_val_metrics, val_eval_duration = evaluate_epoch(model_for_eval, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Val (Re-eval)")
        report_data['best_model_validation_set_metrics_hourly'] = reeval_val_metrics
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(reeval_val_metrics, T_PRED_HORIZON)
        report_data['best_model_validation_set_eval_time_seconds'] = val_eval_duration
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳模型在测试集上..."); test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(model_for_eval, test_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test")
    report_data['test_set_inference_time_seconds'] = test_inference_duration; report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original; report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (RGCN-Transformer) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}")
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    report_file_path = Path(config['results_dir']) / f"training_report_transformer_seed{seed}.json"
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_for_eval, node_feat_mean, node_feat_std, target_mean, target_std

def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    # ... (identical to previous version) ...
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")

# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")
    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data")
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR = Path("Result/Final_RGCNTransformer1")
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME


    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12

    training_config = {
        'seed': 42, 'batch_size': 8, 'lr': 0.001, 'weight_decay': 1e-5,
        'max_epochs': 1000, 'scheduler_patience': 20, 'early_stopping_patience': 45,
        'T_pred_horizon': PREDICTION_HORIZON, 'results_dir': str(RESULTS_SAVE_DIR),
        'global_env_emb_dim': 16, 'time_emb_dim': 8,
        'gcn_hidden_dim': 128, 'gcn_output_dim': 128,
        'num_relations': 5,
        # GRU/LSTM keys will be used for Transformer d_model and dropout
        'gru_hidden_dim': 128, # Used for transformer_d_model
        'dropout_rate_gru': 0.2,  # Used for transformer_dropout_rate
        # New Transformer specific keys (can add more if needed, or use defaults in model)
        'transformer_nhead': 4,
        'transformer_num_encoder_layers': 2,
        'transformer_num_decoder_layers': 2,
        'transformer_dim_feedforward': 512, # e.g., 2 * d_model for this test
        'num_encoder_obs_steps': 1, # Define how many historical steps encoder sees

        'mlp_prediction_hidden_dim': 64,
        'fusion_mlp_output_dim': 128, # Should match transformer_d_model
        'fusion_mlp_hidden_dim': 64,
        'dropout_rate_fusion_mlp': 0.2, 'dropout_rate_encoders': 0.1,
        'dropout_rate_gcn': 0.3,
        'dropout_rate_pred_head': 0.2,
        'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2,
        'h0_from_first_step': True # This config is less relevant for Transformer
    }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")

        num_encoder_obs_steps_from_config = training_config.get('num_encoder_obs_steps', 1)
        expected_len_per_sequence = num_encoder_obs_steps_from_config + training_config['T_pred_horizon']

        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'edge_attr') or graph_step_data.edge_attr is None or \
                   graph_step_data.edge_attr.shape[1] < 5 or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx >= num_encoder_obs_steps_from_config and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    valid_seq = False; break
                if step_idx >= num_encoder_obs_steps_from_config and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_decoder_timeline = generate_time_features_for_sequence(
            base_datetime_for_timeline,
            training_config['T_pred_horizon']
        )

        trained_model, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_rgcn_transformer_hourly_heads(
            all_graph_sequences, training_config, time_features_for_decoder_timeline
        )
        print("RGCN-Transformer 模型训练和评估完成!")
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##GCN+LSTM

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta
from pathlib import Path
import time
import json
import torchprofile

torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

# ===========================================================
# 1. 特征生成 & 辅助模块 (No changes)
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class GCNModule(nn.Module):
    def __init__(self, gcn_input_dim, gcn_hidden_dim, gcn_output_dim, dropout_rate=0.5):
        super().__init__()
        self.gcn_input_dim = gcn_input_dim
        self.gcn_hidden_dim = gcn_hidden_dim
        self.gcn_output_dim = gcn_output_dim

        self.conv1 = GCNConv(gcn_input_dim, gcn_hidden_dim)
        self.bn1 = nn.BatchNorm1d(gcn_hidden_dim)
        self.prelu1 = nn.PReLU(gcn_hidden_dim)

        self.conv2 = GCNConv(gcn_hidden_dim, gcn_hidden_dim)
        self.bn2 = nn.BatchNorm1d(gcn_hidden_dim)
        self.prelu2 = nn.PReLU(gcn_hidden_dim)

        self.conv3 = GCNConv(gcn_hidden_dim, gcn_output_dim)
        self.bn3 = nn.BatchNorm1d(gcn_output_dim)
        self.prelu3 = nn.PReLU(gcn_output_dim)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index, edge_weight=None):
        if x.size(0) == 0: return x

        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        if x.size(0) > 1: x = self.bn1(x)
        x = self.prelu1(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        if x.size(0) > 1: x = self.bn2(x)
        x = self.prelu2(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index, edge_weight=edge_weight)
        if x.size(0) > 1: x = self.bn3(x)
        x = self.prelu3(x)
        return x

# ===========================================================
# 2. GCN-LSTM 模型定义
# ===========================================================
class GCNLSTMModelWithHourlyHeads(nn.Module): # Renamed
    def __init__(self,
                 static_node_in_dim,
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 gcn_hidden_dim,
                 gcn_output_dim,
                 lstm_hidden_dim, # Changed from gru_hidden_dim
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 num_lstm_layers=1, # Changed from num_gru_layers
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 dropout_rate_gcn=0.3,
                 dropout_rate_lstm=0.2, # Changed from dropout_rate_gru
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.gcn_output_dim = gcn_output_dim
        self.lstm_hidden_dim = lstm_hidden_dim # Store for consistency


        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)
        # This MLP processes GCN output to match LSTM hidden dim for h0 (and c0 will be zeros)
        self.h0_c0_from_gcn_encoder = MLPEncoder(gcn_output_dim, lstm_hidden_dim, dropout_rate=dropout_rate_encoders)

        self.gcn_module_for_h0 = GCNModule(static_node_in_dim, gcn_hidden_dim, gcn_output_dim, dropout_rate_gcn)
        self.gcn_module_for_sequence = GCNModule(static_node_in_dim, gcn_hidden_dim, gcn_output_dim, dropout_rate_gcn)

        concatenated_feature_dim = gcn_output_dim + global_env_emb_dim + time_emb_dim
        actual_fusion_mlp_output_dim = fusion_mlp_output_dim if fusion_mlp_output_dim is not None else concatenated_feature_dim
        self.fusion_mlp_input_dim = concatenated_feature_dim

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=actual_fusion_mlp_output_dim,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        lstm_input_size_actual = actual_fusion_mlp_output_dim
        self.lstm_input_dim = lstm_input_size_actual # For profiling

        self.lstm = nn.LSTM( # Changed from nn.GRU
            input_size=lstm_input_size_actual,
            hidden_size=lstm_hidden_dim,
            num_layers=num_lstm_layers,
            batch_first=True,
            dropout=dropout_rate_lstm if num_lstm_layers > 1 else 0.0
        )

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(lstm_hidden_dim, mlp_prediction_hidden_dim), # From LSTM output
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))


    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        normalized_x_7am = (pyg_batch_7am.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

        edge_weight_7am = getattr(pyg_batch_7am, 'edge_weight', None)
        if edge_weight_7am is not None:
            edge_weight_7am = edge_weight_7am.float()

        gcn_output_7am = self.gcn_module_for_h0(
            normalized_x_7am,
            pyg_batch_7am.edge_index,
            edge_weight=edge_weight_7am
        )
        h0_features_for_lstm_nodes = self.h0_c0_from_gcn_encoder(gcn_output_7am)

        # LSTM expects a tuple (h_0, c_0)
        h0_for_lstm = h0_features_for_lstm_nodes.unsqueeze(0) # (1, N_nodes, lstm_hidden_dim)
        c0_for_lstm = torch.zeros_like(h0_for_lstm)          # (1, N_nodes, lstm_hidden_dim)

        if self.lstm.num_layers > 1:
            h0_for_lstm = h0_for_lstm.repeat(self.lstm.num_layers, 1, 1)
            c0_for_lstm = c0_for_lstm.repeat(self.lstm.num_layers, 1, 1)

        initial_hidden_state = (h0_for_lstm, c0_for_lstm)


        all_lstm_input_features_over_time = [] # Renamed
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            normalized_x = (pyg_batch_this_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

            edge_weight_this_timestep = getattr(pyg_batch_this_timestep, 'edge_weight', None)
            if edge_weight_this_timestep is not None:
                edge_weight_this_timestep = edge_weight_this_timestep.float()

            gcn_output_nodes_t = self.gcn_module_for_sequence(
                normalized_x,
                pyg_batch_this_timestep.edge_index,
                edge_weight=edge_weight_this_timestep
            )

            global_env_feat_t_unencoded = pyg_batch_this_timestep.graph_global_env_features
            expected_num_graphs_in_batch_t = pyg_batch_this_timestep.num_graphs
            expected_global_features_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unencoded.shape == (expected_num_graphs_in_batch_t, expected_global_features_dim)):
                if global_env_feat_t_unencoded.ndim == 1 and \
                   global_env_feat_t_unencoded.shape[0] == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                elif global_env_feat_t_unencoded.numel() == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                else:
                    print(f"Warning: LSTM Input Time {t_pred_idx}: Correcting global_env_feat shape from {global_env_feat_t_unencoded.shape} to ({expected_num_graphs_in_batch_t}, {expected_global_features_dim}) with zeros due to mismatch.")
                    global_env_feat_t_unencoded = torch.zeros(expected_num_graphs_in_batch_t, expected_global_features_dim, device=device)
            global_env_emb_t = self.global_env_encoder(global_env_feat_t_unencoded)
            global_env_emb_t_expanded = global_env_emb_t[pyg_batch_this_timestep.batch]

            current_raw_time_feat_for_timestep_t = timeline_time_features[t_pred_idx, :].to(device)
            current_emb_time_feat_for_timestep_t = self.time_encoder(current_raw_time_feat_for_timestep_t)
            num_nodes_in_pyg_batch = pyg_batch_this_timestep.num_nodes
            time_emb_t_expanded_to_nodes = current_emb_time_feat_for_timestep_t.unsqueeze(0).expand(num_nodes_in_pyg_batch, -1)

            concatenated_features_for_timestep_t = torch.cat([gcn_output_nodes_t, global_env_emb_t_expanded, time_emb_t_expanded_to_nodes], dim=-1)
            fused_features_for_timestep_t = self.fusion_mlp(concatenated_features_for_timestep_t)
            all_lstm_input_features_over_time.append(fused_features_for_timestep_t)

        stacked_lstm_input_features = torch.stack(all_lstm_input_features_over_time, dim=1) # Renamed

        if initial_hidden_state[0].shape[1] != stacked_lstm_input_features.shape[0]: # Check h0's N_nodes
            print(f"CRITICAL WARNING: Node count mismatch for LSTM h0/c0 ({initial_hidden_state[0].shape[1]}) and LSTM input sequence ({stacked_lstm_input_features.shape[0]}).")
            if initial_hidden_state[0].shape[1] > stacked_lstm_input_features.shape[0]:
                h0_adj = initial_hidden_state[0][:, :stacked_lstm_input_features.shape[0], :]
                c0_adj = initial_hidden_state[1][:, :stacked_lstm_input_features.shape[0], :]
                initial_hidden_state = (h0_adj, c0_adj)

        lstm_out, _ = self.lstm(stacked_lstm_input_features, initial_hidden_state) # Changed from self.gru

        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            lstm_out_t = lstm_out[:, t, :] # Use lstm_out
            prediction_t_scaled = self.hourly_prediction_heads[t](lstm_out_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 (在原始尺度上计算指标) - NO CHANGES
# ===========================================================
# ... (mse_loss_masked, calculate_hourly_metrics are identical) ...
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 训练与评估循环 (适配y归一化) - NO CHANGES
# ===========================================================
# ... (train_epoch, evaluate_epoch are identical) ...
def train_epoch(model, loader, optimizer, device, timeline_time_features,
                node_feat_mean, node_feat_std, target_mean, target_std):
    model.train()
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    epoch_start_time = time.time()
    for list_of_batched_timesteps in loader:
        optimizer.zero_grad()
        predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)

        targets_list_for_loss_scaled = []
        first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
        mask_for_loss = ~first_predicted_timestep_batch.building_mask

        for t_pred_idx in range(model.T_pred_horizon):
            current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)

        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)
        loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled, mask_for_loss)
        num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

        if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() > 0 :
            loss.backward()
            optimizer.step()
            total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
        num_sequences_processed += num_sequences_in_this_super_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0
    return avg_loss, epoch_duration

def evaluate_epoch(model, loader, device, timeline_time_features,
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    model.eval()
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []
            first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
            mask_for_metrics = ~first_predicted_timestep_batch.building_mask

            for t_pred_idx in range(model.T_pred_horizon):
                current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)
            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)
            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (GCN-LSTM)
# ===========================================================

def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    # ... (identical to previous version) ...
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report

def main_training_gcn_lstm_hourly_heads( # Renamed
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    expected_input_len = T_PRED_HORIZON + 1
    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                all_train_node_features_list.append(graph_data.x)
            if i_step > 0 and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0:
                    all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])

    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_gcn_lstm.pth" # Renamed
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")

    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_gcn_lstm.pth" # Renamed
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization
    sample_graph_7am_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_7am_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_7am_for_dims.graph_global_env_features.shape[0] if sample_graph_7am_for_dims.graph_global_env_features.ndim == 1 else sample_graph_7am_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]

    model = GCNLSTMModelWithHourlyHeads( # Use GCNLSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        gcn_hidden_dim=config.get('gcn_hidden_dim', 128),
        gcn_output_dim=config.get('gcn_output_dim', 128),
        lstm_hidden_dim=config.get('gru_hidden_dim', 128), # Using gru_hidden_dim key for LSTM hidden
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), # Using num_gru_layers key for LSTM layers
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_gcn=config.get('dropout_rate_gcn', 0.3),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2), # Using dropout_rate_gru key for LSTM dropout
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)

    model.node_feat_mean = node_feat_mean.to(device); model.node_feat_std = node_feat_std.to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型总参数量 (GCN-LSTM): {total_params:,}"); report_data['model_total_parameters'] = total_params

    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500
    dummy_edges_component = 60000
    dummy_batch_global_comp = 1
    model.eval()

    # 1. GCNModule
    try:
        gcn_module_to_profile = model.gcn_module_for_h0
        dummy_x_gcn = torch.randn(dummy_nodes_component, gcn_module_to_profile.gcn_input_dim, device=device)
        dummy_ei_gcn = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ew_gcn = torch.rand(dummy_edges_component, device=device)
        macs_gcn = torchprofile.profile_macs(gcn_module_to_profile, args=(dummy_x_gcn, dummy_ei_gcn, dummy_ew_gcn))
        report_data['component_gmacs']['gcn_module'] = macs_gcn / 1e9
        print(f"  GCNModule GMACs: {macs_gcn / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling GCNModule: {e}"); report_data['component_gmacs']['gcn_module'] = "Error"

    # 2. LSTM Layer (Manual MAC Calculation)
    print(f"  Manually Calculating MACs for LSTM Layer:")
    try:
        lstm_layer = model.lstm # Changed from model.gru
        N_nodes = dummy_nodes_component
        L_seq = T_PRED_HORIZON
        H_in = lstm_layer.input_size
        H_hidden = lstm_layer.hidden_size
        num_layers = lstm_layer.num_layers

        macs_lstm_manual = 0
        # For a single layer LSTM: MACs ≈ N * L * 4 * (H_in * H_hidden + H_hidden^2)
        macs_lstm_manual = N_nodes * L_seq * 4 * (H_in * H_hidden + H_hidden * H_hidden) # For the first layer
        if num_layers > 1:
            # Subsequent (num_layers - 1) layers: N * L * 4 * (H_hidden * H_hidden + H_hidden^2)
            macs_lstm_manual += N_nodes * L_seq * (num_layers - 1) * 4 * (H_hidden * H_hidden + H_hidden * H_hidden)

        gmacs_lstm_manual = macs_lstm_manual / 1e9
        report_data['component_gmacs']['lstm_layer'] = gmacs_lstm_manual # Renamed key
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = "Manually calculated based on formula."
        print(f"  LSTM Parameters: input_size={H_in}, hidden_size={H_hidden}, num_layers={num_layers}")
        print(f"  Used for calculation: N_nodes={N_nodes}, L_seq={L_seq}")
        print(f"  LSTM Layer GMACs (Manual): {gmacs_lstm_manual:.4f} (for sequence length {L_seq})")

    except Exception as e:
        print(f"  Error manually calculating LSTM Layer MACs: {e}")
        report_data['component_gmacs']['lstm_layer'] = "Error" # Renamed key
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = f"Error during manual calculation: {str(e)}"

    # 3. Fusion MLP (MLPEncoder)
    try:
        fusion_mlp_to_profile = model.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"

    # 4. Prediction Head (one MLP from ModuleList)
    try:
        pred_head_to_profile = model.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model.lstm_hidden_dim, device=device) # Input is lstm_hidden_dim
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP (single hour) GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"

    # 5. Global Environment Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"

    # 6. Time Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.time_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"

    # 7. H0/C0 from GCN Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.h0_c0_from_gcn_encoder # Renamed
        dummy_input_encoder = torch.randn(dummy_nodes_component, model.gcn_output_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['h0_c0_from_gcn_encoder_mlp'] = macs_encoder / 1e9 # Renamed key
        print(f"  H0/C0 from GCN Encoder MLP GMACs: {macs_encoder / 1e9:.4f}") # Note: This MLP now maps to lstm_hidden_dim
    except Exception as e:
        print(f"  Error profiling H0/C0 from GCN Encoder: {e}"); report_data['component_gmacs']['h0_c0_from_gcn_encoder_mlp'] = "Error"

    model.train()
    # ===== End Component FLOPS Calculation =====

    optimizer = torch.optim.Adam(model.parameters(), lr=config.get('lr', 0.001), weight_decay=config.get('weight_decay', 1e-5))
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 20), verbose=True)

    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    model_save_path = Path(config['results_dir']) / f"best_gcnlstm_hourly_heads_model_seed{seed}.pth" # Renamed
    timeline_time_features_on_device = time_features_for_dataset.to(device)
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        train_loss_scaled, epoch_duration = train_epoch(model, train_loader, optimizer, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_on_device, target_std_on_device)
        epoch_times.append(epoch_duration)
        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(model, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Validation")
        scheduler.step(val_loss_scaled)
        print(f"Epoch {epoch:03d} | Train Scaled MSE: {train_loss_scaled:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Epoch Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")
        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0; torch.save(model.state_dict(), model_save_path)
            print(f"                     ---> Best model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    model_for_eval = GCNLSTMModelWithHourlyHeads( # Use GCNLSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        gcn_hidden_dim=config.get('gcn_hidden_dim', 128),
        gcn_output_dim=config.get('gcn_output_dim', 128),
        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_gcn=config.get('dropout_rate_gcn', 0.3),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)
    try: model_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
    except Exception as e: print(f"无法加载最佳模型 ({e})，将使用训练循环结束时的模型。"); model_for_eval = model

    print("\n评估最佳模型在训练集上..."); best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(model_for_eval, train_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Train")
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics; report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration; _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original: report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    else:
        _, reeval_val_metrics, val_eval_duration = evaluate_epoch(model_for_eval, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Val (Re-eval)")
        report_data['best_model_validation_set_metrics_hourly'] = reeval_val_metrics
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(reeval_val_metrics, T_PRED_HORIZON)
        report_data['best_model_validation_set_eval_time_seconds'] = val_eval_duration
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳模型在测试集上..."); test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(model_for_eval, test_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test")
    report_data['test_set_inference_time_seconds'] = test_inference_duration; report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original; report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (GCN-LSTM) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}") # Renamed
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    report_file_path = Path(config['results_dir']) / f"training_report_gcn_lstm_seed{seed}.json" # Renamed
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_for_eval, node_feat_mean, node_feat_std, target_mean, target_std

def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    # ... (identical to previous version) ...
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")

# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")
    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data")
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR = Path("Result/Final_GCNLSTM1") # Renamed output subdir
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME


    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12

    training_config = {
        'seed': 42, 'batch_size': 8, 'lr': 0.001, 'weight_decay': 1e-5,
        'max_epochs': 1000, 'scheduler_patience': 20, 'early_stopping_patience': 45,
        'T_pred_horizon': PREDICTION_HORIZON, 'results_dir': str(RESULTS_SAVE_DIR),
        'global_env_emb_dim': 16, 'time_emb_dim': 8,
        'gcn_hidden_dim': 128,
        'gcn_output_dim': 128,
        # 'num_relations' and 'gine_edge_dim' are not needed for GCN
        'gru_hidden_dim': 128, # This will be lstm_hidden_dim
        'num_gru_layers': 1,  # This will be num_lstm_layers
        'mlp_prediction_hidden_dim': 64,
        'fusion_mlp_output_dim': 128,
        'fusion_mlp_hidden_dim': 64,
        'dropout_rate_fusion_mlp': 0.2, 'dropout_rate_encoders': 0.1,
        'dropout_rate_gcn': 0.3,
        'dropout_rate_gru': 0.2,  # This will be dropout_rate_lstm
        'dropout_rate_pred_head': 0.2,
        'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2,
        'h0_from_first_step': True
    }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")

        expected_len_per_sequence = training_config['T_pred_horizon'] + 1

        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx > 0 and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    valid_seq = False; break
                if step_idx > 0 and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset_timeline = generate_time_features_for_sequence(
            base_datetime_for_timeline,
            training_config['T_pred_horizon']
        )

        trained_model, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_gcn_lstm_hourly_heads( # Renamed call
            all_graph_sequences, training_config, time_features_for_dataset_timeline
        )
        print("GCN-LSTM 模型训练和评估完成!")
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##GINE+LSTM

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GINEConv # Using GINEConv
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta
from pathlib import Path
import time
import json
import torchprofile

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


# ===========================================================
# 1. 特征生成 & 辅助模块 (No changes)
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class GINEModule(nn.Module):
    def __init__(self, gine_input_dim, gine_hidden_dim, gine_output_dim, edge_feature_dim, dropout_rate=0.5):
        super().__init__()
        self.gine_input_dim = gine_input_dim
        self.gine_hidden_dim = gine_hidden_dim
        self.gine_output_dim = gine_output_dim
        self.edge_feature_dim = edge_feature_dim

        nn1 = nn.Sequential(
            nn.Linear(gine_input_dim, gine_hidden_dim),
            nn.ReLU(),
            nn.Linear(gine_hidden_dim, gine_hidden_dim)
        )
        self.conv1 = GINEConv(nn1, edge_dim=self.edge_feature_dim)
        self.bn1 = nn.BatchNorm1d(gine_hidden_dim)
        self.prelu1 = nn.PReLU(gine_hidden_dim)

        nn2 = nn.Sequential(
            nn.Linear(gine_hidden_dim, gine_hidden_dim),
            nn.ReLU(),
            nn.Linear(gine_hidden_dim, gine_hidden_dim)
        )
        self.conv2 = GINEConv(nn2, edge_dim=self.edge_feature_dim)
        self.bn2 = nn.BatchNorm1d(gine_hidden_dim)
        self.prelu2 = nn.PReLU(gine_hidden_dim)

        nn3 = nn.Sequential(
            nn.Linear(gine_hidden_dim, gine_hidden_dim),
            nn.ReLU(),
            nn.Linear(gine_hidden_dim, gine_hidden_dim)
        )
        self.conv3 = GINEConv(nn3, edge_dim=self.edge_feature_dim)
        self.bn3 = nn.BatchNorm1d(gine_hidden_dim)
        self.prelu3 = nn.PReLU(gine_hidden_dim)

        self.lin_out = nn.Linear(gine_hidden_dim, gine_output_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, edge_index, edge_attr):
        if x.size(0) == 0: return x

        x = self.conv1(x, edge_index, edge_attr=edge_attr)
        if x.size(0) > 1: x = self.bn1(x)
        x = self.prelu1(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, edge_attr=edge_attr)
        if x.size(0) > 1: x = self.bn2(x)
        x = self.prelu2(x)
        x = self.dropout(x)

        x = self.conv3(x, edge_index, edge_attr=edge_attr)
        if x.size(0) > 1: x = self.bn3(x)
        x = self.prelu3(x)

        x = self.lin_out(x)
        return x

# ===========================================================
# 2. GINE-LSTM 模型定义
# ===========================================================
class GINELSTMModelWithHourlyHeads(nn.Module): # Renamed
    def __init__(self,
                 static_node_in_dim,
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 gine_hidden_dim,
                 gine_output_dim,
                 gine_edge_dim,
                 lstm_hidden_dim, # Changed from gru_hidden_dim
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 num_lstm_layers=1, # Changed from num_gru_layers
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 dropout_rate_gine=0.3,
                 dropout_rate_lstm=0.2, # Changed from dropout_rate_gru
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.gine_output_dim = gine_output_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.gine_edge_dim = gine_edge_dim

        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)
        self.h0_c0_from_gine_encoder = MLPEncoder(gine_output_dim, lstm_hidden_dim, dropout_rate=dropout_rate_encoders) # For h0 of LSTM

        self.gine_module_for_h0 = GINEModule(static_node_in_dim, gine_hidden_dim, gine_output_dim,
                                             edge_feature_dim=gine_edge_dim, dropout_rate=dropout_rate_gine)
        self.gine_module_for_sequence = GINEModule(static_node_in_dim, gine_hidden_dim, gine_output_dim,
                                                 edge_feature_dim=gine_edge_dim, dropout_rate=dropout_rate_gine)

        concatenated_feature_dim = gine_output_dim + global_env_emb_dim + time_emb_dim
        actual_fusion_mlp_output_dim = fusion_mlp_output_dim if fusion_mlp_output_dim is not None else concatenated_feature_dim
        self.fusion_mlp_input_dim = concatenated_feature_dim

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=actual_fusion_mlp_output_dim,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        lstm_input_size_actual = actual_fusion_mlp_output_dim
        self.lstm_input_dim = lstm_input_size_actual # For profiling

        self.lstm = nn.LSTM( # Changed from nn.GRU
            input_size=lstm_input_size_actual,
            hidden_size=lstm_hidden_dim,
            num_layers=num_lstm_layers,
            batch_first=True,
            dropout=dropout_rate_lstm if num_lstm_layers > 1 else 0.0
        )

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(lstm_hidden_dim, mlp_prediction_hidden_dim), # Input from LSTM
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))

    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        normalized_x_7am = (pyg_batch_7am.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

        edge_attr_7am = pyg_batch_7am.edge_attr
        if edge_attr_7am is not None:
            edge_attr_7am = edge_attr_7am.float()
        else:
            num_edges_7am = pyg_batch_7am.edge_index.size(1)
            edge_attr_7am = torch.zeros(num_edges_7am, self.gine_edge_dim, device=device)

        gine_output_7am = self.gine_module_for_h0(
            normalized_x_7am,
            pyg_batch_7am.edge_index,
            edge_attr=edge_attr_7am
        )
        h0_features_for_lstm_nodes = self.h0_c0_from_gine_encoder(gine_output_7am)

        h0_for_lstm = h0_features_for_lstm_nodes.unsqueeze(0)
        c0_for_lstm = torch.zeros_like(h0_for_lstm)

        if self.lstm.num_layers > 1: # Changed from self.gru
            h0_for_lstm = h0_for_lstm.repeat(self.lstm.num_layers, 1, 1)
            c0_for_lstm = c0_for_lstm.repeat(self.lstm.num_layers, 1, 1)

        initial_hidden_state = (h0_for_lstm, c0_for_lstm)

        all_lstm_input_features_over_time = [] # Renamed
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            normalized_x = (pyg_batch_this_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

            edge_attr_this_timestep = pyg_batch_this_timestep.edge_attr
            if edge_attr_this_timestep is not None:
                edge_attr_this_timestep = edge_attr_this_timestep.float()
            else:
                num_edges_current = pyg_batch_this_timestep.edge_index.size(1)
                edge_attr_this_timestep = torch.zeros(num_edges_current, self.gine_edge_dim, device=device)

            gine_output_nodes_t = self.gine_module_for_sequence(
                normalized_x,
                pyg_batch_this_timestep.edge_index,
                edge_attr=edge_attr_this_timestep
            )

            global_env_feat_t_unencoded = pyg_batch_this_timestep.graph_global_env_features
            expected_num_graphs_in_batch_t = pyg_batch_this_timestep.num_graphs
            expected_global_features_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unencoded.shape == (expected_num_graphs_in_batch_t, expected_global_features_dim)):
                if global_env_feat_t_unencoded.ndim == 1 and \
                   global_env_feat_t_unencoded.shape[0] == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                elif global_env_feat_t_unencoded.numel() == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                else:
                    print(f"Warning: LSTM Input Time {t_pred_idx}: Correcting global_env_feat shape from {global_env_feat_t_unencoded.shape} to ({expected_num_graphs_in_batch_t}, {expected_global_features_dim}) with zeros due to mismatch.")
                    global_env_feat_t_unencoded = torch.zeros(expected_num_graphs_in_batch_t, expected_global_features_dim, device=device)
            global_env_emb_t = self.global_env_encoder(global_env_feat_t_unencoded)
            global_env_emb_t_expanded = global_env_emb_t[pyg_batch_this_timestep.batch]

            current_raw_time_feat_for_timestep_t = timeline_time_features[t_pred_idx, :].to(device)
            current_emb_time_feat_for_timestep_t = self.time_encoder(current_raw_time_feat_for_timestep_t)
            num_nodes_in_pyg_batch = pyg_batch_this_timestep.num_nodes
            time_emb_t_expanded_to_nodes = current_emb_time_feat_for_timestep_t.unsqueeze(0).expand(num_nodes_in_pyg_batch, -1)

            concatenated_features_for_timestep_t = torch.cat([gine_output_nodes_t, global_env_emb_t_expanded, time_emb_t_expanded_to_nodes], dim=-1)
            fused_features_for_timestep_t = self.fusion_mlp(concatenated_features_for_timestep_t)
            all_lstm_input_features_over_time.append(fused_features_for_timestep_t)

        stacked_lstm_input_features = torch.stack(all_lstm_input_features_over_time, dim=1)

        if initial_hidden_state[0].shape[1] != stacked_lstm_input_features.shape[0]:
            print(f"CRITICAL WARNING: Node count mismatch for LSTM h0/c0 ({initial_hidden_state[0].shape[1]}) and LSTM input sequence ({stacked_lstm_input_features.shape[0]}).")
            if initial_hidden_state[0].shape[1] > stacked_lstm_input_features.shape[0]:
                h0_adj = initial_hidden_state[0][:, :stacked_lstm_input_features.shape[0], :]
                c0_adj = initial_hidden_state[1][:, :stacked_lstm_input_features.shape[0], :]
                initial_hidden_state = (h0_adj, c0_adj)

        lstm_out, _ = self.lstm(stacked_lstm_input_features, initial_hidden_state)

        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            lstm_out_t = lstm_out[:, t, :]
            prediction_t_scaled = self.hourly_prediction_heads[t](lstm_out_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 (在原始尺度上计算指标) - NO CHANGES
# ===========================================================
# ... (mse_loss_masked, calculate_hourly_metrics are identical) ...
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 训练与评估循环 (适配y归一化) - NO CHANGES
# ===========================================================
# ... (train_epoch, evaluate_epoch are identical) ...
def train_epoch(model, loader, optimizer, device, timeline_time_features,
                node_feat_mean, node_feat_std, target_mean, target_std):
    model.train()
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    epoch_start_time = time.time()
    for list_of_batched_timesteps in loader:
        optimizer.zero_grad()
        predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)

        targets_list_for_loss_scaled = []
        first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
        mask_for_loss = ~first_predicted_timestep_batch.building_mask

        for t_pred_idx in range(model.T_pred_horizon):
            current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)

        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)
        loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled, mask_for_loss)
        num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

        if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() > 0 :
            loss.backward()
            optimizer.step()
            total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
        num_sequences_processed += num_sequences_in_this_super_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0
    return avg_loss, epoch_duration

def evaluate_epoch(model, loader, device, timeline_time_features,
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    model.eval()
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []
            first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
            mask_for_metrics = ~first_predicted_timestep_batch.building_mask

            for t_pred_idx in range(model.T_pred_horizon):
                current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)
            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)
            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (GINE-LSTM)
# ===========================================================

def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    # ... (identical) ...
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report

def main_training_gine_lstm_hourly_heads( # Renamed
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    expected_input_len = T_PRED_HORIZON + 1
    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                all_train_node_features_list.append(graph_data.x)
            if i_step > 0 and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0:
                    all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])

    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_gine_lstm.pth" # Renamed
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")

    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_gine_lstm.pth" # Renamed
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization
    sample_graph_7am_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_7am_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_7am_for_dims.graph_global_env_features.shape[0] if sample_graph_7am_for_dims.graph_global_env_features.ndim == 1 else sample_graph_7am_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]

    if hasattr(sample_graph_7am_for_dims, 'edge_attr') and sample_graph_7am_for_dims.edge_attr is not None:
        inferred_gine_edge_dim = sample_graph_7am_for_dims.edge_attr.shape[1]
    else:
        inferred_gine_edge_dim = config.get('gine_edge_dim', 5)
        print(f"Warning: Could not infer gine_edge_dim from sample data, using config or default: {inferred_gine_edge_dim}")
    gine_edge_dim_config = config.get('gine_edge_dim', inferred_gine_edge_dim)

    model = GINELSTMModelWithHourlyHeads( # Use GINELSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        gine_hidden_dim=config.get('gcn_hidden_dim', 128),
        gine_output_dim=config.get('gcn_output_dim', 128),
        gine_edge_dim=gine_edge_dim_config,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128), # Using gru_hidden_dim key for LSTM
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), # Using num_gru_layers key for LSTM
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_gine=config.get('dropout_rate_gcn', 0.3),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2), # Using dropout_rate_gru key for LSTM
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)

    model.node_feat_mean = node_feat_mean.to(device); model.node_feat_std = node_feat_std.to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型总参数量 (GINE-LSTM): {total_params:,}"); report_data['model_total_parameters'] = total_params

    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500
    dummy_edges_component = 60000
    dummy_batch_global_comp = 1
    model.eval()

    # 1. GINEModule
    try:
        gine_module_to_profile = model.gine_module_for_h0
        dummy_x_gine = torch.randn(dummy_nodes_component, gine_module_to_profile.gine_input_dim, device=device)
        dummy_ei_gine = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ea_gine = torch.randn(dummy_edges_component, gine_module_to_profile.edge_feature_dim, device=device)

        macs_gine = torchprofile.profile_macs(gine_module_to_profile, args=(dummy_x_gine, dummy_ei_gine, dummy_ea_gine))
        report_data['component_gmacs']['gine_module'] = macs_gine / 1e9
        print(f"  GINEModule GMACs (edge_dim={gine_module_to_profile.edge_feature_dim}): {macs_gine / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling GINEModule: {e}"); report_data['component_gmacs']['gine_module'] = "Error"

    # 2. LSTM Layer (Manual MAC Calculation)
    print(f"  Manually Calculating MACs for LSTM Layer:")
    try:
        lstm_layer = model.lstm
        N_nodes = dummy_nodes_component
        L_seq = T_PRED_HORIZON
        H_in = lstm_layer.input_size
        H_hidden = lstm_layer.hidden_size
        num_layers = lstm_layer.num_layers

        macs_lstm_manual = 0
        # For a single layer LSTM: MACs ≈ N * L * 4 * (H_in * H_hidden + H_hidden^2)
        macs_lstm_manual = N_nodes * L_seq * 4 * (H_in * H_hidden + H_hidden * H_hidden)
        if num_layers > 1:
            macs_lstm_manual += N_nodes * L_seq * (num_layers - 1) * 4 * (H_hidden * H_hidden + H_hidden * H_hidden)

        gmacs_lstm_manual = macs_lstm_manual / 1e9
        report_data['component_gmacs']['lstm_layer'] = gmacs_lstm_manual
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = "Manually calculated based on formula."
        print(f"  LSTM Parameters: input_size={H_in}, hidden_size={H_hidden}, num_layers={num_layers}")
        print(f"  Used for calculation: N_nodes={N_nodes}, L_seq={L_seq}")
        print(f"  LSTM Layer GMACs (Manual): {gmacs_lstm_manual:.4f} (for sequence length {L_seq})")

    except Exception as e:
        print(f"  Error manually calculating LSTM Layer MACs: {e}")
        report_data['component_gmacs']['lstm_layer'] = "Error"
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = f"Error during manual calculation: {str(e)}"

    # 3. Fusion MLP, 4. Prediction Head, 5. Global Env Encoder, 6. Time Encoder
    try:
        fusion_mlp_to_profile = model.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"

    try:
        pred_head_to_profile = model.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model.lstm_hidden_dim, device=device) # Input from LSTM
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP (single hour) GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"

    try:
        encoder_to_profile = model.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"

    try:
        encoder_to_profile = model.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.time_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"

    # 7. H0/C0 from GINE Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.h0_c0_from_gine_encoder # Renamed
        dummy_input_encoder = torch.randn(dummy_nodes_component, model.gine_output_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['h0_c0_from_gine_encoder_mlp'] = macs_encoder / 1e9 # Renamed key
        print(f"  H0/C0 from GINE Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling H0/C0 from GINE Encoder: {e}"); report_data['component_gmacs']['h0_c0_from_gine_encoder_mlp'] = "Error"

    model.train()
    # ===== End Component FLOPS Calculation =====

    optimizer = torch.optim.Adam(model.parameters(), lr=config.get('lr', 0.001), weight_decay=config.get('weight_decay', 1e-5))
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 20), verbose=True)

    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    model_save_path = Path(config['results_dir']) / f"best_ginelstm_hourly_heads_model_seed{seed}.pth" # Renamed
    timeline_time_features_on_device = time_features_for_dataset.to(device)
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        train_loss_scaled, epoch_duration = train_epoch(model, train_loader, optimizer, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_on_device, target_std_on_device)
        epoch_times.append(epoch_duration)
        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(model, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Validation")
        scheduler.step(val_loss_scaled)
        print(f"Epoch {epoch:03d} | Train Scaled MSE: {train_loss_scaled:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Epoch Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")
        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0; torch.save(model.state_dict(), model_save_path)
            print(f"                     ---> Best model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    model_for_eval = GINELSTMModelWithHourlyHeads( # Use GINELSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        gine_hidden_dim=config.get('gcn_hidden_dim', 128),
        gine_output_dim=config.get('gcn_output_dim', 128),
        gine_edge_dim=gine_edge_dim_config,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_gine=config.get('dropout_rate_gcn', 0.3),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)
    try: model_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
    except Exception as e: print(f"无法加载最佳模型 ({e})，将使用训练循环结束时的模型。"); model_for_eval = model

    print("\n评估最佳模型在训练集上..."); best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(model_for_eval, train_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Train")
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics; report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration; _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original: report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    else:
        _, reeval_val_metrics, val_eval_duration = evaluate_epoch(model_for_eval, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Val (Re-eval)")
        report_data['best_model_validation_set_metrics_hourly'] = reeval_val_metrics
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(reeval_val_metrics, T_PRED_HORIZON)
        report_data['best_model_validation_set_eval_time_seconds'] = val_eval_duration
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳模型在测试集上..."); test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(model_for_eval, test_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test")
    report_data['test_set_inference_time_seconds'] = test_inference_duration; report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original; report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (GINE-LSTM) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}") # Renamed
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    report_file_path = Path(config['results_dir']) / f"training_report_gine_lstm_seed{seed}.json" # Renamed
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_for_eval, node_feat_mean, node_feat_std, target_mean, target_std

def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    # ... (identical) ...
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")

# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")

    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data")
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR = Path("Result/Final_GINELSTM1")
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME

    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12

    training_config = {
        'seed': 42, 'batch_size': 8, 'lr': 0.001, 'weight_decay': 1e-5,
        'max_epochs': 1000, 'scheduler_patience': 20, 'early_stopping_patience': 45,
        'T_pred_horizon': PREDICTION_HORIZON, 'results_dir': str(RESULTS_SAVE_DIR),
        'global_env_emb_dim': 16, 'time_emb_dim': 8,
        'gcn_hidden_dim': 128, # Used as gine_hidden_dim
        'gcn_output_dim': 128, # Used as gine_output_dim
        'gine_edge_dim': 6,   # Ensure this matches your edge_attr dimension
        'gru_hidden_dim': 128, # This will be lstm_hidden_dim
        'num_gru_layers': 1,  # This will be num_lstm_layers
        'mlp_prediction_hidden_dim': 64,
        'fusion_mlp_output_dim': 128,
        'fusion_mlp_hidden_dim': 64,
        'dropout_rate_fusion_mlp': 0.2, 'dropout_rate_encoders': 0.1,
        'dropout_rate_gcn': 0.3, # Used as dropout_rate_gine
        'dropout_rate_gru': 0.2, # Used as dropout_rate_lstm
        'dropout_rate_pred_head': 0.2,
        'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2,
        'h0_from_first_step': True
    }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")

        expected_len_per_sequence = training_config['T_pred_horizon'] + 1

        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'edge_attr') or graph_step_data.edge_attr is None or \
                   graph_step_data.edge_attr.shape[1] != training_config.get('gine_edge_dim', 5) or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx > 0 and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    if hasattr(graph_step_data, 'edge_attr') and graph_step_data.edge_attr is not None and \
                       graph_step_data.edge_attr.shape[1] != training_config.get('gine_edge_dim', 5):
                       print(f"Warning: Seq {i}, step {step_idx} has edge_attr dim {graph_step_data.edge_attr.shape[1]}, expected {training_config.get('gine_edge_dim', 5)}")
                    valid_seq = False; break
                if step_idx > 0 and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset_timeline = generate_time_features_for_sequence(
            base_datetime_for_timeline,
            training_config['T_pred_horizon']
        )

        trained_model, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_gine_lstm_hourly_heads( # Renamed call
            all_graph_sequences, training_config, time_features_for_dataset_timeline
        )
        print("GINE-LSTM 模型训练和评估完成!")
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##GAE+LSTM

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import RGCNConv # Using RGCNConv for C-GVAE
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta
from pathlib import Path
import time
import json
import torchprofile

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# ===========================================================
# 1. 特征生成 & 辅助模块 (No changes)
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

# ===========================================================
# 2. C-GVAE Components (using RGCNConv)
# ===========================================================

class RGCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_relations, dropout_rate, use_residual=True):
        super().__init__()
        self.conv = RGCNConv(in_channels, out_channels, num_relations=num_relations)
        self.norm = nn.BatchNorm1d(out_channels)
        self.activation = nn.PReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.use_residual = use_residual

        if self.use_residual:
            if in_channels == out_channels:
                self.residual_projection = nn.Identity()
            else:
                self.residual_projection = nn.Linear(in_channels, out_channels)

    def forward(self, x_input, edge_index, edge_attr):
        if x_input.size(0) == 0: return x_input

        # Ensure edge_attr provides relation types for RGCNConv
        if edge_attr is None or edge_attr.shape[1] < 1 : # Assuming relation type is at least the first column
             raise ValueError("RGCNBlock requires edge_attr with relation types.")
        edge_type = edge_attr[:,0].long() # Assuming relation type is the first column if generic edge_attr passed
        # If your edge_attr has a specific column for relation type (e.g., index 4 from previous RGCNModule), adjust here.
        # For simplicity with C-GVAE, let's assume edge_attr *is* the relation type vector or its first column.
        # If edge_attr itself is multi-dimensional beyond just relation type, RGCNConv might not use other dimensions.
        # The original RGCNModule used edge_attr[:, 4].long().
        # Let's stick to that convention if edge_attr has 5+ columns, otherwise use column 0.
        if edge_attr.shape[1] >=5: # Match original RGCNModule behavior
            edge_type = edge_attr[:, 4].long()
        else: # Fallback or if edge_attr is just relation types
            edge_type = edge_attr[:,0].long()


        h = self.conv(x_input, edge_index, edge_type=edge_type)
        if h.shape[0] > 1:
            h = self.norm(h)
        h = self.activation(h)
        h = self.dropout(h)

        if self.use_residual:
            projected_x_input = self.residual_projection(x_input)
            h = h + projected_x_input
        return h

class Encoder_CGVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_layers, num_relations, dropout_rate):
        super().__init__()
        layers = []
        current_dim = input_dim
        for i in range(num_layers):
            block_use_residual = (i > 0)
            layers.append(RGCNBlock(current_dim, hidden_dim, num_relations, dropout_rate,
                                   use_residual=block_use_residual if current_dim == hidden_dim else False))
            current_dim = hidden_dim
        self.rgcn_layers = nn.ModuleList(layers)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x, edge_index, edge_attr):
        for layer in self.rgcn_layers:
            x = layer(x, edge_index, edge_attr)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder_CGVAE(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim, original_node_feature_dim,
                 num_layers, num_relations, dropout_rate):
        super().__init__()
        self.initial_input_dim = latent_dim + original_node_feature_dim # z and original x

        layers = []
        current_dim = self.initial_input_dim
        for i in range(num_layers):
            block_use_residual = (i > 0)
            layers.append(RGCNBlock(current_dim, hidden_dim, num_relations, dropout_rate,
                                   use_residual=block_use_residual if current_dim == hidden_dim else False))
            current_dim = hidden_dim
        self.rgcn_layers = nn.ModuleList(layers)
        self.fc_out = nn.Linear(hidden_dim, output_dim) # Output transformed features

    def forward(self, z, x_original, edge_index, edge_attr):
        decoder_input = torch.cat([z, x_original], dim=1)
        x = decoder_input
        for layer in self.rgcn_layers:
            x = layer(x, edge_index, edge_attr)
        output_features = self.fc_out(x)
        return output_features

class RGCN_CGVAE_FeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_feature_dim,
                 num_encoder_layers, num_decoder_layers, num_relations, dropout_rate):
        super().__init__()
        self.encoder = Encoder_CGVAE(input_dim, hidden_dim, latent_dim,
                                     num_encoder_layers, num_relations, dropout_rate)
        self.decoder = Decoder_CGVAE(latent_dim, hidden_dim, output_feature_dim,
                                     input_dim, num_decoder_layers, num_relations, dropout_rate)
        self.latent_dim = latent_dim # For profiling

    def reparameterize(self, mu, logvar):
        # For feature extraction in a downstream task, often mu is used directly during eval.
        # During VAE training, sampling is done. Here, we might always use mu if not training VAE loss.
        if self.training: # If you were to add KLD loss, you'd sample here
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu # Deterministic for feature extraction

    def forward(self, x, edge_index, edge_attr):
        mu, logvar = self.encoder(x, edge_index, edge_attr)

        # Using mu directly as 'z' for feature extraction for the downstream task.
        # If VAE component was trained with KLD, self.reparameterize(mu, logvar) would be used during VAE training.
        z = mu

        output_features = self.decoder(z, x, edge_index, edge_attr) # Pass original x to decoder
        return output_features # , mu, logvar # Optionally return mu, logvar if KLD loss is used elsewhere

# ===========================================================
# 2.b CGVAE-LSTM 模型定义 (Main Model)
# ===========================================================
class CGVAELSTMModelWithHourlyHeads(nn.Module): # Renamed
    def __init__(self,
                 static_node_in_dim,      # Input to CGVAE encoder
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 # CGVAE specific params (using repurposed gcn keys from config)
                 cgvae_hidden_dim,        # Hidden dim inside CGVAE's RGCNBlocks
                 cgvae_latent_dim,        # Latent dim of CGVAE
                 cgvae_output_dim,        # Output dim of CGVAE's decoder (input to LSTM pipeline)
                 cgvae_num_encoder_layers,
                 cgvae_num_decoder_layers,
                 cgvae_dropout_rate,
                 num_relations,           # For RGCNConv inside CGVAE
                 # LSTM specific params (using repurposed gru keys from config)
                 lstm_hidden_dim,
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 num_lstm_layers=1,
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 # dropout_rate_gcn/gine is now cgvae_dropout_rate
                 dropout_rate_lstm=0.2,
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.cgvae_output_dim = cgvae_output_dim # Features from CGVAE
        self.lstm_hidden_dim = lstm_hidden_dim
        self.num_relations = num_relations # For profiling CGVAE component

        # Standard Encoders
        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)

        # MLP to process output of CGVAE's decoder before LSTM's h0/c0
        self.h0_c0_from_cgvae_encoder = MLPEncoder(cgvae_output_dim, lstm_hidden_dim, dropout_rate=dropout_rate_encoders)

        # CGVAE Feature Extractors
        self.cgvae_module_for_h0 = RGCN_CGVAE_FeatureExtractor(
            input_dim=static_node_in_dim, hidden_dim=cgvae_hidden_dim, latent_dim=cgvae_latent_dim,
            output_feature_dim=cgvae_output_dim, num_encoder_layers=cgvae_num_encoder_layers,
            num_decoder_layers=cgvae_num_decoder_layers, num_relations=num_relations,
            dropout_rate=cgvae_dropout_rate
        )
        self.cgvae_module_for_sequence = RGCN_CGVAE_FeatureExtractor(
            input_dim=static_node_in_dim, hidden_dim=cgvae_hidden_dim, latent_dim=cgvae_latent_dim,
            output_feature_dim=cgvae_output_dim, num_encoder_layers=cgvae_num_encoder_layers,
            num_decoder_layers=cgvae_num_decoder_layers, num_relations=num_relations,
            dropout_rate=cgvae_dropout_rate
        )

        concatenated_feature_dim = cgvae_output_dim + global_env_emb_dim + time_emb_dim
        actual_fusion_mlp_output_dim = fusion_mlp_output_dim if fusion_mlp_output_dim is not None else concatenated_feature_dim
        self.fusion_mlp_input_dim = concatenated_feature_dim

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=actual_fusion_mlp_output_dim,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        lstm_input_size_actual = actual_fusion_mlp_output_dim
        self.lstm_input_dim = lstm_input_size_actual

        self.lstm = nn.LSTM(
            input_size=lstm_input_size_actual,
            hidden_size=lstm_hidden_dim,
            num_layers=num_lstm_layers,
            batch_first=True,
            dropout=dropout_rate_lstm if num_lstm_layers > 1 else 0.0
        )

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(lstm_hidden_dim, mlp_prediction_hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))

    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        normalized_x_7am = (pyg_batch_7am.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

        edge_attr_7am = pyg_batch_7am.edge_attr
        if edge_attr_7am is not None:
            edge_attr_7am = edge_attr_7am.float()
        else:
            num_edges_7am = pyg_batch_7am.edge_index.size(1)
            # Infer edge_attr dim for placeholder if needed (assuming relation type is the key part for RGCNBlock)
            # This part might need adjustment based on how edge_attr is structured for RGCNBlock's num_relations
            # For now, assuming a simple edge_attr of size 1 (relation type) if not present.
            # Or, if your RGCNBlock expects a specific dimensionality (e.g., 5 for original RGCNModule), use that.
            # Let's assume edge_attr for RGCNBlock needs at least 1 dim for relation type
            # The RGCNBlock was adapted to take edge_attr[:,4] or edge_attr[:,0]
            # If edge_attr is truly None, then a placeholder is needed.
            # For now, let's assume `edge_attr` will be present from data loading.
            # If it can be None, the GNN module needs to handle it (e.g. error or default behavior).
            # The RGCNBlock inside CGVAE will raise error if edge_attr is None and it expects types.
            # We will assume data loader provides edge_attr.
            # For safety, if it's None and num_relations>0, it's problematic.
            # Let's ensure the dummy data and actual data provide edge_attr if num_relations > 0.
            # The RGCNBlock expects edge_attr to provide relation types.
            if edge_attr_7am is None:
                raise ValueError("edge_attr cannot be None for RGCN_CGVAE_FeatureExtractor.")


        # Pass original x, edge_index, edge_attr to CGVAE
        cgvae_output_7am = self.cgvae_module_for_h0(
            normalized_x_7am, # Original node features (normalized)
            pyg_batch_7am.edge_index,
            edge_attr_7am
        ) # This output is the transformed features from CGVAE's decoder

        h0_features_for_lstm_nodes = self.h0_c0_from_cgvae_encoder(cgvae_output_7am)

        h0_for_lstm = h0_features_for_lstm_nodes.unsqueeze(0)
        c0_for_lstm = torch.zeros_like(h0_for_lstm)

        if self.lstm.num_layers > 1:
            h0_for_lstm = h0_for_lstm.repeat(self.lstm.num_layers, 1, 1)
            c0_for_lstm = c0_for_lstm.repeat(self.lstm.num_layers, 1, 1)

        initial_hidden_state = (h0_for_lstm, c0_for_lstm)

        all_lstm_input_features_over_time = []
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            normalized_x = (pyg_batch_this_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

            edge_attr_this_timestep = pyg_batch_this_timestep.edge_attr
            if edge_attr_this_timestep is not None:
                edge_attr_this_timestep = edge_attr_this_timestep.float()
            else:
                raise ValueError("edge_attr cannot be None for RGCN_CGVAE_FeatureExtractor in sequence.")


            cgvae_output_nodes_t = self.cgvae_module_for_sequence(
                normalized_x, # Original node features (normalized) for this timestep
                pyg_batch_this_timestep.edge_index,
                edge_attr_this_timestep
            )

            global_env_feat_t_unencoded = pyg_batch_this_timestep.graph_global_env_features
            # ... (rest of global_env and time feature processing remains the same)
            expected_num_graphs_in_batch_t = pyg_batch_this_timestep.num_graphs
            expected_global_features_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unencoded.shape == (expected_num_graphs_in_batch_t, expected_global_features_dim)):
                if global_env_feat_t_unencoded.ndim == 1 and \
                   global_env_feat_t_unencoded.shape[0] == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                elif global_env_feat_t_unencoded.numel() == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                else:
                    print(f"Warning: LSTM Input Time {t_pred_idx}: Correcting global_env_feat shape from {global_env_feat_t_unencoded.shape} to ({expected_num_graphs_in_batch_t}, {expected_global_features_dim}) with zeros due to mismatch.")
                    global_env_feat_t_unencoded = torch.zeros(expected_num_graphs_in_batch_t, expected_global_features_dim, device=device)
            global_env_emb_t = self.global_env_encoder(global_env_feat_t_unencoded)
            global_env_emb_t_expanded = global_env_emb_t[pyg_batch_this_timestep.batch]

            current_raw_time_feat_for_timestep_t = timeline_time_features[t_pred_idx, :].to(device)
            current_emb_time_feat_for_timestep_t = self.time_encoder(current_raw_time_feat_for_timestep_t)
            num_nodes_in_pyg_batch = pyg_batch_this_timestep.num_nodes
            time_emb_t_expanded_to_nodes = current_emb_time_feat_for_timestep_t.unsqueeze(0).expand(num_nodes_in_pyg_batch, -1)

            concatenated_features_for_timestep_t = torch.cat([cgvae_output_nodes_t, global_env_emb_t_expanded, time_emb_t_expanded_to_nodes], dim=-1)
            fused_features_for_timestep_t = self.fusion_mlp(concatenated_features_for_timestep_t)
            all_lstm_input_features_over_time.append(fused_features_for_timestep_t)

        stacked_lstm_input_features = torch.stack(all_lstm_input_features_over_time, dim=1)

        if initial_hidden_state[0].shape[1] != stacked_lstm_input_features.shape[0]:
            print(f"CRITICAL WARNING: Node count mismatch for LSTM h0/c0 ({initial_hidden_state[0].shape[1]}) and LSTM input sequence ({stacked_lstm_input_features.shape[0]}).")
            if initial_hidden_state[0].shape[1] > stacked_lstm_input_features.shape[0]:
                h0_adj = initial_hidden_state[0][:, :stacked_lstm_input_features.shape[0], :]
                c0_adj = initial_hidden_state[1][:, :stacked_lstm_input_features.shape[0], :]
                initial_hidden_state = (h0_adj, c0_adj)

        lstm_out, _ = self.lstm(stacked_lstm_input_features, initial_hidden_state)

        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            lstm_out_t = lstm_out[:, t, :]
            prediction_t_scaled = self.hourly_prediction_heads[t](lstm_out_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 (在原始尺度上计算指标) - NO CHANGES
# ===========================================================
# ... (mse_loss_masked, calculate_hourly_metrics are identical) ...
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 训练与评估循环 (适配y归一化) - NO CHANGES
# ===========================================================
# ... (train_epoch, evaluate_epoch are identical) ...
def train_epoch(model, loader, optimizer, device, timeline_time_features,
                node_feat_mean, node_feat_std, target_mean, target_std):
    model.train()
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    epoch_start_time = time.time()
    for list_of_batched_timesteps in loader:
        optimizer.zero_grad()
        predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)

        targets_list_for_loss_scaled = []
        first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
        mask_for_loss = ~first_predicted_timestep_batch.building_mask

        for t_pred_idx in range(model.T_pred_horizon):
            current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)

        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)
        loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled, mask_for_loss)
        num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

        if not torch.isnan(loss) and not torch.isinf(loss) and loss.item() > 0 :
            loss.backward()
            optimizer.step()
            total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
        num_sequences_processed += num_sequences_in_this_super_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0
    return avg_loss, epoch_duration

def evaluate_epoch(model, loader, device, timeline_time_features,
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    model.eval()
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0
    num_sequences_processed = 0
    model.node_feat_mean = node_feat_mean.to(device)
    model.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []
            first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device)
            mask_for_metrics = ~first_predicted_timestep_batch.building_mask

            for t_pred_idx in range(model.T_pred_horizon):
                current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)
            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)
            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (CGVAE-LSTM)
# ===========================================================

def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    # ... (identical) ...
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report

def main_training_cgvae_lstm_hourly_heads( # Renamed
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    expected_input_len = T_PRED_HORIZON + 1
    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                all_train_node_features_list.append(graph_data.x)
            if i_step > 0 and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0:
                    all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])

    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_cgvae_lstm.pth" # Renamed
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")

    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_cgvae_lstm.pth" # Renamed
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization
    sample_graph_7am_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_7am_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_7am_for_dims.graph_global_env_features.shape[0] if sample_graph_7am_for_dims.graph_global_env_features.ndim == 1 else sample_graph_7am_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]
    num_relations = config.get('num_relations', 5) # Needed for RGCNBlocks in CGVAE

    model = CGVAELSTMModelWithHourlyHeads(
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),

        cgvae_hidden_dim=config.get('gcn_hidden_dim', 128),       # Repurposed from 'gcn_hidden_dim'
        cgvae_latent_dim=config.get('cgvae_latent_dim', 64),     # New config, e.g., 64
        cgvae_output_dim=config.get('gcn_output_dim', 128),     # Repurposed, output of CGVAE decoder
        cgvae_num_encoder_layers=config.get('cgvae_num_encoder_layers', 2), # New
        cgvae_num_decoder_layers=config.get('cgvae_num_decoder_layers', 2), # New
        cgvae_dropout_rate=config.get('dropout_rate_gcn', 0.3), # Repurposed
        num_relations=num_relations,

        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1),
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)

    model.node_feat_mean = node_feat_mean.to(device); model.node_feat_std = node_feat_std.to(device)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"模型总参数量 (CGVAE-LSTM): {total_params:,}"); report_data['model_total_parameters'] = total_params

    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500
    dummy_edges_component = 60000
    dummy_batch_global_comp = 1
    model.eval()

    # 1. RGCN_CGVAE_FeatureExtractor (Profile one instance)
    try:
        cgvae_to_profile = model.cgvae_module_for_h0
        dummy_x_cgvae = torch.randn(dummy_nodes_component, model.static_node_in_dim, device=device)
        dummy_ei_cgvae = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        # Edge_attr for RGCNConv within CGVAE needs relation types.
        # Assuming edge_attr from data has at least 5 columns, and col 4 is relation type.
        # If not, adjust dummy_ea_cgvae and RGCNBlock's edge_type extraction.
        dummy_ea_cgvae = torch.randn(dummy_edges_component, 5, device=device) # Min 5 cols for previous RGCNBlock
        dummy_ea_cgvae[:, 4] = torch.randint(0, model.num_relations, (dummy_edges_component,), device=device).float()

        # It's hard to profile the whole CGVAE if reparameterize uses training flag.
        # Let's profile encoder and decoder parts containing RGCNBlocks.
        # For simplicity, we'll try to profile the whole cgvae_to_profile.
        # If it fails, we'll need to sum parts.
        # To make it deterministic for profiling, temporarily override reparameterize or ensure model.training=False

        # Store original training state
        original_training_state = cgvae_to_profile.training
        cgvae_to_profile.eval() # Ensure deterministic path (use mu)

        macs_cgvae = torchprofile.profile_macs(cgvae_to_profile, args=(dummy_x_cgvae, dummy_ei_cgvae, dummy_ea_cgvae))

        cgvae_to_profile.train(original_training_state) # Restore original state

        report_data['component_gmacs']['rgcn_cgvae_feature_extractor'] = macs_cgvae / 1e9
        print(f"  RGCN_CGVAE_FeatureExtractor GMACs: {macs_cgvae / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling RGCN_CGVAE_FeatureExtractor: {e}")
        report_data['component_gmacs']['rgcn_cgvae_feature_extractor'] = "Error"
        if hasattr(cgvae_to_profile, 'train'): # Ensure it's a nn.Module
            cgvae_to_profile.train(original_training_state) # Restore original state on error too


    # 2. LSTM Layer (Manual MAC Calculation)
    # ... (LSTM MAC calculation code identical to GINE-LSTM version) ...
    print(f"  Manually Calculating MACs for LSTM Layer:")
    try:
        lstm_layer = model.lstm
        N_nodes = dummy_nodes_component
        L_seq = T_PRED_HORIZON
        H_in = lstm_layer.input_size
        H_hidden = lstm_layer.hidden_size
        num_layers = lstm_layer.num_layers
        macs_lstm_manual = 0
        macs_lstm_manual = N_nodes * L_seq * 4 * (H_in * H_hidden + H_hidden * H_hidden)
        if num_layers > 1:
            macs_lstm_manual += N_nodes * L_seq * (num_layers - 1) * 4 * (H_hidden * H_hidden + H_hidden * H_hidden)
        gmacs_lstm_manual = macs_lstm_manual / 1e9
        report_data['component_gmacs']['lstm_layer'] = gmacs_lstm_manual
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = "Manually calculated based on formula."
        print(f"  LSTM Parameters: input_size={H_in}, hidden_size={H_hidden}, num_layers={num_layers}")
        print(f"  Used for calculation: N_nodes={N_nodes}, L_seq={L_seq}")
        print(f"  LSTM Layer GMACs (Manual): {gmacs_lstm_manual:.4f} (for sequence length {L_seq})")
    except Exception as e:
        print(f"  Error manually calculating LSTM Layer MACs: {e}")
        report_data['component_gmacs']['lstm_layer'] = "Error"
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = f"Error during manual calculation: {str(e)}"


    # 3. Fusion MLP, 4. Prediction Head, 5. Global Env Encoder, 6. Time Encoder
    try:
        fusion_mlp_to_profile = model.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"

    try:
        pred_head_to_profile = model.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model.lstm_hidden_dim, device=device)
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP (single hour) GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"

    try:
        encoder_to_profile = model.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"

    try:
        encoder_to_profile = model.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model.time_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"

    # 7. H0/C0 from CGVAE Encoder (MLPEncoder)
    try:
        encoder_to_profile = model.h0_c0_from_cgvae_encoder
        dummy_input_encoder = torch.randn(dummy_nodes_component, model.cgvae_output_dim, device=device) # Input is cgvae_output_dim
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['h0_c0_from_cgvae_encoder_mlp'] = macs_encoder / 1e9
        print(f"  H0/C0 from CGVAE Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling H0/C0 from CGVAE Encoder: {e}"); report_data['component_gmacs']['h0_c0_from_cgvae_encoder_mlp'] = "Error"

    model.train()
    # ===== End Component FLOPS Calculation =====

    optimizer = torch.optim.Adam(model.parameters(), lr=config.get('lr', 0.001), weight_decay=config.get('weight_decay', 1e-5))
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.get('scheduler_patience', 20), verbose=True)

    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    model_save_path = Path(config['results_dir']) / f"best_cgvaelstm_hourly_heads_model_seed{seed}.pth" # Renamed
    timeline_time_features_on_device = time_features_for_dataset.to(device)
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        train_loss_scaled, epoch_duration = train_epoch(model, train_loader, optimizer, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_on_device, target_std_on_device)
        epoch_times.append(epoch_duration)
        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(model, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Validation")
        scheduler.step(val_loss_scaled)
        print(f"Epoch {epoch:03d} | Train Scaled MSE: {train_loss_scaled:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Epoch Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")
        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0; torch.save(model.state_dict(), model_save_path)
            print(f"                     ---> Best model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    model_for_eval = CGVAELSTMModelWithHourlyHeads( # Use CGVAELSTMModel
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        cgvae_hidden_dim=config.get('gcn_hidden_dim', 128),
        cgvae_latent_dim=config.get('cgvae_latent_dim', 64),
        cgvae_output_dim=config.get('gcn_output_dim', 128),
        cgvae_num_encoder_layers=config.get('cgvae_num_encoder_layers', 2),
        cgvae_num_decoder_layers=config.get('cgvae_num_decoder_layers', 2),
        cgvae_dropout_rate=config.get('dropout_rate_gcn', 0.3),
        num_relations=num_relations,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)
    try: model_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
    except Exception as e: print(f"无法加载最佳模型 ({e})，将使用训练循环结束时的模型。"); model_for_eval = model

    print("\n评估最佳模型在训练集上..."); best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(model_for_eval, train_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Train")
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics; report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration; _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original: report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    else:
        _, reeval_val_metrics, val_eval_duration = evaluate_epoch(model_for_eval, val_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Best Model on Val (Re-eval)")
        report_data['best_model_validation_set_metrics_hourly'] = reeval_val_metrics
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(reeval_val_metrics, T_PRED_HORIZON)
        report_data['best_model_validation_set_eval_time_seconds'] = val_eval_duration
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳模型在测试集上..."); test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(model_for_eval, test_loader, device, timeline_time_features_on_device, node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test")
    report_data['test_set_inference_time_seconds'] = test_inference_duration; report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original; report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (CGVAE-LSTM) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}") # Renamed
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    report_file_path = Path(config['results_dir']) / f"training_report_cgvae_lstm_seed{seed}.json" # Renamed
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_for_eval, node_feat_mean, node_feat_std, target_mean, target_std

def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    # ... (identical) ...
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")

# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")

    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data")
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR = Path("Result/Final_CGVAELSTM1")
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME

    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12

    training_config = {
        'seed': 42, 'batch_size': 8, 'lr': 0.001, 'weight_decay': 1e-5,
        'max_epochs': 1000, 'scheduler_patience': 20, 'early_stopping_patience': 45,
        'T_pred_horizon': PREDICTION_HORIZON, 'results_dir': str(RESULTS_SAVE_DIR),
        'global_env_emb_dim': 16, 'time_emb_dim': 8,

        # Parameters for CGVAE (using repurposed GCN keys)
        'gcn_hidden_dim': 128,           # -> cgvae_hidden_dim
        'cgvae_latent_dim': 64,         # New: Latent dimension for CGVAE
        'gcn_output_dim': 128,           # -> cgvae_output_dim (output of CGVAE decoder)
        'cgvae_num_encoder_layers': 3,  # New
        'cgvae_num_decoder_layers': 3,  # New
        'dropout_rate_gcn': 0.3,        # -> cgvae_dropout_rate
        'num_relations': 5,             # For RGCNConv within CGVAE

        # Parameters for LSTM (using repurposed GRU keys)
        'gru_hidden_dim': 128, # This will be lstm_hidden_dim
        'num_gru_layers': 1,  # This will be num_lstm_layers
        'dropout_rate_gru': 0.2, # This will be dropout_rate_lstm

        # MLP and other parameters
        'mlp_prediction_hidden_dim': 64,
        'fusion_mlp_output_dim': 128,
        'fusion_mlp_hidden_dim': 64,
        'dropout_rate_fusion_mlp': 0.2, 'dropout_rate_encoders': 0.1,
        'dropout_rate_pred_head': 0.2,
        'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2,
        'h0_from_first_step': True
    }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")

        expected_len_per_sequence = training_config['T_pred_horizon'] + 1

        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                # Check for edge_attr for RGCN-CGVAE
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'edge_attr') or graph_step_data.edge_attr is None or \
                   graph_step_data.edge_attr.shape[1] < (1 if training_config.get('num_relations',0) > 0 else 0) or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx > 0 and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    # If RGCN needs specific edge_attr shape (e.g. for relation types)
                    if training_config.get('num_relations',0) > 0 and hasattr(graph_step_data, 'edge_attr') and \
                       graph_step_data.edge_attr is not None and graph_step_data.edge_attr.shape[1] < 5: # Example check
                       print(f"Warning: Seq {i}, step {step_idx} has edge_attr dim {graph_step_data.edge_attr.shape[1]}, RGCN may need more for types.")
                    valid_seq = False; break
                if step_idx > 0 and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset_timeline = generate_time_features_for_sequence(
            base_datetime_for_timeline,
            training_config['T_pred_horizon']
        )

        trained_model, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_cgvae_lstm_hourly_heads( # Renamed call
            all_graph_sequences, training_config, time_features_for_dataset_timeline
        )
        print("CGVAE-LSTM 模型训练和评估完成!")
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##GGAN+LSTM

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import RGCNConv
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta
from pathlib import Path
import time
import json
import torchprofile

# gc.collect()
# torch.cuda.empty_cache()
# torch.cuda.ipc_collect()


# ===========================================================
# 1. 特征生成 & 辅助模块 (No changes)
# ===========================================================

def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm)
        hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm)
        doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)


class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.LayerNorm(hid_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hid_dim, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

# ===========================================================
# 2. C-GGAN Components (using RGCNConv)
# ===========================================================

class RGCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_relations, dropout_rate,
                 activation_fn=None, use_residual=True, is_discriminator_block=False): # Added is_discriminator_block for clarity
        super().__init__()
        self.conv = RGCNConv(in_channels, out_channels, num_relations=num_relations)
        self.norm = nn.BatchNorm1d(out_channels)
        self.activation = activation_fn if activation_fn is not None else nn.PReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.use_residual = use_residual
        self.is_discriminator_block = is_discriminator_block


        if self.use_residual:
            if in_channels == out_channels:
                self.residual_projection = nn.Identity()
            else:
                self.residual_projection = nn.Linear(in_channels, out_channels)

    def forward(self, x_input, edge_index, edge_attr):
        if x_input.size(0) == 0: return x_input

        if edge_attr is None:
             # If edge_attr is None, but num_relations > 1, RGCNConv will error.
             # For this setup, we expect edge_attr to always be provided if num_relations > 1.
            if self.conv.num_relations > 1:
                 raise ValueError("RGCNBlock: edge_attr is None but num_relations > 1 in RGCNConv.")
            edge_type = None # GCN-like behavior if num_relations is 1
        else:
            if edge_attr.shape[1] >= 5:
                edge_type = edge_attr[:, 4].long()
            elif edge_attr.shape[1] > 0 : # Assume first column is relation type if fewer than 5 cols
                edge_type = edge_attr[:,0].long()
            else:
                raise ValueError("RGCNBlock: edge_attr has 0 columns but is not None.")


        h = self.conv(x_input, edge_index, edge_type=edge_type)
        if h.shape[0] > 1:
            h = self.norm(h)
        h = self.activation(h)
        h = self.dropout(h)

        if self.use_residual:
            projected_x_input = self.residual_projection(x_input)
            h = h + projected_x_input
        return h

class NodeFeatureGeneratorRGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, num_relations, dropout_rate):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_relations = num_relations

        layers = []
        current_dim = input_dim
        for i in range(num_layers):
            block_use_residual = (i > 0)
            layers.append(RGCNBlock(current_dim, hidden_dim, num_relations, dropout_rate,
                                   use_residual=block_use_residual if current_dim == hidden_dim else False))
            current_dim = hidden_dim
        self.rgcn_layers = nn.ModuleList(layers)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, edge_attr):
        for layer in self.rgcn_layers:
            x = layer(x, edge_index, edge_attr)
        transformed_features = self.fc_out(x)
        return transformed_features

class PredictionDiscriminatorRGCN(nn.Module):
    def __init__(self, original_node_feature_dim, prediction_dim,
                 hidden_dim, num_layers, num_relations, dropout_rate):
        super().__init__()
        self.initial_input_dim = original_node_feature_dim + prediction_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_relations = num_relations

        layers = []
        current_dim = self.initial_input_dim
        for i in range(num_layers):
            block_use_residual = (i > 0)
            layers.append(RGCNBlock(current_dim, hidden_dim, num_relations, dropout_rate,
                                   activation_fn=nn.LeakyReLU(0.2, inplace=True),
                                   use_residual=block_use_residual if current_dim == hidden_dim else False,
                                   is_discriminator_block=True))
            current_dim = hidden_dim
        self.rgcn_layers = nn.ModuleList(layers)
        self.fc_out = nn.Linear(hidden_dim, 1)

    def forward(self, x_original_normalized, y_candidate_scaled, edge_index, edge_attr):
        # x_original_normalized: (N_masked_or_full, F_node_orig)
        # y_candidate_scaled: (N_masked_or_full, 1)
        # edge_index & edge_attr: Pertain to the graph from which x_original and y_candidate are derived.
        # IMPORTANT: If x_original_normalized and y_candidate_scaled are subsets of nodes,
        # edge_index must be re-mapped to this subset for RGCNBlock to work correctly.
        # The current fix will pass full x and y and mask AFTER discriminator.

        discriminator_input_features = torch.cat([x_original_normalized, y_candidate_scaled], dim=1)

        h = discriminator_input_features
        for layer in self.rgcn_layers:
            h = layer(h, edge_index, edge_attr) # edge_attr should be for the full graph
        logits = self.fc_out(h)
        return logits


# ===========================================================
# 2.b CGGAN-LSTM 模型定义 (Main Model - Generator Path)
# ===========================================================
class CGGANLSTMModelWithHourlyHeads(nn.Module):
    def __init__(self,
                 static_node_in_dim,
                 global_env_in_dim,
                 time_in_dim,
                 global_env_emb_dim,
                 time_emb_dim,
                 gen_hidden_dim,
                 gen_output_dim,
                 gen_num_layers,
                 gen_dropout_rate,
                 num_relations,
                 lstm_hidden_dim,
                 fusion_mlp_output_dim=None,
                 fusion_mlp_hidden_dim=None,
                 dropout_rate_fusion_mlp=0.1,
                 num_lstm_layers=1,
                 T_pred_horizon=12,
                 dropout_rate_encoders=0.1,
                 dropout_rate_lstm=0.2,
                 mlp_prediction_hidden_dim=64,
                 dropout_rate_pred_head=0.2
                ):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon
        self.static_node_in_dim = static_node_in_dim
        self.global_env_in_dim = global_env_in_dim
        self.time_in_dim = time_in_dim
        self.gen_output_dim = gen_output_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.num_relations = num_relations

        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_encoders)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_encoders)

        self.h0_c0_from_gnn_encoder = MLPEncoder(gen_output_dim, lstm_hidden_dim, dropout_rate=dropout_rate_encoders)

        self.node_feature_generator_h0 = NodeFeatureGeneratorRGCN(
            input_dim=static_node_in_dim, hidden_dim=gen_hidden_dim, output_dim=gen_output_dim,
            num_layers=gen_num_layers, num_relations=num_relations, dropout_rate=gen_dropout_rate
        )
        self.node_feature_generator_sequence = NodeFeatureGeneratorRGCN(
            input_dim=static_node_in_dim, hidden_dim=gen_hidden_dim, output_dim=gen_output_dim,
            num_layers=gen_num_layers, num_relations=num_relations, dropout_rate=gen_dropout_rate
        )

        concatenated_feature_dim = gen_output_dim + global_env_emb_dim + time_emb_dim
        actual_fusion_mlp_output_dim = fusion_mlp_output_dim if fusion_mlp_output_dim is not None else concatenated_feature_dim
        self.fusion_mlp_input_dim = concatenated_feature_dim

        self.fusion_mlp = MLPEncoder(
            in_dim=concatenated_feature_dim,
            out_dim=actual_fusion_mlp_output_dim,
            hid_dim=fusion_mlp_hidden_dim,
            dropout_rate=dropout_rate_fusion_mlp
        )

        lstm_input_size_actual = actual_fusion_mlp_output_dim
        self.lstm_input_dim = lstm_input_size_actual

        self.lstm = nn.LSTM(
            input_size=lstm_input_size_actual,
            hidden_size=lstm_hidden_dim,
            num_layers=num_lstm_layers,
            batch_first=True,
            dropout=dropout_rate_lstm if num_lstm_layers > 1 else 0.0
        )

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon):
            self.hourly_prediction_heads.append(
                nn.Sequential(
                    nn.Linear(lstm_hidden_dim, mlp_prediction_hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate_pred_head),
                    nn.Linear(mlp_prediction_hidden_dim, 1)
                )
            )

        self.register_buffer('node_feat_mean', torch.zeros(static_node_in_dim))
        self.register_buffer('node_feat_std', torch.ones(static_node_in_dim))

    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        # Ensure x is normalized before passing to generator
        normalized_x_7am = (pyg_batch_7am.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

        edge_attr_7am = pyg_batch_7am.edge_attr
        if edge_attr_7am is None: raise ValueError("edge_attr required for NodeFeatureGeneratorRGCN (h0).")
        edge_attr_7am = edge_attr_7am.float()

        generated_features_7am = self.node_feature_generator_h0(
            normalized_x_7am, # Pass normalized x
            pyg_batch_7am.edge_index,
            edge_attr=edge_attr_7am
        )
        h0_features_for_lstm_nodes = self.h0_c0_from_gnn_encoder(generated_features_7am)

        h0_for_lstm = h0_features_for_lstm_nodes.unsqueeze(0)
        c0_for_lstm = torch.zeros_like(h0_for_lstm)

        if self.lstm.num_layers > 1:
            h0_for_lstm = h0_for_lstm.repeat(self.lstm.num_layers, 1, 1)
            c0_for_lstm = c0_for_lstm.repeat(self.lstm.num_layers, 1, 1)
        initial_hidden_state = (h0_for_lstm, c0_for_lstm)

        all_lstm_input_features_over_time = []
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_this_timestep = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            normalized_x_t = (pyg_batch_this_timestep.x - self.node_feat_mean) / (self.node_feat_std + 1e-8)

            edge_attr_this_timestep = pyg_batch_this_timestep.edge_attr
            if edge_attr_this_timestep is None: raise ValueError("edge_attr required for NodeFeatureGeneratorRGCN (sequence).")
            edge_attr_this_timestep = edge_attr_this_timestep.float()

            generated_features_t = self.node_feature_generator_sequence(
                normalized_x_t, # Pass normalized x
                pyg_batch_this_timestep.edge_index,
                edge_attr=edge_attr_this_timestep
            )

            global_env_feat_t_unencoded = pyg_batch_this_timestep.graph_global_env_features
            expected_num_graphs_in_batch_t = pyg_batch_this_timestep.num_graphs
            expected_global_features_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unencoded.shape == (expected_num_graphs_in_batch_t, expected_global_features_dim)):
                if global_env_feat_t_unencoded.ndim == 1 and \
                   global_env_feat_t_unencoded.shape[0] == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                elif global_env_feat_t_unencoded.numel() == expected_num_graphs_in_batch_t * expected_global_features_dim:
                    global_env_feat_t_unencoded = global_env_feat_t_unencoded.view(expected_num_graphs_in_batch_t, expected_global_features_dim)
                else:
                    print(f"Warning: LSTM Input Time {t_pred_idx}: Correcting global_env_feat shape from {global_env_feat_t_unencoded.shape} to ({expected_num_graphs_in_batch_t}, {expected_global_features_dim}) with zeros due to mismatch.")
                    global_env_feat_t_unencoded = torch.zeros(expected_num_graphs_in_batch_t, expected_global_features_dim, device=device)
            global_env_emb_t = self.global_env_encoder(global_env_feat_t_unencoded)
            global_env_emb_t_expanded = global_env_emb_t[pyg_batch_this_timestep.batch]

            current_raw_time_feat_for_timestep_t = timeline_time_features[t_pred_idx, :].to(device)
            current_emb_time_feat_for_timestep_t = self.time_encoder(current_raw_time_feat_for_timestep_t)
            num_nodes_in_pyg_batch = pyg_batch_this_timestep.num_nodes
            time_emb_t_expanded_to_nodes = current_emb_time_feat_for_timestep_t.unsqueeze(0).expand(num_nodes_in_pyg_batch, -1)

            concatenated_features_for_timestep_t = torch.cat([generated_features_t, global_env_emb_t_expanded, time_emb_t_expanded_to_nodes], dim=-1)
            fused_features_for_timestep_t = self.fusion_mlp(concatenated_features_for_timestep_t)
            all_lstm_input_features_over_time.append(fused_features_for_timestep_t)

        stacked_lstm_input_features = torch.stack(all_lstm_input_features_over_time, dim=1)

        if initial_hidden_state[0].shape[1] != stacked_lstm_input_features.shape[0]:
            print(f"CRITICAL WARNING: Node count mismatch for LSTM h0/c0 ({initial_hidden_state[0].shape[1]}) and LSTM input sequence ({stacked_lstm_input_features.shape[0]}).")
            if initial_hidden_state[0].shape[1] > stacked_lstm_input_features.shape[0]:
                h0_adj = initial_hidden_state[0][:, :stacked_lstm_input_features.shape[0], :]
                c0_adj = initial_hidden_state[1][:, :stacked_lstm_input_features.shape[0], :]
                initial_hidden_state = (h0_adj, c0_adj)

        lstm_out, _ = self.lstm(stacked_lstm_input_features, initial_hidden_state)

        all_hourly_final_predictions_scaled = []
        for t in range(self.T_pred_horizon):
            lstm_out_t = lstm_out[:, t, :]
            prediction_t_scaled = self.hourly_prediction_heads[t](lstm_out_t)
            all_hourly_final_predictions_scaled.append(prediction_t_scaled.squeeze(-1))

        predictions_scaled = torch.stack(all_hourly_final_predictions_scaled, dim=1)
        return predictions_scaled

# ===========================================================
# 3. 评估指标函数 - NO CHANGES
# ===========================================================
# ... (mse_loss_masked, calculate_hourly_metrics are identical) ...
def mse_loss_masked(predictions_scaled, targets_scaled, mask):
    expanded_mask = mask.unsqueeze(1).expand_as(targets_scaled)
    valid_targets_mask = ~torch.isnan(targets_scaled)
    final_mask = expanded_mask & valid_targets_mask
    if final_mask.sum() == 0:
        return torch.tensor(0.0, device=predictions_scaled.device, requires_grad=True)
    loss = F.mse_loss(predictions_scaled[final_mask], targets_scaled[final_mask])
    return loss

def calculate_hourly_metrics(predictions_scaled, targets_scaled, node_masks, target_mean, target_std):
    target_mean_cpu = target_mean.cpu()
    target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu

    num_nodes_total, T_horizon = preds_unscaled.shape
    hourly_metrics_dict = {}

    preds_np = preds_unscaled.numpy()
    targets_np = targets_unscaled.numpy()
    mask_np = node_masks.cpu().numpy()

    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]
        targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]
        targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final_valid = preds_t_on_loss_nodes[valid_target_data_mask_t]
        targets_t_final_valid = targets_t_on_loss_nodes[valid_target_data_mask_t]

        if preds_t_final_valid.shape[0] < 2:
            hourly_metrics_dict[t] = {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count': 0}
            continue
        mse = np.mean((preds_t_final_valid - targets_t_final_valid)**2)
        mae = np.mean(np.abs(preds_t_final_valid - targets_t_final_valid))
        rmse = np.sqrt(mse)
        try:
            r2 = r2_score(targets_t_final_valid, preds_t_final_valid)
        except ValueError:
            r2 = np.nan
        hourly_metrics_dict[t] = {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2, 'count': preds_t_final_valid.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. 对抗训练与评估循环
# ===========================================================
def train_epoch_adversarial(model_G, model_D, loader,
                            optimizer_G, optimizer_D,
                            device, timeline_time_features,
                            node_feat_mean, node_feat_std,
                            target_mean, target_std,
                            lambda_adv, T_pred_horizon, static_node_in_dim): # Added static_node_in_dim
    model_G.train()
    model_D.train()

    total_loss_G_epoch = 0
    total_loss_D_epoch = 0
    total_pred_loss_epoch = 0
    total_adv_loss_G_epoch = 0
    num_sequences_processed = 0

    # Ensure GNN modules within model_G have access to scalers if needed by GNN input normalization
    # model_G itself handles it via registered buffers, which is good.

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)
    node_feat_mean_dev = node_feat_mean.to(device) # For normalizing x before D
    node_feat_std_dev = node_feat_std.to(device)   # For normalizing x before D


    epoch_start_time = time.time()

    for list_of_batched_timesteps in loader:
        # list_of_batched_timesteps[0] is obs, [1:] are for prediction horizon
        num_graphs_in_batch = list_of_batched_timesteps[0].num_graphs


        # --- Train Discriminator ---
        optimizer_D.zero_grad()

        with torch.no_grad():
            predictions_batch_scaled_fake = model_G(list_of_batched_timesteps,
                                                    timeline_time_features.to(device),
                                                    device).detach()

        loss_D_batch_total = torch.tensor(0.0, device=device)

        for t_idx in range(T_pred_horizon):
            current_graph_data = list_of_batched_timesteps[t_idx + 1].to(device)
            x_original_t_normalized = (current_graph_data.x - node_feat_mean_dev) / (node_feat_std_dev + 1e-8)

            y_real_t_original = current_graph_data.y.squeeze()
            y_real_t_scaled = ((y_real_t_original - target_mean_dev) / (target_std_dev + 1e-8)).unsqueeze(-1)
            y_fake_t_scaled = predictions_batch_scaled_fake[:, t_idx].unsqueeze(-1)

            mask_for_D_loss_t = ~current_graph_data.building_mask
            valid_targets_mask_t = ~torch.isnan(y_real_t_scaled.squeeze())
            final_mask_real_t = mask_for_D_loss_t & valid_targets_mask_t

            edge_attr_t = current_graph_data.edge_attr
            if edge_attr_t is not None: edge_attr_t = edge_attr_t.float()
            else: raise ValueError(f"Discriminator requires edge_attr at t_idx={t_idx}")

            # Real samples
            if final_mask_real_t.sum() > 0:
                # Pass only the masked nodes' features to D if D is not graph-aware beyond node features
                # However, D uses RGCNConv, so it needs full graph structure for context,
                # but we want to compute loss only on relevant nodes.
                # So, D processes all nodes, then we mask the logits.
                logits_real_all_nodes = model_D(x_original_t_normalized,
                                                y_real_t_scaled,
                                                current_graph_data.edge_index,
                                                edge_attr_t)
                loss_D_real_t = F.binary_cross_entropy_with_logits(
                    logits_real_all_nodes[final_mask_real_t],
                    torch.ones_like(logits_real_all_nodes[final_mask_real_t])
                )
            else:
                loss_D_real_t = torch.tensor(0.0, device=device)

            # Fake samples
            if mask_for_D_loss_t.sum() > 0:
                logits_fake_all_nodes = model_D(x_original_t_normalized,
                                                y_fake_t_scaled,
                                                current_graph_data.edge_index,
                                                edge_attr_t)
                loss_D_fake_t = F.binary_cross_entropy_with_logits(
                    logits_fake_all_nodes[mask_for_D_loss_t],
                    torch.zeros_like(logits_fake_all_nodes[mask_for_D_loss_t])
                )
            else:
                loss_D_fake_t = torch.tensor(0.0, device=device)

            loss_D_t = (loss_D_real_t + loss_D_fake_t) / 2.0
            if not (torch.isnan(loss_D_t) or torch.isinf(loss_D_t)):
                 loss_D_batch_total += loss_D_t

        if T_pred_horizon > 0: # Avoid division by zero if T_pred_horizon is 0
            avg_loss_D_batch = loss_D_batch_total / T_pred_horizon
            if avg_loss_D_batch.item() > 0 :
                 avg_loss_D_batch.backward()
                 optimizer_D.step()
            total_loss_D_epoch += avg_loss_D_batch.item() * num_graphs_in_batch


        # --- Train Generator (Main Model G) ---
        optimizer_G.zero_grad()

        predictions_batch_scaled_for_G = model_G(list_of_batched_timesteps,
                                                 timeline_time_features.to(device),
                                                 device)

        # 1. Prediction Loss (MSE)
        targets_list_for_loss_scaled = []
        first_prediction_graph_step_in_batch_G = list_of_batched_timesteps[1].to(device)
        mask_for_G_pred_loss = ~first_prediction_graph_step_in_batch_G.building_mask

        for t_pred_idx in range(T_pred_horizon):
            current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
            targets_t_nodes_scaled = (targets_t_nodes_original - target_mean_dev) / (target_std_dev + 1e-8)
            targets_list_for_loss_scaled.append(targets_t_nodes_scaled)
        targets_batch_scaled = torch.stack(targets_list_for_loss_scaled, dim=1)

        loss_pred_G = mse_loss_masked(predictions_batch_scaled_for_G, targets_batch_scaled, mask_for_G_pred_loss)

        # 2. Adversarial Loss for Generator
        loss_adv_G_batch_total = torch.tensor(0.0, device=device)
        for t_idx in range(T_pred_horizon):
            current_graph_data_G = list_of_batched_timesteps[t_idx + 1].to(device)
            x_original_normalized_t_G = (current_graph_data_G.x - node_feat_mean_dev) / (node_feat_std_dev + 1e-8)
            y_fake_t_scaled_for_G = predictions_batch_scaled_for_G[:, t_idx].unsqueeze(-1)
            mask_for_adv_G_loss_t = ~current_graph_data_G.building_mask

            edge_attr_t_G = current_graph_data_G.edge_attr
            if edge_attr_t_G is not None: edge_attr_t_G = edge_attr_t_G.float()
            else: raise ValueError(f"Generator's adversarial pass requires edge_attr at t_idx={t_idx}")


            if mask_for_adv_G_loss_t.sum() > 0:
                # Pass full node features, D will process graph, then mask logits for loss
                fake_logits_for_G_all_nodes = model_D(x_original_normalized_t_G,
                                                      y_fake_t_scaled_for_G,
                                                      current_graph_data_G.edge_index,
                                                      edge_attr_t_G)
                loss_adv_G_t = F.binary_cross_entropy_with_logits(
                    fake_logits_for_G_all_nodes[mask_for_adv_G_loss_t],
                    torch.ones_like(fake_logits_for_G_all_nodes[mask_for_adv_G_loss_t])
                )
                if not (torch.isnan(loss_adv_G_t) or torch.isinf(loss_adv_G_t)):
                    loss_adv_G_batch_total += loss_adv_G_t

        avg_loss_adv_G_batch = torch.tensor(0.0, device=device)
        if T_pred_horizon > 0:
            avg_loss_adv_G_batch = loss_adv_G_batch_total / T_pred_horizon

        loss_G_total_batch = loss_pred_G + lambda_adv * avg_loss_adv_G_batch

        if not torch.isnan(loss_G_total_batch) and not torch.isinf(loss_G_total_batch) and loss_G_total_batch.item() > 0:
            loss_G_total_batch.backward()
            optimizer_G.step()

        total_loss_G_epoch += loss_G_total_batch.item() * num_graphs_in_batch
        total_pred_loss_epoch += loss_pred_G.item() * num_graphs_in_batch
        total_adv_loss_G_epoch += avg_loss_adv_G_batch.item() * num_graphs_in_batch

        num_sequences_processed += num_graphs_in_batch

    epoch_duration = time.time() - epoch_start_time
    avg_loss_G = total_loss_G_epoch / num_sequences_processed if num_sequences_processed > 0 else 0.0
    avg_loss_D = total_loss_D_epoch / num_sequences_processed if num_sequences_processed > 0 else 0.0
    avg_pred_loss = total_pred_loss_epoch / num_sequences_processed if num_sequences_processed > 0 else 0.0
    avg_adv_loss_G = total_adv_loss_G_epoch / num_sequences_processed if num_sequences_processed > 0 else 0.0

    return avg_loss_G, avg_loss_D, avg_pred_loss, avg_adv_loss_G, epoch_duration


# Evaluate epoch remains the same, it only uses the Generator (main model)
def evaluate_epoch(model_G, loader, device, timeline_time_features, # model_G is the main model
                   node_feat_mean, node_feat_std, target_mean, target_std, epoch_type="Eval"):
    # ... (evaluate_epoch remains identical to the GINE-LSTM version, ensure it uses model_G) ...
    model_G.eval() # Use model_G for evaluation
    all_batch_predictions_scaled = []
    all_batch_targets_scaled = []
    all_batch_masks_for_metrics = []
    total_loss_scaled = 0 # This will be MSE loss
    num_sequences_processed = 0
    model_G.node_feat_mean = node_feat_mean.to(device)
    model_G.node_feat_std = node_feat_std.to(device)

    target_mean_dev = target_mean.to(device)
    target_std_dev = target_std.to(device)

    eval_start_time = time.time()
    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            predictions_batch_scaled = model_G(list_of_batched_timesteps, timeline_time_features.to(device), device)
            targets_list_original = []
            first_predicted_timestep_batch = list_of_batched_timesteps[1].to(device) # Assuming 1 obs step
            mask_for_metrics = ~first_predicted_timestep_batch.building_mask

            for t_pred_idx in range(model_G.T_pred_horizon): # Use model_G.T_pred_horizon
                current_target_timestep_batch = list_of_batched_timesteps[t_pred_idx + 1].to(device)
                targets_t_nodes_original = current_target_timestep_batch.y.squeeze()
                targets_list_original.append(targets_t_nodes_original)

            targets_batch_original = torch.stack(targets_list_original, dim=1)
            targets_batch_scaled_for_loss = (targets_batch_original - target_mean_dev) / (target_std_dev + 1e-8)

            loss = mse_loss_masked(predictions_batch_scaled, targets_batch_scaled_for_loss, mask_for_metrics)

            num_sequences_in_this_super_batch = list_of_batched_timesteps[0].num_graphs

            if not torch.isnan(loss) and not torch.isinf(loss):
                total_loss_scaled += loss.item() * num_sequences_in_this_super_batch
            num_sequences_processed += num_sequences_in_this_super_batch

            all_batch_predictions_scaled.append(predictions_batch_scaled.cpu())
            all_batch_targets_scaled.append(targets_batch_scaled_for_loss.cpu())
            all_batch_masks_for_metrics.append(mask_for_metrics.cpu())

    eval_duration = time.time() - eval_start_time
    avg_loss_scaled = total_loss_scaled / num_sequences_processed if num_sequences_processed > 0 else 0.0

    if not all_batch_predictions_scaled:
        empty_metrics = {t: {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0} for t in range(model_G.T_pred_horizon)}
        return avg_loss_scaled, empty_metrics, eval_duration

    final_predictions_scaled = torch.cat(all_batch_predictions_scaled, dim=0)
    final_targets_scaled = torch.cat(all_batch_targets_scaled, dim=0)
    final_masks_for_metrics = torch.cat(all_batch_masks_for_metrics, dim=0)

    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_scaled, final_targets_scaled,
                                                             final_masks_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_loss_scaled, hourly_metrics_original_scale, eval_duration


# ===========================================================
# 5. 主训练流程 (CGGAN-LSTM)
# ===========================================================
def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon):
    # ... (identical) ...
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values:
            aggregated_report[f'avg_{metric_name}'] = np.mean(values)
            aggregated_report[f'std_{metric_name}'] = np.std(values)
        else:
            aggregated_report[f'avg_{metric_name}'] = np.nan
            aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report

def main_training_cggan_lstm_hourly_heads( # Renamed
    all_sequences_data: list,
    config: dict,
    time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time()
    report_data = {'config': config}

    seed = config.get('seed', 42)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    report_data['device'] = str(device)

    T_PRED_HORIZON = config.get('T_pred_horizon', 12)
    expected_input_len = T_PRED_HORIZON + 1
    valid_sequences_data = [seq for seq in all_sequences_data if isinstance(seq, list) and len(seq) == expected_input_len]
    if len(valid_sequences_data) != len(all_sequences_data):
        print(f"警告: 从 {len(all_sequences_data)} 个序列中筛选出 {len(valid_sequences_data)} 个长度为 {expected_input_len} 的有效序列。")
    if not valid_sequences_data:
        raise ValueError(f"没有找到长度为 {expected_input_len} 的有效序列数据。")
    all_sequences_data = valid_sequences_data

    # Dataset split
    num_total_sequences = len(all_sequences_data)
    indices = np.random.permutation(num_total_sequences)
    train_split_ratio = config.get('train_split_ratio', 0.7)
    val_split_ratio = config.get('val_split_ratio', 0.2)
    train_size = int(train_split_ratio * num_total_sequences)
    val_size = int(val_split_ratio * num_total_sequences)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]
    train_dataset = [all_sequences_data[i] for i in train_indices]
    val_dataset   = [all_sequences_data[i] for i in val_indices]
    test_dataset  = [all_sequences_data[i] for i in test_indices]
    report_data['dataset_split'] = {'total_sequences': num_total_sequences, 'train_size': len(train_dataset), 'val_size': len(val_dataset), 'test_size': len(test_dataset)}

    # Scaler calculation
    all_train_node_features_list = []
    all_train_target_values_list_for_scaling = []
    for seq in train_dataset:
        for i_step, graph_data in enumerate(seq):
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                all_train_node_features_list.append(graph_data.x)
            if i_step > 0 and hasattr(graph_data, 'y') and graph_data.y is not None:
                y_original = graph_data.y.squeeze(); current_mask_for_loss = ~graph_data.building_mask
                valid_target_indices = current_mask_for_loss & ~torch.isnan(y_original)
                if valid_target_indices.sum() > 0:
                    all_train_target_values_list_for_scaling.append(y_original[valid_target_indices])

    if not all_train_node_features_list: raise ValueError("训练数据中未找到节点特征 'x'，无法计算scaler！")
    all_train_node_features_tensor = torch.cat(all_train_node_features_list, dim=0)
    node_feat_mean = torch.mean(all_train_node_features_tensor, dim=0); node_feat_std = torch.std(all_train_node_features_tensor, dim=0)
    node_feat_std[node_feat_std < 1e-8] = 1.0
    scaler_path_x = Path(config['results_dir']) / "node_feature_scaler_cggan_lstm.pth" # Renamed
    torch.save({'mean': node_feat_mean, 'std': node_feat_std}, scaler_path_x); print(f"节点特征x scaler已保存到: {scaler_path_x}")

    if not all_train_target_values_list_for_scaling:
        target_mean = torch.tensor(0.0); target_std = torch.tensor(1.0)
    else:
        all_train_target_values_tensor = torch.cat(all_train_target_values_list_for_scaling, dim=0)
        target_mean = torch.mean(all_train_target_values_tensor.float()); target_std = torch.std(all_train_target_values_tensor.float())
        if target_std < 1e-8: target_std = torch.tensor(1.0)
    target_scaler_path = Path(config['results_dir']) / "target_scaler_cggan_lstm.pth" # Renamed
    torch.save({'mean': target_mean, 'std': target_std}, target_scaler_path); print(f"目标值y scaler已保存到: {target_scaler_path}")

    # DataLoaders
    batch_size = config.get('batch_size', 8); num_workers = config.get('num_workers', 0)
    pin_memory_flag = config.get('pin_memory', False) and device.type == 'cuda'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=pin_memory_flag)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=pin_memory_flag)

    # Model Initialization
    sample_graph_7am_for_dims = all_sequences_data[0][0]
    static_node_in_dim = sample_graph_7am_for_dims.x.shape[1]
    global_env_in_dim = sample_graph_7am_for_dims.graph_global_env_features.shape[0] if sample_graph_7am_for_dims.graph_global_env_features.ndim == 1 else sample_graph_7am_for_dims.graph_global_env_features.shape[1]
    time_in_dim = time_features_for_dataset.shape[1]
    num_relations = config.get('num_relations', 5)

    # Generator (Main Model)
    model_G = CGGANLSTMModelWithHourlyHeads(
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        gen_hidden_dim=config.get('gcn_hidden_dim', 128),
        gen_output_dim=config.get('gcn_output_dim', 128),
        gen_num_layers=config.get('cggan_gen_num_layers', 2),
        gen_dropout_rate=config.get('dropout_rate_gcn', 0.3),
        num_relations=num_relations,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1),
        T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)

    model_G.node_feat_mean = node_feat_mean.to(device); model_G.node_feat_std = node_feat_std.to(device)
    total_params_G = sum(p.numel() for p in model_G.parameters() if p.requires_grad)
    print(f"模型总参数量 (Generator - CGGAN-LSTM): {total_params_G:,}");
    report_data['model_G_total_parameters'] = total_params_G

    # Discriminator
    model_D = PredictionDiscriminatorRGCN(
        original_node_feature_dim=static_node_in_dim,
        prediction_dim=1,
        hidden_dim=config.get('cggan_disc_hidden_dim', 64),
        num_layers=config.get('cggan_disc_num_layers', 2),
        num_relations=num_relations,
        dropout_rate=config.get('cggan_disc_dropout_rate', 0.3)
    ).to(device)
    total_params_D = sum(p.numel() for p in model_D.parameters() if p.requires_grad)
    print(f"模型总参数量 (Discriminator): {total_params_D:,}");
    report_data['model_D_total_parameters'] = total_params_D


    # ===== Component FLOPS Calculation =====
    print("\nCalculating MACs for model components (approximate FLOPS):")
    report_data['component_gmacs'] = {}
    dummy_nodes_component = 2500
    dummy_edges_component = 60000
    dummy_batch_global_comp = 1

    model_G.eval()
    model_D.eval()

    # 1. NodeFeatureGeneratorRGCN (from model_G)
    try:
        gen_gnn_to_profile = model_G.node_feature_generator_h0
        dummy_x_gen_gnn = torch.randn(dummy_nodes_component, gen_gnn_to_profile.input_dim, device=device)
        dummy_ei_gen_gnn = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ea_gen_gnn = torch.randn(dummy_edges_component, 5, device=device) # Assuming edge_attr has 5 features for RGCN
        if gen_gnn_to_profile.num_relations > 0 and dummy_ea_gen_gnn.shape[1] >=5:
             dummy_ea_gen_gnn[:, 4] = torch.randint(0, gen_gnn_to_profile.num_relations, (dummy_edges_component,), device=device).float()
        elif gen_gnn_to_profile.num_relations > 0 and dummy_ea_gen_gnn.shape[1] > 0 :
             dummy_ea_gen_gnn[:,0] = torch.randint(0, gen_gnn_to_profile.num_relations, (dummy_edges_component,), device=device).float()


        macs_gen_gnn = torchprofile.profile_macs(gen_gnn_to_profile, args=(dummy_x_gen_gnn, dummy_ei_gen_gnn, dummy_ea_gen_gnn))
        report_data['component_gmacs']['node_feature_generator_rgcn'] = macs_gen_gnn / 1e9
        print(f"  NodeFeatureGeneratorRGCN GMACs: {macs_gen_gnn / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling NodeFeatureGeneratorRGCN: {e}"); report_data['component_gmacs']['node_feature_generator_rgcn'] = "Error"

    # 2. PredictionDiscriminatorRGCN
    try:
        disc_to_profile = model_D
        dummy_x_orig_disc = torch.randn(dummy_nodes_component, static_node_in_dim, device=device)
        dummy_y_cand_disc = torch.randn(dummy_nodes_component, 1, device=device)
        dummy_ei_disc = torch.randint(0, dummy_nodes_component, (2, dummy_edges_component), device=device)
        dummy_ea_disc = torch.randn(dummy_edges_component, 5, device=device) # Assuming edge_attr has 5 features
        if disc_to_profile.num_relations > 0 and dummy_ea_disc.shape[1] >=5:
             dummy_ea_disc[:,4] = torch.randint(0, disc_to_profile.num_relations, (dummy_edges_component,), device=device).float()
        elif disc_to_profile.num_relations > 0 and dummy_ea_disc.shape[1] > 0:
             dummy_ea_disc[:,0] = torch.randint(0, disc_to_profile.num_relations, (dummy_edges_component,), device=device).float()


        macs_disc = torchprofile.profile_macs(disc_to_profile, args=(dummy_x_orig_disc, dummy_y_cand_disc, dummy_ei_disc, dummy_ea_disc))
        report_data['component_gmacs']['prediction_discriminator_rgcn'] = macs_disc / 1e9
        print(f"  PredictionDiscriminatorRGCN GMACs: {macs_disc / 1e9:.4f}")
    except Exception as e:
        print(f"  Error profiling PredictionDiscriminatorRGCN: {e}"); report_data['component_gmacs']['prediction_discriminator_rgcn'] = "Error"

    # 3. LSTM Layer (Manual MAC Calculation - from model_G)
    print(f"  Manually Calculating MACs for LSTM Layer:")
    try:
        lstm_layer = model_G.lstm
        N_nodes = dummy_nodes_component
        L_seq = T_PRED_HORIZON
        H_in = lstm_layer.input_size
        H_hidden = lstm_layer.hidden_size
        num_layers = lstm_layer.num_layers
        macs_lstm_manual = N_nodes * L_seq * 4 * (H_in * H_hidden + H_hidden * H_hidden)
        if num_layers > 1:
            macs_lstm_manual += N_nodes * L_seq * (num_layers - 1) * 4 * (H_hidden * H_hidden + H_hidden * H_hidden)
        gmacs_lstm_manual = macs_lstm_manual / 1e9
        report_data['component_gmacs']['lstm_layer'] = gmacs_lstm_manual
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = "Manually calculated based on formula."
        print(f"  LSTM Parameters: input_size={H_in}, hidden_size={H_hidden}, num_layers={num_layers}")
        print(f"  Used for calculation: N_nodes={N_nodes}, L_seq={L_seq}")
        print(f"  LSTM Layer GMACs (Manual): {gmacs_lstm_manual:.4f} (for sequence length {L_seq})")
    except Exception as e:
        print(f"  Error manually calculating LSTM Layer MACs: {e}")
        report_data['component_gmacs']['lstm_layer'] = "Error"
        report_data['component_gmacs']['lstm_layer_profiling_notes'] = f"Error during manual calculation: {str(e)}"

    # 4. Other MLPs from model_G
    try:
        fusion_mlp_to_profile = model_G.fusion_mlp
        dummy_input_fusion_mlp = torch.randn(dummy_nodes_component, model_G.fusion_mlp_input_dim, device=device)
        macs_fusion_mlp = torchprofile.profile_macs(fusion_mlp_to_profile, args=(dummy_input_fusion_mlp,))
        report_data['component_gmacs']['fusion_mlp'] = macs_fusion_mlp / 1e9
        print(f"  Fusion MLP GMACs: {macs_fusion_mlp / 1e9:.4f}")
    except Exception as e: print(f"  Error profiling Fusion MLP: {e}"); report_data['component_gmacs']['fusion_mlp'] = "Error"
    try:
        pred_head_to_profile = model_G.hourly_prediction_heads[0]
        dummy_input_pred_head = torch.randn(dummy_nodes_component, model_G.lstm_hidden_dim, device=device)
        macs_pred_head = torchprofile.profile_macs(pred_head_to_profile, args=(dummy_input_pred_head,))
        report_data['component_gmacs']['prediction_head_mlp'] = macs_pred_head / 1e9
        print(f"  Prediction Head MLP GMACs: {macs_pred_head / 1e9:.4f}")
    except Exception as e: print(f"  Error profiling Prediction Head: {e}"); report_data['component_gmacs']['prediction_head_mlp'] = "Error"
    try:
        encoder_to_profile = model_G.global_env_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model_G.global_env_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['global_env_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Global Env Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e: print(f"  Error profiling Global Env Encoder: {e}"); report_data['component_gmacs']['global_env_encoder_mlp'] = "Error"
    try:
        encoder_to_profile = model_G.time_encoder
        dummy_input_encoder = torch.randn(dummy_batch_global_comp, model_G.time_in_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['time_encoder_mlp'] = macs_encoder / 1e9
        print(f"  Time Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e: print(f"  Error profiling Time Encoder: {e}"); report_data['component_gmacs']['time_encoder_mlp'] = "Error"
    try:
        encoder_to_profile = model_G.h0_c0_from_gnn_encoder
        dummy_input_encoder = torch.randn(dummy_nodes_component, model_G.gen_output_dim, device=device)
        macs_encoder = torchprofile.profile_macs(encoder_to_profile, args=(dummy_input_encoder,))
        report_data['component_gmacs']['h0_c0_from_gnn_encoder_mlp'] = macs_encoder / 1e9
        print(f"  H0/C0 from GNN Encoder MLP GMACs: {macs_encoder / 1e9:.4f}")
    except Exception as e: print(f"  Error profiling H0/C0 from GNN Encoder: {e}"); report_data['component_gmacs']['h0_c0_from_gnn_encoder_mlp'] = "Error"

    model_G.train()
    model_D.train()
    # ===== End Component FLOPS Calculation =====

    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=config.get('lr_g', config.get('lr', 0.001)),
                                   betas=(config.get('beta1_g', 0.5), 0.999))
    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=config.get('lr_d', config.get('lr', 0.001)),
                                   betas=(config.get('beta1_d', 0.5), 0.999))

    scheduler_G = ReduceLROnPlateau(optimizer_G, mode='min', factor=0.5,
                                    patience=config.get('scheduler_patience_g', config.get('scheduler_patience', 20)), verbose=True)

    best_val_loss_scaled = float('inf'); best_val_hourly_metrics_original = None; best_epoch = 0
    patience_counter = 0; max_epochs = config.get('max_epochs', 300); early_stopping_patience = config.get('early_stopping_patience', 45)
    lambda_adv = config.get('lambda_adv', 0.01)
    model_save_path = Path(config['results_dir']) / f"best_cgganlstm_model_G_seed{seed}.pth"

    timeline_time_features_on_device = time_features_for_dataset.to(device)
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    target_mean_on_device = target_mean.to(device); target_std_on_device = target_std.to(device)
    epoch_times = []

    for epoch in range(1, max_epochs + 1):
        avg_loss_G, avg_loss_D, avg_pred_loss, avg_adv_loss_G, epoch_duration = train_epoch_adversarial(
            model_G, model_D, train_loader, optimizer_G, optimizer_D,
            device, timeline_time_features_on_device,
            node_feat_mean, node_feat_std,
            target_mean_on_device, target_std_on_device,
            lambda_adv, T_PRED_HORIZON, static_node_in_dim
        )
        epoch_times.append(epoch_duration)

        val_loss_scaled, val_hourly_metrics_original, _ = evaluate_epoch(
            model_G, val_loader, device,
            timeline_time_features_on_device,
            node_feat_mean, node_feat_std,
            target_mean_cpu, target_std_cpu,
            epoch_type="Validation"
        )

        scheduler_G.step(val_loss_scaled)

        print(f"Epoch {epoch:03d} | G Pred MSE: {avg_pred_loss:.4f} | G Adv: {avg_adv_loss_G:.4f} | D Loss: {avg_loss_D:.4f} | Val Scaled MSE: {val_loss_scaled:.4f} | LR_G: {optimizer_G.param_groups[0]['lr']:.6f} | Time: {epoch_duration:.2f}s")
        _print_hourly_metrics_summary("Val", val_hourly_metrics_original, T_PRED_HORIZON, indent="                     ")

        if val_loss_scaled < best_val_loss_scaled:
            best_val_loss_scaled = val_loss_scaled; best_val_hourly_metrics_original = val_hourly_metrics_original; best_epoch = epoch
            patience_counter = 0;
            torch.save(model_G.state_dict(), model_save_path)
            print(f"                     ---> Best Generator model saved (Epoch: {epoch}, Val Scaled MSE: {best_val_loss_scaled:.4f})")
        else:
            patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch} due to no improvement."); break

    total_training_duration = time.time() - train_start_time
    report_data['total_training_time_seconds'] = total_training_duration
    report_data['average_epoch_time_seconds'] = np.mean(epoch_times) if epoch_times else np.nan
    report_data['num_epochs_trained'] = epoch; report_data['best_validation_epoch'] = best_epoch
    report_data['best_validation_scaled_mse'] = best_val_loss_scaled

    model_G_for_eval = CGGANLSTMModelWithHourlyHeads(
        static_node_in_dim=static_node_in_dim, global_env_in_dim=global_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim', 16), time_emb_dim=config.get('time_emb_dim', 8),
        gen_hidden_dim=config.get('gcn_hidden_dim', 128), gen_output_dim=config.get('gcn_output_dim', 128),
        gen_num_layers=config.get('cggan_gen_num_layers', 2), gen_dropout_rate=config.get('dropout_rate_gcn', 0.3),
        num_relations=num_relations,
        lstm_hidden_dim=config.get('gru_hidden_dim', 128),
        fusion_mlp_output_dim=config.get('fusion_mlp_output_dim', 128),
        fusion_mlp_hidden_dim=config.get('fusion_mlp_hidden_dim', 64),
        dropout_rate_fusion_mlp=config.get('dropout_rate_fusion_mlp', 0.2),
        num_lstm_layers=config.get('num_gru_layers', 1), T_pred_horizon=T_PRED_HORIZON,
        dropout_rate_encoders=config.get('dropout_rate_encoders', 0.1),
        dropout_rate_lstm=config.get('dropout_rate_gru', 0.2),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim', 64),
        dropout_rate_pred_head=config.get('dropout_rate_pred_head', 0.2)
    ).to(device)
    try:
        model_G_for_eval.load_state_dict(torch.load(model_save_path, map_location=device))
        print(f"Successfully loaded best Generator model from {model_save_path}")
    except Exception as e:
        print(f"无法加载最佳Generator模型 ({e})，将使用训练循环结束时的Generator模型。")
        model_G_for_eval = model_G

    print("\n评估最佳Generator模型在训练集上...");
    best_model_train_loss_scaled, best_model_train_hourly_metrics, train_eval_duration = evaluate_epoch(
        model_G_for_eval, train_loader, device, timeline_time_features_on_device,
        node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu,
        epoch_type="Best Model on Train"
    )
    report_data['best_model_train_set_metrics_hourly'] = best_model_train_hourly_metrics
    report_data['best_model_train_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_model_train_hourly_metrics, T_PRED_HORIZON)
    report_data['best_model_train_set_eval_time_seconds'] = train_eval_duration
    _print_hourly_metrics_summary("最佳模型训练集", best_model_train_hourly_metrics, T_PRED_HORIZON)

    report_data['best_model_validation_set_metrics_hourly'] = best_val_hourly_metrics_original
    if best_val_hourly_metrics_original:
        report_data['best_model_validation_set_metrics_aggregated'] = calculate_aggregated_metrics_report(best_val_hourly_metrics_original, T_PRED_HORIZON)
    _print_hourly_metrics_summary("最佳模型验证集", report_data['best_model_validation_set_metrics_hourly'], T_PRED_HORIZON)


    print("\n评估最佳Generator模型在测试集上...");
    test_loss_scaled, test_hourly_metrics_original, test_inference_duration = evaluate_epoch(
        model_G_for_eval, test_loader, device, timeline_time_features_on_device,
        node_feat_mean, node_feat_std, target_mean_cpu, target_std_cpu, epoch_type="Test"
    )
    report_data['test_set_inference_time_seconds'] = test_inference_duration
    report_data['best_model_test_set_metrics_hourly'] = test_hourly_metrics_original
    report_data['best_model_test_set_metrics_aggregated'] = calculate_aggregated_metrics_report(test_hourly_metrics_original, T_PRED_HORIZON)
    print("\n" + "="*20 + " 最终测试集评估结果 (CGGAN-LSTM) " + "="*20); print(f"平均测试 Scaled MSE: {test_loss_scaled:.4f}")
    _print_hourly_metrics_summary("测试集", test_hourly_metrics_original, T_PRED_HORIZON)

    agg_test = report_data['best_model_test_set_metrics_aggregated']
    print(f"平均测试 MSE (Orig) : {agg_test.get('avg_mse', np.nan):.4f} (Std: {agg_test.get('std_mse', np.nan):.4f})")
    print(f"平均测试 R2 (Orig)  : {agg_test.get('avg_r2', np.nan):.4f} (Std: {agg_test.get('std_r2', np.nan):.4f})")
    print(f"平均测试 MAE (Orig) : {agg_test.get('avg_mae', np.nan):.4f} (Std: {agg_test.get('std_mae', np.nan):.4f})")
    print(f"平均测试 RMSE (Orig): {agg_test.get('avg_rmse', np.nan):.4f} (Std: {agg_test.get('std_rmse', np.nan):.4f})")
    print("="*70)

    report_file_path = Path(config['results_dir']) / f"training_report_cggan_lstm_seed{seed}.json"
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer): return int(obj)
                if isinstance(obj, np.floating): return float(obj)
                if isinstance(obj, np.ndarray): return obj.tolist()
                if isinstance(obj, torch.Tensor): return obj.tolist()
                if isinstance(obj, Path): return str(obj)
                return super(NpEncoder, self).default(obj)
        with open(report_file_path, 'w') as f: json.dump(report_data, f, indent=4, cls=NpEncoder)
        print(f"训练报告已保存到: {report_file_path}")
    except Exception as e: print(f"保存训练报告失败: {e}")

    return model_G_for_eval, node_feat_mean, node_feat_std, target_mean, target_std

def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "):
    # ... (identical) ...
    if hourly_metrics is None:
        print(f"{indent}{set_name} metrics not available.")
        return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed:
            print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}")
            header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2', np.nan):13.4f} | {metrics.get('mse', np.nan):14.4f} | {metrics.get('mae', np.nan):14.4f} | {metrics.get('rmse', np.nan):15.4f} | {metrics.get('count', 0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")

# ===========================================================
# 6. 主执行块
# ===========================================================
if __name__ == "__main__":
    # gc.collect()
    # if torch.cuda.is_available(): torch.cuda.empty_cache()

    DRIVE_BASE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")

    if not DRIVE_BASE_PATH.exists(): DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)

    DATA_SUBDIR = Path("Result/Sequential_12Hour_Data")
    DATA_FILENAME = "graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR = Path("Result/Final_CGGANLSTM1")
    RESULTS_SAVE_DIR = DRIVE_BASE_PATH / RESULTS_SUBDIR
    os.makedirs(RESULTS_SAVE_DIR, exist_ok=True)
    DATA_PATH = DRIVE_BASE_PATH / DATA_SUBDIR / DATA_FILENAME


    DATA_YEAR = 2023; DATA_MONTH = 5; DATA_DAY = 3
    START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES = 8; PREDICTION_HORIZON = 12

    training_config = {
        'seed': 42, 'batch_size': 8, # Keep consistent with previous runs for data loading
        'lr_g': 0.0002, 'lr_d': 0.00005, 'beta1_g': 0.5, 'beta1_d': 0.5, # GAN specific, might need tuning
        'lambda_adv': 0.001, # Weight for adversarial loss - start small, tune if needed

        'max_epochs': 1000, 'scheduler_patience_g': 20, 'early_stopping_patience': 45, # Standard training setup
        'T_pred_horizon': PREDICTION_HORIZON,
        'results_dir': str(RESULTS_SAVE_DIR),

        'global_env_emb_dim': 16, 'time_emb_dim': 8, # Consistent

        # === Parameters for NodeFeatureGeneratorRGCN (Generator's GNN part) ===
        'gcn_hidden_dim': 128,          # Renamed to 'gen_hidden_dim' in model, maps to CGGAN_Generator hidden dim
        'gcn_output_dim': 128,          # Renamed to 'gen_output_dim' in model, output of Generator GNN (feeds LSTM pipeline)
        'cggan_gen_num_layers': 3,      # Your RGCNModule had 3 conv layers. Let's match that.
        'dropout_rate_gcn': 0.3,        # Renamed to 'gen_dropout_rate' in model, dropout in Generator GNN

        # === Parameters for PredictionDiscriminatorRGCN ===
        'cggan_disc_hidden_dim': 64,    # Discriminator can be a bit smaller
        'cggan_disc_num_layers': 2,     # e.g., 2 layers for Discriminator GNN
        'cggan_disc_dropout_rate': 0.3,

        # === Common GNN Parameter ===
        'num_relations': 5,             # Critical for RGCNConv in both G and D

        # === Parameters for LSTM (using repurposed GRU keys as before) ===
        'gru_hidden_dim': 128,          # This will be lstm_hidden_dim
        'num_gru_layers': 1,            # This will be num_lstm_layers
        'dropout_rate_gru': 0.2,        # This will be dropout_rate_lstm

        # === MLP and other parameters (match previous RGCN+LSTM setup) ===
        'mlp_prediction_hidden_dim': 64,
        'fusion_mlp_output_dim': 128,     # Should align with LSTM input if direct, or output of fusion MLP
        'fusion_mlp_hidden_dim': 64,
        'dropout_rate_fusion_mlp': 0.2,
        'dropout_rate_encoders': 0.1,
        'dropout_rate_pred_head': 0.2,

        'use_amp': False, 'enable_profiler': False, 'num_workers': 0,
        'pin_memory': False, 'train_split_ratio': 0.7, 'val_split_ratio': 0.2,
        'h0_from_first_step': True
    }

    all_graph_sequences_loaded = None
    try:
        if not DATA_PATH.exists(): raise FileNotFoundError(f"数据文件在指定路径未找到: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("加载的数据格式不正确。")

        expected_len_per_sequence = training_config['T_pred_horizon'] + 1

        processed_sequences = []
        for i, seq in enumerate(all_graph_sequences_loaded):
            if not isinstance(seq, list) or len(seq) != expected_len_per_sequence: continue
            valid_seq = True
            for step_idx, graph_step_data in enumerate(seq):
                if not isinstance(graph_step_data, Data) or not hasattr(graph_step_data, 'x') or graph_step_data.x is None or \
                   not hasattr(graph_step_data, 'edge_index') or graph_step_data.edge_index is None or \
                   not hasattr(graph_step_data, 'edge_attr') or graph_step_data.edge_attr is None or \
                   graph_step_data.edge_attr.shape[1] < (5 if training_config.get('num_relations',0) > 0 else 1) or \
                   not hasattr(graph_step_data, 'graph_global_env_features') or \
                   (step_idx > 0 and (not hasattr(graph_step_data, 'y') or graph_step_data.y is None)):
                    if hasattr(graph_step_data, 'edge_attr') and graph_step_data.edge_attr is not None and \
                       graph_step_data.edge_attr.shape[1] < (5 if training_config.get('num_relations',0) > 0 else 1) :
                       print(f"Warning: Seq {i}, step {step_idx} has edge_attr dim {graph_step_data.edge_attr.shape[1]}, RGCN needs compatible dim.")
                    valid_seq = False; break
                if step_idx > 0 and isinstance(graph_step_data.y, torch.Tensor) and graph_step_data.y.ndim == 1:
                    graph_step_data.y = graph_step_data.y.unsqueeze(1)
            if valid_seq: processed_sequences.append(seq)
        if not processed_sequences: raise ValueError(f"数据处理后没有长度为 {expected_len_per_sequence} 的有效序列。")
        all_graph_sequences = processed_sequences
        print(f"成功加载并处理 {len(all_graph_sequences)} 个空间窗口的序列数据。")
    except Exception as e: print(f"加载或验证数据时发生错误: {e}"); all_graph_sequences = None

    if all_graph_sequences:
        base_datetime_for_timeline = dt_datetime(DATA_YEAR, DATA_MONTH, DATA_DAY, START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset_timeline = generate_time_features_for_sequence(
            base_datetime_for_timeline,
            training_config['T_pred_horizon']
        )

        trained_model_G, final_node_mean, final_node_std, final_target_mean, final_target_std = main_training_cggan_lstm_hourly_heads(
            all_graph_sequences, training_config, time_features_for_dataset_timeline
        )
        print("CGGAN-LSTM 模型训练和评估完成!")
    else:
        print("由于数据加载失败或数据为空，训练流程未启动。")

##CGAN+LSTM

In [None]:
# ===========================================================
# 0. 环境 & 依赖
# ===========================================================
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import r2_score
import gc
import math
from datetime import datetime as dt_datetime, timedelta
from pathlib import Path
import time
import json
import torchprofile

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# ===========================================================
# 1. 特征生成 & 辅助模块
# ===========================================================
def generate_time_features_for_sequence(base_dt_obj, num_steps):
    time_features_list = []
    for i in range(num_steps):
        current_dt = base_dt_obj + timedelta(hours=i)
        hour_norm = current_dt.hour / 23.0
        day_of_year_norm = current_dt.timetuple().tm_yday / (366.0 if current_dt.year % 4 == 0 and (current_dt.year % 100 != 0 or current_dt.year % 400 == 0) else 365.0)
        hour_sin = math.sin(2 * math.pi * hour_norm); hour_cos = math.cos(2 * math.pi * hour_norm)
        doy_sin = math.sin(2 * math.pi * day_of_year_norm); doy_cos = math.cos(2 * math.pi * day_of_year_norm)
        time_features_list.append(torch.tensor([hour_sin, hour_cos, doy_sin, doy_cos], dtype=torch.float32))
    return torch.stack(time_features_list)

class MLPEncoder(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=None, dropout_rate=0.1, activation_fn=nn.ReLU, add_layer_norm=True):
        super().__init__()
        if hid_dim is None:
            hid_dim = max(min(in_dim, out_dim), (in_dim + out_dim) // 2)
            if hid_dim == 0 and out_dim > 0 : hid_dim = out_dim
            if hid_dim == 0 and in_dim > 0 : hid_dim = in_dim
            if hid_dim == 0 : hid_dim = 1
        layers = [nn.Linear(in_dim, hid_dim)]
        if activation_fn is not None: layers.append(activation_fn())
        if add_layer_norm: layers.append(nn.LayerNorm(hid_dim))
        layers.append(nn.Dropout(dropout_rate)); layers.append(nn.Linear(hid_dim, out_dim))
        self.mlp = nn.Sequential(*layers)
    def forward(self, x): return self.mlp(x)

# ===========================================================
# 2. U-Net based Generator and PatchGAN Discriminator (Per-Pixel LSTM version)
# ===========================================================
class PreUNetFusionModule(nn.Module):
    def __init__(self, static_feat_dim, global_env_emb_dim, time_emb_dim, unet_input_channels, dropout_rate=0.1):
        super().__init__()
        input_channels = static_feat_dim + global_env_emb_dim + time_emb_dim
        self.fusion_convs = nn.Sequential(
            nn.Conv2d(input_channels, (input_channels + unet_input_channels)//2, 3, 1, 1, bias=False),
            nn.InstanceNorm2d((input_channels + unet_input_channels)//2), nn.ReLU(inplace=True), nn.Dropout2d(dropout_rate),
            nn.Conv2d((input_channels + unet_input_channels)//2, unet_input_channels, 3, 1, 1, bias=False),
            nn.InstanceNorm2d(unet_input_channels), nn.ReLU(inplace=True)
        )
        self.input_channels_for_profiling = input_channels
        self.output_channels_for_profiling = unet_input_channels
    def forward(self, static_img_feat, expanded_global_emb, expanded_time_emb):
        combined = torch.cat([static_img_feat, expanded_global_emb, expanded_time_emb], dim=1)
        return self.fusion_convs(combined)

class UNetFeatureExtractorModule(nn.Module):
    def __init__(self, input_channels, output_feature_channels=32,
                 encoder_channels=(64, 128, 256), middle_channels=256, decoder_channels=(128, 64)):
        super().__init__()
        self.input_channels_for_profiling = input_channels
        self.output_feature_channels_for_profiling = output_feature_channels
        self.encoder1 = self._encoder_block(input_channels, encoder_channels[0])
        self.encoder2 = self._encoder_block(encoder_channels[0], encoder_channels[1])
        self.encoder3 = self._encoder_block(encoder_channels[1], encoder_channels[2])
        self.middle = nn.Sequential(nn.Conv2d(encoder_channels[2], middle_channels, 3,1,1,bias=False), nn.InstanceNorm2d(middle_channels), nn.ReLU(inplace=True))
        self.decoder3 = self._decoder_block(middle_channels + encoder_channels[2], decoder_channels[0])
        self.decoder2 = self._decoder_block(decoder_channels[0] + encoder_channels[1], decoder_channels[1])
        self.final_feature_conv = nn.Conv2d(decoder_channels[1] + encoder_channels[0], output_feature_channels, 1)

    def _encoder_block(self, in_c, out_c, norm=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)]
        if norm: layers.append(nn.InstanceNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2, inplace=True)); return nn.Sequential(*layers)
    def _decoder_block(self, in_c, out_c, norm=True):
        layers = [nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False)]
        if norm: layers.append(nn.InstanceNorm2d(out_c))
        layers.append(nn.ReLU(inplace=True)); return nn.Sequential(*layers)
    def forward(self, x):
        e1 = self.encoder1(x); e2 = self.encoder2(e1); e3 = self.encoder3(e2)
        m = self.middle(e3)
        d3 = self.decoder3(torch.cat([m, e3], dim=1))
        if d3.size()[2:] != e2.size()[2:]: d3 = F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=False)
        d2 = self.decoder2(torch.cat([d3, e2], dim=1))
        if d2.size()[2:] != e1.size()[2:]: d2 = F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=False)
        feature_map = self.final_feature_conv(torch.cat([d2, e1], dim=1))
        return F.interpolate(feature_map, size=x.size()[2:], mode='bilinear', align_corners=False)

class H0PixelFeatureExtractor(nn.Module):
    def __init__(self, static_feat_dim, # 16
                 unet_input_ch_for_h0_unet, # e.g. static_feat_dim, or static_feat_dim+const_emb
                 unet_feature_output_channels_for_h0, # Output of U-Net for h0 path
                 unet_enc_ch_list, unet_mid_ch, unet_dec_ch_list, # U-Net structure for H0
                 lstm_hidden_dim, num_lstm_layers, H_img=50, W_img=50, dropout_rate=0.1):
        super().__init__()
        self.num_lstm_layers = num_lstm_layers
        # This U-Net processes the 7am image to get features for h0
        self.h0_unet_feature_extractor = UNetFeatureExtractorModule(
            input_channels=unet_input_ch_for_h0_unet, # Typically static_feat_dim (16)
            output_feature_channels=unet_feature_output_channels_for_h0, # e.g., 32
            encoder_channels=unet_enc_ch_list,
            middle_channels=unet_mid_ch,
            decoder_channels=unet_dec_ch_list
        )
        # MLP to project each pixel's U-Net feature to LSTM hidden dim
        self.h0_pixel_projector = MLPEncoder(
            in_dim=unet_feature_output_channels_for_h0, # C_unet_out from h0_unet
            out_dim=lstm_hidden_dim,
            hid_dim= (unet_feature_output_channels_for_h0 + lstm_hidden_dim) // 2,
            dropout_rate=dropout_rate, add_layer_norm=False
        )
        self.static_feat_dim_for_profiling = static_feat_dim
        self.unet_feature_output_channels_for_h0_profiling = unet_feature_output_channels_for_h0


    def forward(self, static_7am_image_normalized): # [B, 16, H, W]
        # If h0_unet_feature_extractor needs specific input channels different from static_feat_dim,
        # prepare it here (e.g. if it also expects fused features, but usually not for h0)
        # For now, assuming h0_unet_feature_extractor takes static_7am_image_normalized directly
        # if its input_channels match static_feat_dim.

        unet_features_7am = self.h0_unet_feature_extractor(static_7am_image_normalized) # [B, C_unet_out_h0, H, W]
        B, C_unet_out_h0, H, W = unet_features_7am.shape

        pixel_features_7am = unet_features_7am.permute(0, 2, 3, 1).reshape(B * H * W, C_unet_out_h0) # [B*H*W, C_unet_out_h0]
        projected_pixel_features = self.h0_pixel_projector(pixel_features_7am) # [B*H*W, lstm_hidden_dim]

        h0 = projected_pixel_features.unsqueeze(0).repeat(self.num_lstm_layers, 1, 1) # [Num_LSTM_Layers, B*H*W, lstm_hidden_dim]
        return h0

class UNetPerPixelLSTMHeadModel_Generator(nn.Module): # GENERATOR
    def __init__(self, static_feat_dim, global_env_in_dim, time_in_dim, global_env_emb_dim, time_emb_dim,
                 unet_input_channels_after_fusion, dropout_rate_pre_fusion,
                 unet_feature_output_channels, unet_encoder_channels_list, unet_middle_channels_val, unet_decoder_channels_list,
                 # H0 Extractor specific U-Net params
                 h0_unet_input_channels, h0_unet_output_channels,
                 h0_unet_enc_ch_list, h0_unet_mid_ch, h0_unet_dec_ch_list,
                 lstm_hidden_dim, num_lstm_layers, dropout_rate_lstm,
                 mlp_prediction_hidden_dim, dropout_rate_pred_head,
                 T_pred_horizon, H_img=50, W_img=50, dropout_rate_other_mlps=0.1):
        super().__init__()
        self.T_pred_horizon = T_pred_horizon; self.static_feat_dim = static_feat_dim
        self.H_img, self.W_img = H_img, W_img
        self.unet_input_channels_after_fusion_for_profiling = unet_input_channels_after_fusion
        self.unet_feature_output_channels_for_profiling = unet_feature_output_channels

        self.global_env_encoder = MLPEncoder(global_env_in_dim, global_env_emb_dim, dropout_rate=dropout_rate_other_mlps)
        self.time_encoder = MLPEncoder(time_in_dim, time_emb_dim, dropout_rate=dropout_rate_other_mlps)
        self.pre_unet_fusion = PreUNetFusionModule(static_feat_dim, global_env_emb_dim, time_emb_dim, unet_input_channels_after_fusion, dropout_rate_pre_fusion)

        # Main U-Net for per-timestep feature extraction
        self.unet_feature_extractor = UNetFeatureExtractorModule(unet_input_channels_after_fusion, unet_feature_output_channels, unet_encoder_channels_list, unet_middle_channels_val, unet_decoder_channels_list)

        # H0 Pixel Feature Extractor
        self.h0_pixel_feature_extractor = H0PixelFeatureExtractor(
            static_feat_dim=static_feat_dim, # Input to its own U-Net is the 16-channel static image
            unet_input_ch_for_h0_unet=h0_unet_input_channels, # Configurable, typically static_feat_dim
            unet_feature_output_channels_for_h0=h0_unet_output_channels, # Output of its U-Net
            unet_enc_ch_list=h0_unet_enc_ch_list, unet_mid_ch=h0_unet_mid_ch, unet_dec_ch_list=h0_unet_dec_ch_list,
            lstm_hidden_dim=lstm_hidden_dim, num_lstm_layers=num_lstm_layers,
            H_img=H_img, W_img=W_img, dropout_rate=dropout_rate_other_mlps
        )

        self.lstm_input_size_for_profiling = unet_feature_output_channels # LSTM input is C_unet_out (per pixel)
        self.lstm_hidden_dim_for_profiling = lstm_hidden_dim
        self.lstm = nn.LSTM(unet_feature_output_channels, lstm_hidden_dim, num_lstm_layers, batch_first=True, dropout=dropout_rate_lstm if num_lstm_layers > 1 else 0.0)

        self.hourly_prediction_heads = nn.ModuleList()
        for _ in range(T_pred_horizon): # Output 1 value per pixel
            self.hourly_prediction_heads.append(MLPEncoder(lstm_hidden_dim, 1, mlp_prediction_hidden_dim, dropout_rate_pred_head, add_layer_norm=False, activation_fn=None)) # No activation for regression output

        self.register_buffer('static_image_feat_mean', torch.zeros(static_feat_dim))
        self.register_buffer('static_image_feat_std', torch.ones(static_feat_dim))

    def forward(self, list_of_batched_timesteps: list, timeline_time_features: torch.Tensor, device: torch.device):
        pyg_batch_7am = list_of_batched_timesteps[0].to(device)
        current_B = pyg_batch_7am.num_graphs; nodes_per_graph = self.H_img * self.W_img

        static_x_image_flat = pyg_batch_7am.x
        normalized_static_x_flat = (static_x_image_flat - self.static_image_feat_mean.to(device)) / (self.static_image_feat_std.to(device) + 1e-8)
        static_x_image_normalized = normalized_static_x_flat.view(current_B, nodes_per_graph, self.static_feat_dim).permute(0,2,1).contiguous().view(current_B, self.static_feat_dim, self.H_img, self.W_img)

        h0 = self.h0_pixel_feature_extractor(static_x_image_normalized); c0 = torch.zeros_like(h0, device=device)
        lstm_initial_state = (h0, c0)

        all_unet_output_pixel_features_T = []; all_fused_images_for_unet_T = []
        for t_pred_idx in range(self.T_pred_horizon):
            pyg_batch_t = list_of_batched_timesteps[t_pred_idx + 1].to(device)
            global_env_feat_t_unenc = pyg_batch_t.graph_global_env_features
            exp_glob_dim = self.global_env_encoder.mlp[0].in_features
            if not (global_env_feat_t_unenc.shape==(current_B,exp_glob_dim)): global_env_feat_t_unenc=torch.zeros(current_B,exp_glob_dim,device=device)
            glob_emb_t = self.global_env_encoder(global_env_feat_t_unenc).unsqueeze(-1).unsqueeze(-1).expand(-1,-1,self.H_img,self.W_img)
            time_emb_t = self.time_encoder(timeline_time_features[t_pred_idx,:].to(device)).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(current_B,-1,self.H_img,self.W_img)

            fused_img_t = self.pre_unet_fusion(static_x_image_normalized, glob_emb_t, time_emb_t) # Use normalized static for consistency
            all_fused_images_for_unet_T.append(fused_img_t)
            unet_feat_map_t = self.unet_feature_extractor(fused_img_t) # [B, C_unet_out, H, W]
            # Permute and reshape for per-pixel LSTM: [B, C, H, W] -> [B, H, W, C] -> [B*H*W, C]
            pixel_feats_t = unet_feat_map_t.permute(0,2,3,1).reshape(current_B * nodes_per_graph, -1)
            all_unet_output_pixel_features_T.append(pixel_feats_t)

        lstm_input_sequence = torch.stack(all_unet_output_pixel_features_T, dim=1) # [B*H*W, T, C_unet_out]
        stacked_fused_images_for_unet = torch.stack(all_fused_images_for_unet_T, dim=1) # [B, T, C_fused, H, W] (for D)

        lstm_out_sequence, _ = self.lstm(lstm_input_sequence, lstm_initial_state) # [B*H*W, T, LSTM_Hidden_Dim]

        all_hourly_predictions_flat_pixels = []
        for t in range(self.T_pred_horizon):
            pred_pixels_t = self.hourly_prediction_heads[t](lstm_out_sequence[:, t, :]) # [B*H*W, 1]
            all_hourly_predictions_flat_pixels.append(pred_pixels_t.squeeze(-1)) # [B*H*W]

        stacked_predictions_flat = torch.stack(all_hourly_predictions_flat_pixels, dim=1) # [B*H*W, T] (for eval metrics)

        # Reshape for GAN L1 loss & D: [B*H*W, T] -> [B, H*W, T] -> [B, T, H*W] -> [B, T, 1, H, W]
        generated_sequence_images = stacked_predictions_flat.view(current_B, nodes_per_graph, self.T_pred_horizon)\
                                     .permute(0,2,1).contiguous().view(current_B, self.T_pred_horizon, 1, self.H_img, self.W_img)

        return generated_sequence_images, stacked_fused_images_for_unet, stacked_predictions_flat

class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_channels):
        super().__init__(); self.input_channels_for_profiling = input_channels
        self.model = nn.Sequential(
            self._disc_block(input_channels, 64, stride=2, normalize=False),
            self._disc_block(64, 128, stride=2), self._disc_block(128, 256, stride=2),
            self._disc_block(256, 512, stride=1, padding=1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )
    def _disc_block(self, in_c, out_c, stride=2, padding=1, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, stride, padding, bias=False)]
        if normalize: layers.append(nn.InstanceNorm2d(out_c, affine=True))
        layers.append(nn.LeakyReLU(0.2, inplace=True)); return nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

# ===========================================================
# 3. Loss Functions & Metrics (masked_l1_loss, calculate_hourly_metrics - as before)
# ===========================================================
# ... (masked_l1_loss and calculate_hourly_metrics are identical to previous cGAN code block)
def masked_l1_loss(generated_images, target_images, image_masks):
    if image_masks.sum() == 0: return torch.tensor(0.0, device=generated_images.device, requires_grad=True)
    if image_masks.ndim == 3: image_masks = image_masks.unsqueeze(1) # [B,H,W] -> [B,1,H,W]
    image_masks = image_masks.bool()
    valid_targets_mask = ~torch.isnan(target_images)
    final_mask = image_masks & valid_targets_mask
    if final_mask.sum() == 0: return torch.tensor(0.0, device=generated_images.device, requires_grad=True)
    diff = torch.abs(generated_images - target_images)
    return diff[final_mask].mean()

def calculate_hourly_metrics(predictions_flat_scaled, targets_flat_scaled, node_masks_flat, target_mean, target_std):
    target_mean_cpu = target_mean.cpu(); target_std_cpu = target_std.cpu()
    preds_unscaled = predictions_flat_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    targets_unscaled = targets_flat_scaled.clone().cpu() * (target_std_cpu + 1e-8) + target_mean_cpu
    num_nodes_total, T_horizon = preds_unscaled.shape; hourly_metrics_dict = {}
    preds_np = preds_unscaled.numpy(); targets_np = targets_unscaled.numpy()
    mask_np = node_masks_flat.cpu().numpy().astype(bool)
    for t in range(T_horizon):
        preds_t_all_nodes = preds_np[:, t]; targets_t_all_nodes = targets_np[:, t]
        preds_t_on_loss_nodes = preds_t_all_nodes[mask_np]; targets_t_on_loss_nodes = targets_t_all_nodes[mask_np]
        valid_target_data_mask_t = ~np.isnan(targets_t_on_loss_nodes)
        preds_t_final = preds_t_on_loss_nodes[valid_target_data_mask_t]; targets_t_final = targets_t_on_loss_nodes[valid_target_data_mask_t]
        if preds_t_final.shape[0] < 2: hourly_metrics_dict[t] = {'mse':np.nan,'mae':np.nan,'rmse':np.nan,'r2':np.nan,'count':0}; continue
        mse = np.mean((preds_t_final - targets_t_final)**2); mae = np.mean(np.abs(preds_t_final - targets_t_final))
        rmse = np.sqrt(mse); r2 = np.nan
        try: r2 = r2_score(targets_t_final, preds_t_final)
        except ValueError: pass
        hourly_metrics_dict[t] = {'mse':mse,'mae':mae,'rmse':rmse,'r2':r2,'count':preds_t_final.shape[0]}
    return hourly_metrics_dict

# ===========================================================
# 4. GAN Training and Evaluation Epochs (Adapted for Per-Pixel LSTM)
# ===========================================================
# train_epoch_unet_lstm_cgan and evaluate_epoch_unet_lstm_cgan are largely the same logic
# as the previous cGAN, but ensure G's outputs are handled correctly.
def train_epoch_unet_pixel_lstm_cgan( # Renamed
    model_G, model_D, loader, optimizer_G, optimizer_D,
    criterion_GAN, criterion_L1, lambda_L1,
    device, timeline_time_features, H_img, W_img,
    static_img_feat_mean, static_img_feat_std, target_mean, target_std
):
    model_G.train(); model_D.train()
    total_loss_G, total_loss_D, total_loss_G_GAN, total_loss_G_L1 = 0,0,0,0
    num_sequences_processed = 0
    model_G.static_image_feat_mean = static_img_feat_mean.to(device)
    model_G.static_image_feat_std = static_img_feat_std.to(device)
    target_mean_dev = target_mean.to(device); target_std_dev = target_std.to(device)
    epoch_start_time = time.time()

    for list_of_batched_timesteps in loader:
        current_B = list_of_batched_timesteps[0].num_graphs
        # G forward returns: generated_img_seq [B,T,1,H,W], stacked_fused_inputs [B,T,C_fused,H,W], preds_flat [B*H*W,T]
        generated_sequence_images, stacked_fused_conditions, _ = model_G(list_of_batched_timesteps, timeline_time_features.to(device), device)

        # --- Train Discriminator ---
        optimizer_D.zero_grad()
        loss_D_real_accum = 0; loss_D_fake_accum = 0
        for t in range(model_G.T_pred_horizon):
            real_target_y_flat = list_of_batched_timesteps[t+1].y.to(device).squeeze()
            real_target_images_t = real_target_y_flat.view(current_B, 1, H_img, W_img)
            real_target_images_scaled_t = (real_target_images_t - target_mean_dev) / (target_std_dev + 1e-8)
            fused_condition_t = stacked_fused_conditions[:, t, :, :, :]
            D_input_real = torch.cat((fused_condition_t, real_target_images_scaled_t), dim=1)
            pred_real = model_D(D_input_real)
            loss_D_real_accum += criterion_GAN(pred_real, torch.ones_like(pred_real, device=device))
            fake_images_t = generated_sequence_images[:, t, :, :, :].detach()
            D_input_fake = torch.cat((fused_condition_t, fake_images_t), dim=1)
            pred_fake = model_D(D_input_fake)
            loss_D_fake_accum += criterion_GAN(pred_fake, torch.zeros_like(pred_fake, device=device))
        loss_D = 0.5 * (loss_D_real_accum + loss_D_fake_accum) / model_G.T_pred_horizon # Average over T
        if not torch.isnan(loss_D): loss_D.backward(); optimizer_D.step()
        total_loss_D += loss_D.item() * current_B

        # --- Train Generator ---
        optimizer_G.zero_grad()
        loss_G_GAN_accum = 0; loss_G_L1_accum = 0
        building_mask_flat = list_of_batched_timesteps[1].building_mask.to(device) # [B*H*W]
        image_mask_for_loss = (~building_mask_flat).view(current_B, 1, H_img, W_img) # [B,1,H,W]

        for t in range(model_G.T_pred_horizon):
            fake_images_t = generated_sequence_images[:, t, :, :, :]
            fused_condition_t = stacked_fused_conditions[:, t, :, :, :]
            D_input_for_G = torch.cat((fused_condition_t, fake_images_t), dim=1)
            pred_fake_for_G = model_D(D_input_for_G)
            loss_G_GAN_accum += criterion_GAN(pred_fake_for_G, torch.ones_like(pred_fake_for_G, device=device))

            real_target_y_flat = list_of_batched_timesteps[t+1].y.to(device).squeeze()
            real_target_images_t = real_target_y_flat.view(current_B, 1, H_img, W_img)
            real_target_images_scaled_t = (real_target_images_t - target_mean_dev) / (target_std_dev + 1e-8)
            loss_G_L1_accum += masked_l1_loss(fake_images_t, real_target_images_scaled_t, image_mask_for_loss)

        avg_loss_G_GAN = loss_G_GAN_accum / model_G.T_pred_horizon
        avg_loss_G_L1 = loss_G_L1_accum / model_G.T_pred_horizon
        loss_G = avg_loss_G_GAN + lambda_L1 * avg_loss_G_L1
        if not torch.isnan(loss_G): loss_G.backward(); optimizer_G.step()
        total_loss_G += loss_G.item() * current_B
        total_loss_G_GAN += avg_loss_G_GAN.item() * current_B
        total_loss_G_L1 += avg_loss_G_L1.item() * current_B
        num_sequences_processed += current_B

    epoch_duration = time.time() - epoch_start_time
    avg_loss_G = total_loss_G/num_sequences_processed if num_sequences_processed > 0 else 0
    avg_loss_D = total_loss_D/num_sequences_processed if num_sequences_processed > 0 else 0
    avg_loss_G_GAN = total_loss_G_GAN/num_sequences_processed if num_sequences_processed > 0 else 0
    avg_loss_G_L1 = total_loss_G_L1/num_sequences_processed if num_sequences_processed > 0 else 0
    return avg_loss_G, avg_loss_D, avg_loss_G_GAN, avg_loss_G_L1, epoch_duration

def evaluate_epoch_unet_pixel_lstm_cgan( # Renamed
    model_G, loader, criterion_L1, device, timeline_time_features, H_img, W_img,
    static_img_feat_mean, static_img_feat_std, target_mean, target_std, lambda_L1, epoch_type="Eval"
):
    model_G.eval()
    all_batch_predictions_flat_scaled = [] # This comes from model_G's 3rd output
    all_batch_targets_flat_scaled = []
    all_batch_masks_flat_for_metrics = []
    total_eval_loss_G_L1_scaled = 0
    num_sequences_processed = 0
    model_G.static_image_feat_mean = static_img_feat_mean.to(device)
    model_G.static_image_feat_std = static_img_feat_std.to(device)
    target_mean_dev = target_mean.to(device); target_std_dev = target_std.to(device)
    eval_start_time = time.time()

    with torch.no_grad():
        for list_of_batched_timesteps in loader:
            current_B = list_of_batched_timesteps[0].num_graphs
            generated_sequence_images, _, predictions_final_flat_scaled = model_G(list_of_batched_timesteps, timeline_time_features.to(device), device)

            eval_loss_G_L1_accum_batch = 0
            building_mask_flat = list_of_batched_timesteps[1].building_mask.to(device)
            image_mask_for_loss = (~building_mask_flat).view(current_B, 1, H_img, W_img)

            temp_targets_scaled_T_list = []
            for t in range(model_G.T_pred_horizon):
                fake_images_t = generated_sequence_images[:, t, :, :, :]
                real_target_y_flat = list_of_batched_timesteps[t+1].y.to(device).squeeze()
                real_target_images_t = real_target_y_flat.view(current_B, 1, H_img, W_img)
                real_target_images_scaled_t = (real_target_images_t - target_mean_dev) / (target_std_dev + 1e-8)
                eval_loss_G_L1_accum_batch += masked_l1_loss(fake_images_t, real_target_images_scaled_t, image_mask_for_loss)
                temp_targets_scaled_T_list.append(real_target_images_scaled_t.permute(0,2,3,1).reshape(-1)) # Flatten for metrics

            total_eval_loss_G_L1_scaled += (eval_loss_G_L1_accum_batch / model_G.T_pred_horizon).item() * current_B
            all_batch_predictions_flat_scaled.append(predictions_final_flat_scaled.cpu())
            all_batch_targets_flat_scaled.append(torch.stack(temp_targets_scaled_T_list, dim=1).cpu()) # [B*H*W, T]
            all_batch_masks_flat_for_metrics.append((~list_of_batched_timesteps[1].building_mask).cpu())
            num_sequences_processed += current_B

    eval_duration = time.time() - eval_start_time
    avg_eval_loss_G_L1_scaled = total_eval_loss_G_L1_scaled / num_sequences_processed if num_sequences_processed > 0 else 0
    if not all_batch_predictions_flat_scaled:
        empty_metrics = {t: {'mse':np.nan,'mae':np.nan,'rmse':np.nan,'r2':np.nan,'count':0} for t in range(model_G.T_pred_horizon)}
        return avg_eval_loss_G_L1_scaled, empty_metrics, eval_duration
    final_predictions_flat_scaled = torch.cat(all_batch_predictions_flat_scaled, dim=0)
    final_targets_flat_scaled = torch.cat(all_batch_targets_flat_scaled, dim=0)
    final_masks_flat_for_metrics = torch.cat(all_batch_masks_flat_for_metrics, dim=0)
    hourly_metrics_original_scale = calculate_hourly_metrics(final_predictions_flat_scaled, final_targets_flat_scaled, final_masks_flat_for_metrics, target_mean.cpu(), target_std.cpu())
    return avg_eval_loss_G_L1_scaled, hourly_metrics_original_scale, eval_duration

# ===========================================================
# 5. 主训练流程 (U-Net Per-Pixel LSTM + Heads + cGAN)
# ===========================================================
def calculate_aggregated_metrics_report(hourly_metrics_dict, T_pred_horizon): # No change
    # ... (Identical)
    metrics_to_aggregate = ['r2', 'mse', 'mae', 'rmse']
    aggregated_report = {}
    for metric_name in metrics_to_aggregate:
        values = [hourly_metrics_dict[t][metric_name] for t in range(T_pred_horizon) if t in hourly_metrics_dict and not np.isnan(hourly_metrics_dict[t][metric_name])]
        if values: aggregated_report[f'avg_{metric_name}'] = np.mean(values); aggregated_report[f'std_{metric_name}'] = np.std(values)
        else: aggregated_report[f'avg_{metric_name}'] = np.nan; aggregated_report[f'std_{metric_name}'] = np.nan
    return aggregated_report

def main_training_unet_pixel_lstm_cgan( # Renamed
    all_sequences_data: list, config: dict, time_features_for_dataset: torch.Tensor
):
    train_start_time = time.time(); report_data = {'config': config}
    seed = config.get('seed', 42); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}"); report_data['device'] = str(device)
    T_PRED_HORIZON = config.get('T_pred_horizon',12); H_IMG=config.get('H_IMG',50); W_IMG=config.get('W_IMG',50)
    nodes_per_image = H_IMG*W_IMG; expected_input_len = T_PRED_HORIZON+1

    # Data validation and splitting (Identical to previous main_training function)
    valid_sequences_data = [] # ... (Copy from previous main_training) ...
    for seq in all_sequences_data:
        if isinstance(seq, list) and len(seq) == expected_input_len:
            valid_step = True
            for step_data in seq:
                if not (hasattr(step_data,'x') and step_data.x is not None and step_data.x.shape[0]%nodes_per_image==0 and step_data.x.shape[1]==config.get('static_feat_dim',16)):
                    valid_step=False; break
            if valid_step: valid_sequences_data.append(seq)
    if len(valid_sequences_data) != len(all_sequences_data): print(f"警告:筛选出{len(valid_sequences_data)}/{len(all_sequences_data)}有效序列。")
    if not valid_sequences_data: raise ValueError("无有效序列数据。")
    all_sequences_data = valid_sequences_data
    num_total_sequences = len(all_sequences_data); indices = np.random.permutation(num_total_sequences)
    train_s_r=config.get('train_split_ratio',0.7); val_s_r=config.get('val_split_ratio',0.2)
    train_size=int(train_s_r*num_total_sequences); val_size=int(val_s_r*num_total_sequences)
    train_idx=indices[:train_size]; val_idx=indices[train_size:train_size+val_size]; test_idx=indices[train_size+val_size:]
    train_dataset=[all_sequences_data[i] for i in train_idx]; val_dataset=[all_sequences_data[i] for i in val_idx]; test_dataset=[all_sequences_data[i] for i in test_idx]
    report_data['dataset_split']={'total':num_total_sequences,'train':len(train_dataset),'val':len(val_dataset),'test':len(test_dataset)}

    # Scaler calculation (Identical)
    all_7am_node_feat_list = [s[0].x for s in train_dataset if hasattr(s[0],'x') and s[0].x is not None] # ... (Copy) ...
    if not all_7am_node_feat_list: raise ValueError("训练数据7am无特征x！")
    all_7am_tensor = torch.cat(all_7am_node_feat_list,dim=0)
    static_img_mean=torch.mean(all_7am_tensor,dim=0); static_img_std=torch.std(all_7am_tensor,dim=0)
    static_img_std[static_img_std<1e-8]=1.0
    scaler_path_x = Path(config['results_dir'])/"static_img_scaler_unet_pixel_lstm_cgan.pth"
    torch.save({'mean':static_img_mean,'std':static_img_std},scaler_path_x);print(f"静态图像特征scaler:{scaler_path_x}")
    all_train_tgt_val_list = [] # ... (Copy target scaler code) ...
    for seq in train_dataset:
        for i_s, g_data in enumerate(seq):
            if i_s>0 and hasattr(g_data,'y') and g_data.y is not None:
                y_orig=g_data.y.squeeze(); mask_loss=~g_data.building_mask; valid_idx=mask_loss&~torch.isnan(y_orig)
                if valid_idx.sum()>0: all_train_tgt_val_list.append(y_orig[valid_idx])
    if not all_train_tgt_val_list: tgt_mean=torch.tensor(0.0); tgt_std=torch.tensor(1.0)
    else: all_tgt_tensor=torch.cat(all_train_tgt_val_list,dim=0); tgt_mean=torch.mean(all_tgt_tensor.float()); tgt_std=torch.std(all_tgt_tensor.float());
    if tgt_std<1e-8: tgt_std=torch.tensor(1.0)
    tgt_scaler_path=Path(config['results_dir'])/"tgt_scaler_unet_pixel_lstm_cgan.pth"
    torch.save({'mean':tgt_mean,'std':tgt_std},tgt_scaler_path); print(f"目标y scaler:{tgt_scaler_path}")

    # DataLoaders (Identical)
    batch_size=config.get('batch_size',4); num_workers=config.get('num_workers',0); pin_mem=config.get('pin_memory',False)&(device.type=='cuda')
    train_loader=DataLoader(train_dataset,batch_size,shuffle=True,drop_last=True,num_workers=num_workers,pin_memory=pin_mem)
    val_loader=DataLoader(val_dataset,batch_size,shuffle=False,drop_last=False,num_workers=num_workers,pin_memory=pin_mem)
    test_loader=DataLoader(test_dataset,batch_size,shuffle=False,drop_last=False,num_workers=num_workers,pin_memory=pin_mem)

    # Model Initialization
    s_g_7am = all_sequences_data[0][0] # sample_graph_7am_for_dims
    stat_feat_dim = s_g_7am.x.shape[1]; glob_env_in_dim=s_g_7am.graph_global_env_features.shape[-1]; time_in_dim=time_features_for_dataset.shape[1]
    if stat_feat_dim != config.get('static_feat_dim',16): config['static_feat_dim']=stat_feat_dim

    model_G = UNetPerPixelLSTMHeadModel_Generator( # GENERATOR
        static_feat_dim=config['static_feat_dim'], global_env_in_dim=glob_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim'), time_emb_dim=config.get('time_emb_dim'),
        unet_input_channels_after_fusion=config.get('unet_input_channels_after_fusion'), dropout_rate_pre_fusion=config.get('dropout_rate_pre_fusion'),
        unet_feature_output_channels=config.get('unet_feature_output_channels'), unet_encoder_channels_list=config.get('unet_encoder_channels_list'),
        unet_middle_channels_val=config.get('unet_middle_channels_val'), unet_decoder_channels_list=config.get('unet_decoder_channels_list'),
        h0_unet_input_channels=config.get('h0_unet_input_channels'), h0_unet_output_channels=config.get('h0_unet_output_channels'),
        h0_unet_enc_ch_list=config.get('h0_unet_enc_ch_list'), h0_unet_mid_ch=config.get('h0_unet_mid_ch'), h0_unet_dec_ch_list=config.get('h0_unet_dec_ch_list'),
        lstm_hidden_dim=config.get('lstm_hidden_dim'), num_lstm_layers=config.get('num_lstm_layers'), dropout_rate_lstm=config.get('dropout_rate_lstm'),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim'), dropout_rate_pred_head=config.get('dropout_rate_pred_head'),
        T_pred_horizon=T_PRED_HORIZON, H_img=H_IMG, W_img=W_IMG, dropout_rate_other_mlps=config.get('dropout_rate_other_mlps')
    ).to(device)
    model_D = PatchGANDiscriminator(input_channels=config.get('discriminator_input_channels')).to(device)
    model_G.static_image_feat_mean = static_img_mean.to(device); model_G.static_image_feat_std = static_img_std.to(device)
    params_G = sum(p.numel() for p in model_G.parameters() if p.requires_grad); params_D = sum(p.numel() for p in model_D.parameters() if p.requires_grad)
    print(f"Generator Params: {params_G:,}, Discriminator Params: {params_D:,}"); report_data.update({'p_G':params_G, 'p_D':params_D, 'p_total':params_G+params_D})

    # FLOPS Calculation
    print("\nCalculating MACs for UNet Per-Pixel LSTM cGAN components:"); report_data['component_gmacs'] = {}
    dummy_B=1; model_G.eval(); model_D.eval() # ... (FLOPS sections for G sub-modules and D as in previous cGAN block) ...
    try: macs_ge = torchprofile.profile_macs(model_G.global_env_encoder, args=(torch.randn(dummy_B, glob_env_in_dim, device=device),)) / 1e9
    except Exception: macs_ge="Error" ; print(f"Error profiling G_global_env_encoder_mlp")
    report_data['component_gmacs']['G_global_env_encoder_mlp(single)'] = macs_ge
    try: macs_t = torchprofile.profile_macs(model_G.time_encoder, args=(torch.randn(time_in_dim, device=device),)) / 1e9
    except Exception: macs_t="Error"; print(f"Error profiling G_time_encoder_mlp")
    report_data['component_gmacs']['G_time_encoder_mlp(single)'] = macs_t
    dummy_stat_img = torch.randn(dummy_B, config['static_feat_dim'], H_IMG, W_IMG, device=device)
    dummy_exp_ge = torch.randn(dummy_B, config.get('global_env_emb_dim'), H_IMG, W_IMG, device=device)
    dummy_exp_te = torch.randn(dummy_B, config.get('time_emb_dim'), H_IMG, W_IMG, device=device)
    try: macs_pre_fus = torchprofile.profile_macs(model_G.pre_unet_fusion, args=(dummy_stat_img, dummy_exp_ge, dummy_exp_te)) / 1e9
    except Exception: macs_pre_fus="Error"; print(f"Error profiling G_pre_unet_fusion")
    report_data['component_gmacs']['G_pre_unet_fusion(single)'] = macs_pre_fus
    dummy_unet_in = torch.randn(dummy_B, config.get('unet_input_channels_after_fusion'), H_IMG, W_IMG, device=device)
    try: macs_unet_ext = torchprofile.profile_macs(model_G.unet_feature_extractor, args=(dummy_unet_in,)) / 1e9
    except Exception: macs_unet_ext="Error"; print(f"Error profiling G_unet_feature_extractor")
    report_data['component_gmacs']['G_unet_feature_extractor(single)'] = macs_unet_ext
    dummy_h0_unet_in = torch.randn(dummy_B, config.get('h0_unet_input_channels'), H_IMG, W_IMG, device=device) # If H0 UNet input is different
    try: macs_h0_ext = torchprofile.profile_macs(model_G.h0_pixel_feature_extractor, args=(dummy_h0_unet_in if config.get('h0_unet_input_channels') != config.get('static_feat_dim') else dummy_stat_img,)) / 1e9
    except Exception: macs_h0_ext="Error"; print(f"Error profiling G_h0_pixel_feature_extractor")
    report_data['component_gmacs']['G_h0_pixel_feature_extractor'] = macs_h0_ext
    try: # LSTM Manual
        lstm_l = model_G.lstm; N_b_L = dummy_B*H_IMG*W_IMG; L_s_L=T_PRED_HORIZON; H_i_L=model_G.lstm_input_size_for_profiling; H_h_L=model_G.lstm_hidden_dim_for_profiling; n_l_L=lstm_l.num_layers
        macs_LST = N_b_L*L_s_L*4*(H_i_L*H_h_L+H_h_L*H_h_L)
        if n_l_L>1: macs_LST+=N_b_L*L_s_L*(n_l_L-1)*4*(H_h_L*H_h_L+H_h_L*H_h_L)
        gmacs_LST = macs_LST/1e9
    except Exception: gmacs_LST="Error"; print(f"Error in LSTM MACs")
    report_data['component_gmacs']['G_lstm_layer(manual)'] = gmacs_LST
    try: macs_ph = torchprofile.profile_macs(model_G.hourly_prediction_heads[0], args=(torch.randn(dummy_B*H_IMG*W_IMG, config.get('lstm_hidden_dim'), device=device),)) / 1e9
    except Exception: macs_ph="Error"; print(f"Error profiling G_prediction_head_mlp")
    report_data['component_gmacs']['G_prediction_head_mlp(single)'] = macs_ph
    print(f" G GMACs Est(single): GEEnc={macs_ge:.3f},TEnc={macs_t:.3f},PreFus={macs_pre_fus:.3f},UNetExt={macs_unet_ext:.3f},H0Ext={macs_h0_ext:.3f},LSTM={gmacs_LST:.3f},PredHead={macs_ph:.3f}")
    try: # Discriminator
        dummy_D_in = torch.randn(dummy_B, config.get('discriminator_input_channels'), H_IMG, W_IMG, device=device)
        macs_D = torchprofile.profile_macs(model_D, args=(dummy_D_in,)) / 1e9
        report_data['component_gmacs']['D_patchgan_discriminator(single)'] = macs_D
        print(f"  Discriminator PatchGAN GMACs (single pass): {macs_D:.4f}")
    except Exception as e: print(f"  Error profiling Discriminator: {e}")
    model_G.train(); model_D.train()

    # Optimizers, Schedulers, Loss criteria (as in previous cGAN block)
    opt_G = torch.optim.Adam(model_G.parameters(),lr=config.get('lr_G',0.0002),betas=(config.get('beta1',0.5),config.get('beta2',0.999)))
    opt_D = torch.optim.Adam(model_D.parameters(),lr=config.get('lr_D',0.0002),betas=(config.get('beta1',0.5),config.get('beta2',0.999)))
    sch_G = ReduceLROnPlateau(opt_G,'min',factor=0.5,patience=config.get('scheduler_patience_G',25),verbose=True)
    sch_D = ReduceLROnPlateau(opt_D,'min',factor=0.5,patience=config.get('scheduler_patience_D',25),verbose=True)
    crit_GAN = nn.MSELoss().to(device); crit_L1 = nn.L1Loss().to(device)
    lambda_L1_val = config.get('lambda_L1',100.0)

    best_val_L1 = float('inf'); best_val_metrics = None; best_ep = 0; pat_count = 0
    max_ep = config.get('max_epochs',100); early_stop_pat = config.get('early_stopping_patience',30) # Reduced for quick testing
    model_G_path = Path(config['results_dir'])/f"best_unet_pixel_lstm_cgan_G_s{seed}.pth"
    model_D_path = Path(config['results_dir'])/f"best_unet_pixel_lstm_cgan_D_s{seed}.pth"
    timeline_feat_dev = time_features_for_dataset.to(device)
    tgt_mean_cpu = tgt_mean.cpu(); tgt_std_cpu = tgt_std.cpu()
    ep_times = []

    for ep in range(1, max_ep + 1):
        avg_L_G, avg_L_D, avg_L_G_GAN, avg_L_G_L1, ep_dur = train_epoch_unet_pixel_lstm_cgan(
            model_G, model_D, train_loader, opt_G, opt_D, crit_GAN, crit_L1, lambda_L1_val,
            device, timeline_feat_dev, H_IMG, W_IMG, static_img_mean, static_img_std, tgt_mean, tgt_std
        )
        ep_times.append(ep_dur)
        val_L1, val_metrics, _ = evaluate_epoch_unet_pixel_lstm_cgan(
            model_G, val_loader, crit_L1, device, timeline_feat_dev, H_IMG, W_IMG,
            static_img_mean, static_img_std, tgt_mean, tgt_std, lambda_L1_val, epoch_type="Validation"
        )
        sch_G.step(val_L1); sch_D.step(avg_L_D)
        print(f"Ep {ep:03d}|T:{ep_dur:.1f}s|G_L:{avg_L_G:.4f}(GAN:{avg_L_G_GAN:.4f},L1:{avg_L_G_L1:.4f})|D_L:{avg_L_D:.4f}|Val_L1:{val_L1:.4f}|LR_G:{opt_G.param_groups[0]['lr']:.6f}")
        _print_hourly_metrics_summary("Val", val_metrics, T_PRED_HORIZON, indent=" "*21)
        if val_L1 < best_val_L1:
            best_val_L1=val_L1; best_val_metrics=val_metrics; best_ep=ep; pat_count=0
            torch.save(model_G.state_dict(),model_G_path); torch.save(model_D.state_dict(),model_D_path)
            print(f"{' '*21}---> Best models saved (Ep:{ep}, Val L1:{best_val_L1:.4f})")
        else: pat_count+=1
        if pat_count>=early_stop_pat: print(f"Early stopping at ep {ep}."); break

    # Reporting and final evaluation (copy & adapt from previous main_training)
    report_data['total_training_time_seconds'] = time.time()-train_start_time # ...
    report_data['average_epoch_time_seconds'] = np.mean(ep_times) if ep_times else np.nan
    report_data['num_epochs_trained'] = ep; report_data['best_validation_epoch'] = best_ep
    report_data['best_validation_scaled_L1_loss'] = best_val_L1

    model_G_eval = UNetPerPixelLSTMHeadModel_Generator( # Instantiate with same config
        static_feat_dim=config['static_feat_dim'], global_env_in_dim=glob_env_in_dim, time_in_dim=time_in_dim,
        global_env_emb_dim=config.get('global_env_emb_dim'), time_emb_dim=config.get('time_emb_dim'),
        unet_input_channels_after_fusion=config.get('unet_input_channels_after_fusion'), dropout_rate_pre_fusion=config.get('dropout_rate_pre_fusion'),
        unet_feature_output_channels=config.get('unet_feature_output_channels'), unet_encoder_channels_list=config.get('unet_encoder_channels_list'),
        unet_middle_channels_val=config.get('unet_middle_channels_val'), unet_decoder_channels_list=config.get('unet_decoder_channels_list'),
        h0_unet_input_channels=config.get('h0_unet_input_channels'), h0_unet_output_channels=config.get('h0_unet_output_channels'),
        h0_unet_enc_ch_list=config.get('h0_unet_enc_ch_list'), h0_unet_mid_ch=config.get('h0_unet_mid_ch'), h0_unet_dec_ch_list=config.get('h0_unet_dec_ch_list'),
        lstm_hidden_dim=config.get('lstm_hidden_dim'), num_lstm_layers=config.get('num_lstm_layers'), dropout_rate_lstm=config.get('dropout_rate_lstm'),
        mlp_prediction_hidden_dim=config.get('mlp_prediction_hidden_dim'), dropout_rate_pred_head=config.get('dropout_rate_pred_head'),
        T_pred_horizon=T_PRED_HORIZON, H_img=H_IMG, W_img=W_IMG, dropout_rate_other_mlps=config.get('dropout_rate_other_mlps')
    ).to(device)
    try: model_G_eval.load_state_dict(torch.load(model_G_path, map_location=device))
    except Exception as e: print(f"无法加载最佳G模型({e})"); model_G_eval = model_G

    print("\nEval G on Train..."); train_L1, train_metrics, train_dur = evaluate_epoch_unet_pixel_lstm_cgan(model_G_eval, train_loader, crit_L1, device, timeline_feat_dev, H_IMG,W_IMG, static_img_mean, static_img_std, tgt_mean_cpu, tgt_std_cpu, lambda_L1_val, "Best G on Train")
    report_data.update({'best_G_train_metrics_hr':train_metrics, 'best_G_train_metrics_agg':calculate_aggregated_metrics_report(train_metrics,T_PRED_HORIZON), 'best_G_train_eval_time_s':train_dur})
    _print_hourly_metrics_summary("最佳G训练集", train_metrics, T_PRED_HORIZON)
    report_data['best_G_val_metrics_hr'] = best_val_metrics
    if best_val_metrics: report_data['best_G_val_metrics_agg'] = calculate_aggregated_metrics_report(best_val_metrics, T_PRED_HORIZON)
    _print_hourly_metrics_summary("最佳G验证集", report_data['best_G_val_metrics_hr'], T_PRED_HORIZON)
    print("\nEval G on Test..."); test_L1, test_metrics, test_dur = evaluate_epoch_unet_pixel_lstm_cgan(model_G_eval, test_loader, crit_L1, device, timeline_feat_dev, H_IMG,W_IMG, static_img_mean, static_img_std, tgt_mean_cpu, tgt_std_cpu, lambda_L1_val, "Test")
    report_data.update({'test_infer_time_s':test_dur, 'best_G_test_metrics_hr':test_metrics, 'best_G_test_metrics_agg':calculate_aggregated_metrics_report(test_metrics,T_PRED_HORIZON), 'test_scaled_L1':test_L1})
    print(f"\n{'='*10} Final Test (U-Net Pixel LSTM cGAN - G) {'='*10}\nAvg Test Scaled L1: {test_L1:.4f}")
    _print_hourly_metrics_summary("测试集 (G)", test_metrics, T_PRED_HORIZON)
    agg_tst = report_data['best_G_test_metrics_agg'] # ... (rest of printout for agg_test)
    print(f"Avg Test MSE(Orig): {agg_tst.get('avg_mse',np.nan):.4f} (Std:{agg_tst.get('std_mse',np.nan):.4f})")
    print(f"Avg Test R2 (Orig): {agg_tst.get('avg_r2',np.nan):.4f} (Std:{agg_tst.get('std_r2',np.nan):.4f})")
    print(f"Avg Test MAE(Orig): {agg_tst.get('avg_mae',np.nan):.4f} (Std:{agg_tst.get('std_mae',np.nan):.4f})")
    print(f"Avg Test RMSE(Orig):{agg_tst.get('avg_rmse',np.nan):.4f} (Std:{agg_tst.get('std_rmse',np.nan):.4f})")
    print("="*70)
    report_file = Path(config['results_dir'])/f"report_unet_pixel_lstm_cgan_s{seed}.json" # ... (JSON save)
    try:
        class NpEncoder(json.JSONEncoder):
            def default(self,o):
                if isinstance(o,np.integer):return int(o)
                if isinstance(o,np.floating):return float(o)
                if isinstance(o,np.ndarray):return o.tolist()
                if isinstance(o,torch.Tensor):return o.tolist()
                if isinstance(o,Path):return str(o)
                return super(NpEncoder,self).default(o)
        with open(report_file,'w') as f: json.dump(report_data,f,indent=4,cls=NpEncoder)
        print(f"训练报告: {report_file}")
    except Exception as e: print(f"保存训练报告失败: {e}")
    return model_G_eval, model_D, static_img_mean, static_img_std, tgt_mean, tgt_std


def _print_hourly_metrics_summary(set_name, hourly_metrics, T_pred_horizon, indent="  "): # No change
    # ... (Identical)
    if hourly_metrics is None: print(f"{indent}{set_name} metrics not available."); return
    print(f"\n{indent}每小时 {set_name} 指标 (Original Scale):")
    header_printed = False
    for hour_idx in range(T_pred_horizon):
        metrics = hourly_metrics.get(hour_idx, {'mse': np.nan, 'mae': np.nan, 'rmse': np.nan, 'r2': np.nan, 'count':0})
        if not header_printed: print(f"{indent}  Hour | {'R2':>13s} | {'MSE':>14s} | {'MAE':>14s} | {'RMSE':>15s} | {'Count':>7s}"); header_printed = True
        print(f"{indent}  {hour_idx:02d}   | {metrics.get('r2',np.nan):13.4f} | {metrics.get('mse',np.nan):14.4f} | {metrics.get('mae',np.nan):14.4f} | {metrics.get('rmse',np.nan):15.4f} | {metrics.get('count',0):7d}")
    aggregated = calculate_aggregated_metrics_report(hourly_metrics, T_pred_horizon)
    print(f"{indent}  Aggregated Avg R2   : {aggregated.get('avg_r2', np.nan):.4f} (Std: {aggregated.get('std_r2', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MSE  : {aggregated.get('avg_mse', np.nan):.4f} (Std: {aggregated.get('std_mse', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg MAE  : {aggregated.get('avg_mae', np.nan):.4f} (Std: {aggregated.get('std_mae', np.nan):.4f})")
    print(f"{indent}  Aggregated Avg RMSE : {aggregated.get('avg_rmse', np.nan):.4f} (Std: {aggregated.get('std_rmse', np.nan):.4f})")


# ===========================================================
# 6. 主执行块 (Adapted for U-Net Per-Pixel LSTM cGAN)
# ===========================================================
if __name__ == "__main__":
    gc.collect(); torch.cuda.empty_cache()

    DRIVE_BASE_PATH=Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process")
    DRIVE_BASE_PATH.mkdir(parents=True, exist_ok=True)
    DATA_SUBDIR=Path("Result/Sequential_12Hour_Data"); DATA_FILENAME="graph_seq_20230503_SeqH7to19_NpyH8fill0.0.pkl"
    RESULTS_SUBDIR=Path("Result/Final_UNetPixelLSTM1"); RESULTS_SAVE_DIR=DRIVE_BASE_PATH/RESULTS_SUBDIR
    RESULTS_SAVE_DIR.mkdir(parents=True,exist_ok=True); DATA_PATH=DRIVE_BASE_PATH/DATA_SUBDIR/DATA_FILENAME
    DATA_YEAR=2023; DATA_MONTH=5; DATA_DAY=3; START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES=8; PREDICTION_HORIZON=12
    H_IMG,W_IMG = 50,50

    training_config = {
        'seed':42, 'batch_size':8, 'lr_G':0.0002, 'lr_D':0.0002, 'beta1':0.5, 'beta2':0.999, 'lambda_L1':100.0,
        'max_epochs':1000, 'scheduler_patience_G':20, 'scheduler_patience_D':20, 'early_stopping_patience':45, # Set max_epochs higher for real training
        'T_pred_horizon':PREDICTION_HORIZON, 'results_dir':str(RESULTS_SAVE_DIR), 'H_IMG':H_IMG, 'W_IMG':W_IMG,
        'static_feat_dim':16, 'global_env_emb_dim':16, 'time_emb_dim':8,
        'unet_input_channels_after_fusion':32, 'dropout_rate_pre_fusion':0.1,
        'unet_feature_output_channels':128, 'unet_encoder_channels_list':(64,128,256),
        'unet_middle_channels_val':512, 'unet_decoder_channels_list':(256,128,64),
        # H0 U-Net params (can be smaller than main U-Net)
        'h0_unet_input_channels': 16, # Takes the 16-channel static image
        'h0_unet_output_channels': 128, # Output channels of U-Net used within H0 extractor
        'h0_unet_enc_ch_list': (32, 64, 128), 'h0_unet_mid_ch': 256, 'h0_unet_dec_ch_list': (128, 64, 32),
        'lstm_hidden_dim':128, 'num_lstm_layers':1, 'dropout_rate_lstm':0.2, # LSTM hidden dim is output of H0 projector, and input to heads
        'mlp_prediction_hidden_dim':64, 'dropout_rate_pred_head':0.2,
        'dropout_rate_other_mlps':0.1,
        'discriminator_input_channels':32+1, # C_fused (32) + 1 (output image)
        'use_amp':False, 'enable_profiler':False, 'num_workers':0, 'pin_memory':False,
        'train_split_ratio':0.7, 'val_split_ratio':0.2,
    }

    all_graph_sequences_loaded = None
    try: # Data loading (as before)
        if not DATA_PATH.exists(): raise FileNotFoundError(f"Data file not found: {DATA_PATH}")
        with open(DATA_PATH, "rb") as f: all_graph_sequences_loaded = pickle.load(f)
        if not all_graph_sequences_loaded or not isinstance(all_graph_sequences_loaded, list) or \
           not all_graph_sequences_loaded[0] or not isinstance(all_graph_sequences_loaded[0], list):
            raise ValueError("Loaded data is not in expected list-of-lists format or is empty.")
        print(f"Original loaded sequences: {len(all_graph_sequences_loaded)}")
    except Exception as e: print(f"Error loading data: {e}"); all_graph_sequences_loaded = None

    if all_graph_sequences_loaded:
        base_dt_timeline = dt_datetime(DATA_YEAR,DATA_MONTH,DATA_DAY,START_HOUR_OF_DAY_FOR_TIMELINE_FEATURES)
        time_features_for_dataset = generate_time_features_for_sequence(base_dt_timeline, training_config['T_pred_horizon'])
        main_training_unet_pixel_lstm_cgan(all_graph_sequences_loaded, training_config, time_features_for_dataset)
        print("U-Net Per-Pixel LSTM cGAN training and evaluation complete!")
    else:
        print("Training not started due to data loading issues.")