In [None]:
!pip install pvlib
# 1. 卸载导致冲突的库
# -y 选项可以自动确认卸载，无需手动输入 'y'
!pip uninstall -y tensorflow numba

# 2. 重新运行高效的 PyG 安装脚本
import torch

# 检测当前 PyTorch 和 CUDA 版本
TORCH = torch.__version__.split('+')[0]
# 确保 CUDA 版本字符串正确处理，例如 '12.4' 变为 'cu124'
if torch.version.cuda:
    CUDA = "cu" + torch.version.cuda.replace('.', '')
else:
    CUDA = 'cpu'

# 打印检测到的版本，方便确认
print(f"PyTorch version: {TORCH}")
print(f"CUDA version: {CUDA}")

# 使用 -q (quiet) 参数减少输出
# 使用 --no-cache-dir 避免使用缓存
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

# 3. 验证安装
try:
    import torch_geometric
    print("\nPyG installation successful!")
    print(f"PyG version: {torch_geometric.__version__}")
except ImportError:
    print("\nPyG installation failed.")

# GraphData Construction

In [None]:
# %%capture
# !pip install torch_geometric
# !pip install pvlib # 需要安装此库用于太阳位置计算

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from torch_scatter import scatter_mean # Assuming torch_scatter is available if add_neighbor_agg=True


import networkx as nx
import pickle
import re
# import textwrap # Removed unused
from pathlib import Path
import scipy.sparse.linalg as spla
from sklearn.metrics import pairwise_distances

import pvlib # 用于太阳位置计算
from datetime import datetime, timezone # timezone需要显式导入
import math

# ==================== 太阳位置计算模块 ====================
class SolarCalculator:
    def __init__(self, latitude, longitude, altitude=0, timezone_str='Asia/Singapore'):
        self.latitude = latitude
        self.longitude = longitude
        self.altitude = altitude
        self.timezone_str = timezone_str

    def get_solar_position(self, year, month, day, hour_of_day):
        dt_local = datetime(year, month, day, hour_of_day, 0, 0)
        dt_aware_local = pd.Timestamp(dt_local, tz=self.timezone_str)
        times = pd.DatetimeIndex([dt_aware_local])

        solar_position = pvlib.solarposition.get_solarposition(
            time=times,
            latitude=self.latitude,
            longitude=self.longitude,
            altitude=self.altitude,
            temperature=25, # 平均温度，对太阳位置影响不大
            pressure=pvlib.atmosphere.alt2pres(self.altitude)
        )
        azimuth_deg = solar_position['azimuth'].iloc[0]
        elevation_deg = solar_position['elevation'].iloc[0]

        return {
            'azimuth_deg': azimuth_deg,
            'elevation_deg': elevation_deg,
            'azimuth_rad': math.radians(azimuth_deg),
            'elevation_rad': math.radians(elevation_deg)
        }

# ==================== CSV气象数据处理模块 ===================
class CSVWeatherParser:
    def __init__(self, csv_path: str):
        self.required_columns = [
            'Dry Bulb Temperature {C}', 'Relative Humidity {%}',
            'Wind Speed {m/s}', 'Atmospheric Pressure {Pa}',
            'Global Horizontal Radiation {Wh/m2}', 'Wind Direction {deg}'
        ]
        self.df = self._load_and_preprocess_csv(csv_path)

    def _load_and_preprocess_csv(self, path: str) -> pd.DataFrame:
        try:
            df = pd.read_csv(path)
        except Exception as e:
            raise ValueError(f"无法读取CSV文件 '{path}': {e}")

        expected_datetime_cols = ['Date', 'HH:MM']
        for col in expected_datetime_cols:
            if col not in df.columns:
                raise ValueError(f"CSV文件 '{path}' 缺少必要的日期/时间列: '{col}'")
        try:
            df['datetime'] = pd.to_datetime(
                df['Date'].astype(str) + ' ' + df['HH:MM'].astype(str),
                format='%Y/%m/%d %H:%M', errors='raise'
            )
        except Exception as e:
            raise ValueError(f"解析日期时间列时出错: {e}.")

        df['month'] = df['datetime'].dt.month
        df['day'] = df['datetime'].dt.day
        df['hour'] = df['datetime'].dt.hour

        all_required_data_cols = self.required_columns.copy()
        for col_name_csv in all_required_data_cols: # Renamed to avoid conflict
            if col_name_csv not in df.columns:
                raise ValueError(f"CSV文件 '{path}' 缺少必要的气象数据列: '{col_name_csv}'")
            df[col_name_csv] = pd.to_numeric(df[col_name_csv], errors='coerce')
        return df

    def get_hourly_data(self, target_month: int, target_day: int, target_hour: int) -> dict:
        mask = (
            (self.df['month'] == target_month) &
            (self.df['day'] == target_day) &
            (self.df['hour'] == target_hour)
        )
        selected_data = self.df[mask]
        fallback_data = pd.DataFrame() # Initialize to empty dataframe

        if selected_data.empty:
            fallback_mask = (self.df['month'] == target_month) & (self.df['hour'] == target_hour)
            fallback_data = self.df[fallback_mask]
            if not fallback_data.empty:
                print(f"警告: 未找到 {target_month}-{target_day} {target_hour}:00 的精确气象数据。将使用该月该小时的平均/首条记录。")
                selected_data = fallback_data.iloc[[0]]
            else:
                raise ValueError(f"未找到气象数据：月份={target_month}, 日期={target_day}, 小时={target_hour}, 且无回退数据。")

        if len(selected_data) > 1 and (fallback_data.empty or (not fallback_data.empty and len(fallback_data) > 1)):
            print(f"警告: 找到 {len(selected_data)} 条记录：月份={target_month}, 日期={target_day}, 小时={target_hour}. 将使用第一条记录。")

        row = selected_data.iloc[0].copy()
        for col_name_csv in self.required_columns: # Renamed to avoid conflict
            if pd.isna(row[col_name_csv]):
                mean_val = self.df[(self.df['month'] == target_month) & (self.df['hour'] == target_hour)][col_name_csv].mean()
                if pd.notna(mean_val):
                    print(f"警告: 在 {target_month}-{target_day} {target_hour}:00 的数据中，列 '{col_name_csv}' 无效(NaN)。已用月均值 {mean_val:.2f} 填充。")
                    row[col_name_csv] = mean_val
                else:
                    raise ValueError(f"在 {target_month}-{target_day} {target_hour}:00 的数据中，列 '{col_name_csv}' 包含无效值 (NaN) 且无法用月均值填充。")

        wind_direction_deg_raw = row['Wind Direction {deg}']
        wind_blowing_to_meteo_deg = (wind_direction_deg_raw + 180) % 360

        hourly_params = {
            'temperature_c': row['Dry Bulb Temperature {C}'],
            'humidity_percent': row['Relative Humidity {%}'],
            'wind_speed_mps': row['Wind Speed {m/s}'],
            'pressure_pa': row['Atmospheric Pressure {Pa}'],#只是不要这个参数了
            'global_horizontal_radiation_whm2': row['Global Horizontal Radiation {Wh/m2}'],
            'wind_direction_from_deg': wind_direction_deg_raw,
            'wind_blowing_to_meteo_deg': wind_blowing_to_meteo_deg,
            'wind_blowing_to_meteo_rad': math.radians(wind_blowing_to_meteo_deg)
        }
        return hourly_params

# ==================== 数据预处理模块 ===================
class ClimateDataPreprocessor:
    def __init__(self, input_path, output_path, window_size=50, stride=40,
                 global_tree_max=None, global_building_max=None):
        self.input_data = np.load(input_path).astype(np.float32)
        raw_output = np.load(output_path).astype(np.float32)
        self.window_size = window_size
        self.stride = stride
        self.building_mask_full = (self.input_data[:, :, 2] > 0)
        self.global_tree_max = global_tree_max if global_tree_max is not None else self.input_data[:, :, 0].max()
        if self.global_tree_max == 0: self.global_tree_max = 1.0
        self.global_building_max = global_building_max if global_building_max is not None else self.input_data[:, :, 2].max()
        if self.global_building_max == 0: self.global_building_max = 1.0

        self.output_data, self.nan_report = self._process_nan(raw_output)
        self._validate_shapes()

    def _validate_shapes(self):
        input_shape = self.input_data.shape
        output_shape = self.output_data.shape
        assert len(input_shape) == 3 and input_shape[:2] == (250, 250), \
            f"输入数据形状应为(250,250,3)，实际得到{input_shape}"
        assert len(output_shape) == 4 and output_shape[1:3] == (250, 250), \
            f"输出数据形状应为(6,250,250,12)，实际得到{output_shape}"

    def _process_nan(self, raw_output):
        building_mask = self.building_mask_full
        nan_mask = np.isnan(raw_output)
        building_mask_expanded = building_mask[np.newaxis, :, :, np.newaxis]
        non_building_nan_initial = np.sum(nan_mask & ~building_mask_expanded)
        cleaned = raw_output.copy()
        for var_idx in range(cleaned.shape[0]):
            for t_idx in range(cleaned.shape[3]):
                slice_data = cleaned[var_idx, :, :, t_idx]
                mask_non_building_current_slice = ~building_mask
                valid_mean_non_building = np.nanmean(slice_data[mask_non_building_current_slice])
                if np.isnan(valid_mean_non_building):
                    valid_mean_non_building = 0
                slice_data[np.isnan(slice_data) & mask_non_building_current_slice] = valid_mean_non_building
                cleaned[var_idx, :, :, t_idx] = slice_data
        remaining_nan_in_non_building = np.sum(np.isnan(cleaned) & ~building_mask_expanded)
        replaced_count = non_building_nan_initial - remaining_nan_in_non_building
        report = {
            "total_nan_raw": np.sum(np.isnan(raw_output)),
            "nan_in_buildings_raw": np.sum(nan_mask & building_mask_expanded),
            "nan_outside_buildings_raw": non_building_nan_initial,
            "total_nan_cleaned": np.sum(np.isnan(cleaned)),
            "nan_in_buildings_cleaned": np.sum(np.isnan(cleaned) & building_mask_expanded),
            "replaced_nan_outside_buildings": replaced_count
        }
        if report["total_nan_cleaned"] > report["nan_in_buildings_cleaned"]:
                print(f"Warning: NANs still exist outside building areas after cleaning. Count: {report['total_nan_cleaned'] - report['nan_in_buildings_cleaned']}")
        return cleaned, report

    def _generate_windows(self, data, is_output=False):
        h, w = data.shape[1:3] if is_output else data.shape[:2]
        windows = []
        for i in range(0, h - self.window_size + 1, self.stride):
            for j in range(0, w - self.window_size + 1, self.stride):
                if is_output:
                    window = data[:, i:i+self.window_size, j:j+self.window_size, :]
                else:
                    window = data[i:i+self.window_size, j:j+self.window_size]
                windows.append(window)
        return np.array(windows)

    def process(self):
        input_features = self._create_input_features()
        input_windows = self._generate_windows(input_features)
        output_windows = self._generate_windows(self.output_data, is_output=True)
        return input_windows, output_windows

    def _create_input_features(self):
        surface = self.input_data[:, :, 1].astype(int)
        surface_clipped = np.clip(surface, 1, 5)
        onehot = np.eye(5)[surface_clipped - 1]
        tree_norm = self.input_data[:, :, 0] / self.global_tree_max
        building_norm = self.input_data[:, :, 2] / self.global_building_max
        building_mask_float = self.building_mask_full.astype(float)
        # Base features: 1 (tree) + 5 (onehot) + 1 (bldg_norm) + 1 (bldg_mask) = 8 features
        return np.concatenate([
            tree_norm[..., None],
            onehot,
            building_norm[..., None],
            building_mask_float[..., None]
        ], axis=-1)

# ==================== 图结构构建模块 (核心修改区域) ===================
# 定义边类型常量
EDGE_TYPE_TREE_ACTIVITY = 0.0
EDGE_TYPE_SHADOW = 1.0
EDGE_TYPE_LOCAL_WIND = 2.0 # 用于边缘节点和地面节点的局部连接
EDGE_TYPE_SIMILARITY = 3.0
EDGE_TYPE_INTERNAL_N8 = 4.0 # 新增：用于内部建筑/树木节点的N8连接

class GraphConstructor:
    def __init__(self,
                 global_tree_max_val,
                 global_building_max_val,
                 grid_size=4,
                 k_similarity=8,
                 target_attr_index=5,
                 base_building_shadow_radius_max_grids=15,
                 base_tree_shadow_radius_max_grids=5,
                 base_tree_activity_radius_max_grids=4,
                 wind_effect_on_radius_factor=0.3,
                 max_expected_wind_speed=8.0,
                 shadow_angular_width_deg = 30.0,
                 base_edge_weight=1.0,
                 shadow_influence_weight_factor=1.5, # 在当前简化权重计算中未使用
                 wind_alignment_weight_factor=0.5,   # 在当前简化权重计算中未使用
                 distance_decay_factor_per_grid=0.1, # 用于非相似性边的衰减
                 similarity_dist_decay_factor_per_grid=0.005, # 新增：相似性边的专属衰减率 (例如，常规衰减率的一小部分)
                 actual_shadow_boost_factor=1.05,
                 tree_activity_height_influence_factor=0.2,
                 knn_node_feature_normalization_epsilon=1e-6,
                 internal_n8_weight=0.1 # 在当前简化权重计算中未使用
                ):

        self.grid_size = grid_size
        self.k_similarity = k_similarity
        self.target_attr_index = target_attr_index
        self.global_tree_max = global_tree_max_val
        self.global_building_max = global_building_max_val
        self.base_building_shadow_radius_max_grids = base_building_shadow_radius_max_grids
        self.base_tree_shadow_radius_max_grids = base_tree_shadow_radius_max_grids
        self.base_tree_activity_radius_max_grids = base_tree_activity_radius_max_grids
        self.wind_effect_on_radius_factor = wind_effect_on_radius_factor
        self.max_expected_wind_speed = max_expected_wind_speed if max_expected_wind_speed > 0 else 1.0
        self.shadow_angular_width_rad = math.radians(shadow_angular_width_deg)
        self.base_edge_weight = base_edge_weight
        self.shadow_influence_weight_factor = shadow_influence_weight_factor
        self.wind_alignment_weight_factor = wind_alignment_weight_factor
        self.distance_decay_factor_per_grid = distance_decay_factor_per_grid
        self.similarity_dist_decay_factor_per_grid = similarity_dist_decay_factor_per_grid # 新增属性
        self.actual_shadow_boost_factor = actual_shadow_boost_factor
        self.knn_epsilon = knn_node_feature_normalization_epsilon
        self.internal_n8_weight = internal_n8_weight
        self.tree_activity_height_influence_factor = tree_activity_height_influence_factor # <<<< 保存新参数


    def _is_internal_node(self, r, c, window_h, window_w, hcoords, node_features_local, is_checking_tree: bool):
        neighbor_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        for dr_offset, dc_offset in neighbor_offsets:
            nr, nc = r + dr_offset, c + dc_offset
            if not (0 <= nr < window_h and 0 <= nc < window_w):
                return False
            neighbor_node_idx = hcoords.get((nr, nc))
            if neighbor_node_idx is None:
                 return False
            neighbor_feat = node_features_local[neighbor_node_idx]
            if is_checking_tree:
                if not (neighbor_feat[0] > 0):
                    return False
            else: # is_checking_building
                if not (neighbor_feat[7] > 0.5):
                    return False
        return True

    def build_graph(self, input_window_base_features, output_window_all_hours,
                    hourly_weather_params: dict, solar_params: dict,
                    target_hour_index_in_day: int):

        node_features_local, positions = self._extract_local_node_features(
            input_window_base_features
        )
        num_nodes = len(node_features_local)

        graph_global_env_f = np.array([
            hourly_weather_params['temperature_c'],
            hourly_weather_params['humidity_percent'],
            hourly_weather_params['wind_speed_mps'],
            hourly_weather_params['pressure_pa'],
            hourly_weather_params['global_horizontal_radiation_whm2'],
            solar_params['azimuth_rad'],
            solar_params['elevation_rad']
        ], dtype=np.float32)

        hcoords = {}
        window_h, window_w = input_window_base_features.shape[:2]
        for r_idx in range(window_h):
            for c_idx in range(window_w):
                node_idx = r_idx * window_w + c_idx
                hcoords[(r_idx, c_idx)] = node_idx

        edge_index_dynamic, edge_attr_dynamic, edge_weights_dynamic = self._build_dynamic_local_edges(
            node_features_local, positions, hcoords, window_h, window_w,
            hourly_weather_params, solar_params
        )

        edge_index_sim, edge_attr_sim, edge_weights_sim = self._build_edges_topk_feature_similarity(
            node_features_local, positions, hourly_weather_params, solar_params
        )

        edge_index_final, edge_attr_final_no_weight, edge_weights_final = self._merge_undirected_edges(
            edge_index_dynamic, edge_attr_dynamic, edge_weights_dynamic,
            edge_index_sim, edge_attr_sim, edge_weights_sim
        )

        node_targets = self._process_targets(output_window_all_hours, target_hour_index_in_day)

        if edge_index_final.size > 0 and edge_index_final.max() >= num_nodes:
            raise ValueError(
                f"检测到越界节点索引：edge_index 存在 {edge_index_final.max()}，节点总数仅 {num_nodes}！"
            )

        building_mask_for_window_flat = (input_window_base_features[:, :, 7].flatten() > 0.5)

        if edge_weights_final.size > 0:
            edge_weights_reshaped = edge_weights_final.reshape(-1, 1)
            edge_attr_final_with_weight = np.concatenate([edge_attr_final_no_weight, edge_weights_reshaped], axis=-1)
        else:
            expected_cols = 5 + 1 # 4 local physical + 1 type + 1 weight
            edge_attr_final_with_weight = np.empty((0, expected_cols) , dtype=np.float32)

        data = Data(
            x=torch.FloatTensor(node_features_local),
            edge_index=torch.LongTensor(edge_index_final).t().contiguous() if edge_index_final.size > 0 else torch.empty((2,0), dtype=torch.long),
            edge_attr=torch.FloatTensor(edge_attr_final_with_weight),
            edge_weight=torch.FloatTensor(edge_weights_final) if edge_weights_final.size > 0 else torch.empty(0, dtype=torch.float),
            y=torch.FloatTensor(node_targets),
            pos=torch.FloatTensor(positions),
            building_mask=torch.BoolTensor(building_mask_for_window_flat),
            hourly_weather=torch.FloatTensor([v for k,v in sorted(hourly_weather_params.items())]),
            solar_position=torch.FloatTensor([v for k,v in sorted(solar_params.items())]),
            graph_global_env_features=torch.FloatTensor(graph_global_env_f)
        )
        return data

    def _extract_local_node_features(self, window_base_features):
        h, w, c_base = window_base_features.shape
        features_list = window_base_features.reshape(-1, c_base)
        positions_list = []
        for r_idx in range(h):
            for c_idx in range(w):
                positions_list.append([c_idx * self.grid_size, r_idx * self.grid_size])
        return np.array(features_list), np.array(positions_list)

    def _process_targets(self, output_window_all_hours, target_hour_index_in_day):
        target_slice_one_hour = output_window_all_hours[
            self.target_attr_index, :, :, target_hour_index_in_day
        ]
        return target_slice_one_hour.reshape(-1, 1)

    # 在 GraphConstructor 类中：
    def _calculate_edge_weight(self, dist_m, wind_align_cos, wind_speed, is_shadow_interaction, edge_type, tree_height_norm_factor=None):
        # dist_m: 边的物理距离（米）
        # edge_type: 当前边的类型
        # 其他参数 (wind_align_cos, wind_speed, is_shadow_interaction) 在此版本中暂时不用

        current_weight = self.base_edge_weight  # 初始权重，例如 1.0

        if edge_type == EDGE_TYPE_SIMILARITY:
            # 对相似性边应用专属的、更温和的线性距离衰减
            if self.similarity_dist_decay_factor_per_grid > 0 and dist_m > 0:
                dist_grids = dist_m / self.grid_size
                current_weight /= (1.0 + self.similarity_dist_decay_factor_per_grid * dist_grids)
            # 如果 similarity_dist_decay_factor_per_grid <= 0 (或 dist_m == 0), 权重保持为 base_edge_weight

        else:
            # 对于其他类型的边，应用标准的线性距离衰减
            if self.distance_decay_factor_per_grid > 0 and dist_m > 0:
                dist_grids = dist_m / self.grid_size
                current_weight /= (1.0 + self.distance_decay_factor_per_grid * dist_grids)
            # 如果 distance_decay_factor_per_grid <= 0 (或 dist_m == 0), 权重保持为 base_edge_weight

        # ----- 新增：基于树高对树木活动边权重进行微弱调节 -----
        if edge_type == EDGE_TYPE_TREE_ACTIVITY and \
           tree_height_norm_factor is not None and \
           self.tree_activity_height_influence_factor > 0:
            # tree_height_norm_factor 预期是归一化的 (0到1)
            # self.tree_activity_height_influence_factor 是一个小的正值, e.g., 0.1, 0.2, 0.3
            # 这会使得权重在原有基础上，根据树高乘以一个 (1 + 0) 到 (1 + factor*1) 的系数
            modulation = 1.0 + self.tree_activity_height_influence_factor * tree_height_norm_factor
            current_weight *= modulation
        # ----- 结束新增部分 -----

        # ----- 新增：对被识别的阴影边进行微弱放大 -----
        if edge_type == EDGE_TYPE_SHADOW and is_shadow_interaction:
            current_weight *= self.actual_shadow_boost_factor
        # ----- 结束新增部分 -----

        return max(current_weight, 0.01) # 保证权重不为0或负


    def _build_dynamic_local_edges(self, node_features_local, positions, hcoords,
                                   window_h, window_w,
                                   hourly_weather_params: dict, solar_params: dict):
        edge_index_list = []
        edge_attr_list = []
        edge_weights_list = []
        num_nodes = len(node_features_local)

        sol_elev_rad = solar_params['elevation_rad']
        wind_speed = hourly_weather_params['wind_speed_mps'] # Still needed for _calculate_edge_weight signature
        wind_blowing_to_meteo_rad = hourly_weather_params['wind_blowing_to_meteo_rad']
        radiation = hourly_weather_params['global_horizontal_radiation_whm2']
        shadow_main_direction_meteo_rad = (solar_params['azimuth_rad'] + math.pi) % (2 * math.pi)

        base_edge_attr_dim_no_globals = 4

        for i in range(num_nodes):
            src_feat_local = node_features_local[i]
            src_pos = positions[i]
            src_grid_r, src_grid_c = int(round(src_pos[1]/self.grid_size)), int(round(src_pos[0]/self.grid_size))

            is_tree_node = src_feat_local[0] > 0
            is_building_node = src_feat_local[7] > 0.5
            is_object_node = is_tree_node or is_building_node
            is_internal = False

            if is_object_node:
                is_internal = self._is_internal_node(src_grid_r, src_grid_c, window_h, window_w, hcoords, node_features_local, is_checking_tree=is_tree_node)

            if is_internal:
                edge_type_internal = EDGE_TYPE_INTERNAL_N8
                base_n8_radius_grids = 1
                for dr_n8 in range(-base_n8_radius_grids, base_n8_radius_grids + 1):
                    for dc_n8 in range(-base_n8_radius_grids, base_n8_radius_grids + 1):
                        if dr_n8 == 0 and dc_n8 == 0: continue
                        tgt_grid_r, tgt_grid_c = src_grid_r + dr_n8, src_grid_c + dc_n8
                        if (tgt_grid_r, tgt_grid_c) in hcoords:
                            j = hcoords[(tgt_grid_r, tgt_grid_c)]
                            dst_pos = positions[j]
                            dist = np.linalg.norm(src_pos - dst_pos)
                            wind_align_cos_internal = 0.0 # Placeholder for internal edges

                            edge_index_list.append([i, j])
                            current_edge_attr = [dist, float(dc_n8), float(dr_n8), wind_align_cos_internal, edge_type_internal]
                            edge_attr_list.append(current_edge_attr)
                            # Pass all params, but only dist_m and edge_type will be used by simplified weight calc
                            weight = self._calculate_edge_weight(dist, wind_align_cos_internal, wind_speed, False, edge_type_internal)
                            edge_weights_list.append(weight)
            else:
                if is_tree_node:
                    edge_type = EDGE_TYPE_TREE_ACTIVITY
                    activity_rad_factor = np.clip(radiation / 1000, 0.5, 1.2)
                    tree_activity_radius_grids = int(round(self.base_tree_activity_radius_max_grids * activity_rad_factor))
                    tree_activity_radius_grids = max(1, tree_activity_radius_grids)
                    # --- 获取归一化树高 ---
                    # 假设 src_feat_local[0] 是归一化的树木高度 (来自 ClimateDataPreprocessor._create_input_features)
                    normalized_tree_height = src_feat_local[0]
                    # --- 结束获取 ---
                    for dr in range(-tree_activity_radius_grids, tree_activity_radius_grids + 1):
                      for dc in range(-tree_activity_radius_grids, tree_activity_radius_grids + 1):
                        if dr == 0 and dc == 0: continue
                        if dr*dr + dc*dc > tree_activity_radius_grids*tree_activity_radius_grids: continue
                        tgt_grid_r, tgt_grid_c = src_grid_r + dr, src_grid_c + dc
                        if (tgt_grid_r, tgt_grid_c) in hcoords:
                          j = hcoords[(tgt_grid_r, tgt_grid_c)]
                          dst_pos = positions[j]
                          dist = np.linalg.norm(src_pos - dst_pos)
                          vec_ij = dst_pos - src_pos
                          if np.linalg.norm(vec_ij) < 1e-6: continue
                          angle_ij_cartesian_rad = math.atan2(vec_ij[1], vec_ij[0])
                          wind_blowing_to_cartesian_rad = (math.pi/2 - wind_blowing_to_meteo_rad) % (2*math.pi)
                          wind_align_cos = math.cos(angle_ij_cartesian_rad - wind_blowing_to_cartesian_rad)

                          edge_index_list.append([i, j])
                          current_edge_attr = [dist, float(dc), float(dr), wind_align_cos, edge_type]
                          edge_attr_list.append(current_edge_attr)
                          # --- 修改权重计算调用 ---
                          weight = self._calculate_edge_weight(
                            dist,
                            wind_align_cos,
                            wind_speed,
                            False, # is_shadow_interaction for tree activity is False
                            edge_type,
                            tree_height_norm_factor=normalized_tree_height # <<<< 传递归一化树高
                          )
                          # --- 结束修改 ---
                          edge_weights_list.append(weight)

                if is_object_node and sol_elev_rad > math.radians(1.0):
                    edge_type = EDGE_TYPE_SHADOW
                    obj_height_norm = src_feat_local[0] if is_tree_node else src_feat_local[6]
                    obj_height_abs = obj_height_norm * (self.global_tree_max if is_tree_node else self.global_building_max)
                    shadow_length_m = obj_height_abs / math.tan(sol_elev_rad) if math.tan(sol_elev_rad) > 1e-6 else obj_height_abs * 1000
                    base_max_grids_for_obj = self.base_tree_shadow_radius_max_grids if is_tree_node else self.base_building_shadow_radius_max_grids
                    shadow_length_grids = min(int(round(shadow_length_m / self.grid_size)), base_max_grids_for_obj)
                    shadow_length_grids = max(1, shadow_length_grids)
                    max_iter_radius = shadow_length_grids
                    for dr_s in range(-max_iter_radius, max_iter_radius + 1):
                        for dc_s in range(-max_iter_radius, max_iter_radius + 1):
                            if dr_s == 0 and dc_s == 0: continue
                            dist_grid_sq = dr_s*dr_s + dc_s*dc_s
                            if dist_grid_sq == 0 or dist_grid_sq > shadow_length_grids*shadow_length_grids : continue
                            tgt_grid_r, tgt_grid_c = src_grid_r + dr_s, src_grid_c + dc_s
                            if (tgt_grid_r, tgt_grid_c) in hcoords:
                                angle_to_target_meteo_rad = math.atan2(dc_s, -dr_s)
                                angle_diff = abs(angle_to_target_meteo_rad - shadow_main_direction_meteo_rad)
                                angle_diff = min(angle_diff, 2*math.pi - angle_diff)
                                if angle_diff <= self.shadow_angular_width_rad / 2.0:
                                    j = hcoords[(tgt_grid_r, tgt_grid_c)]
                                    dst_pos = positions[j]
                                    dist = np.linalg.norm(src_pos - dst_pos)
                                    vec_ij = dst_pos - src_pos
                                    if np.linalg.norm(vec_ij) < 1e-6: continue
                                    angle_ij_cartesian_rad = math.atan2(vec_ij[1], vec_ij[0])
                                    wind_blowing_to_cartesian_rad = (math.pi/2 - wind_blowing_to_meteo_rad) % (2*math.pi)
                                    wind_align_cos = math.cos(angle_ij_cartesian_rad - wind_blowing_to_cartesian_rad)
                                    edge_index_list.append([i, j])
                                    current_edge_attr = [dist, float(dc_s), float(dr_s), wind_align_cos, edge_type]
                                    edge_attr_list.append(current_edge_attr)
                                    weight = self._calculate_edge_weight(dist, wind_align_cos, wind_speed, True, edge_type) # is_shadow_interaction = True
                                    edge_weights_list.append(weight)

                edge_type_local_wind = EDGE_TYPE_LOCAL_WIND
                base_local_radius_grids = 1 if not is_object_node else 2
                for dr_w in range(-base_local_radius_grids, base_local_radius_grids + 1):
                    for dc_w in range(-base_local_radius_grids, base_local_radius_grids + 1):
                        if dr_w == 0 and dc_w == 0: continue
                        tgt_grid_r, tgt_grid_c = src_grid_r + dr_w, src_grid_c + dc_w
                        if (tgt_grid_r, tgt_grid_c) in hcoords:
                            j = hcoords[(tgt_grid_r, tgt_grid_c)]
                            dst_pos = positions[j]
                            vec_ij = dst_pos - src_pos
                            dist = np.linalg.norm(vec_ij)
                            if dist < 1e-6 : continue
                            angle_ij_cartesian_rad = math.atan2(vec_ij[1], vec_ij[0])
                            wind_blowing_to_cartesian_rad = (math.pi/2 - wind_blowing_to_meteo_rad) % (2*math.pi)
                            wind_align_cos = math.cos(angle_ij_cartesian_rad - wind_blowing_to_cartesian_rad)
                            wind_strength_factor = np.clip(wind_speed / self.max_expected_wind_speed, 0, 1)
                            radius_mod_factor = 1.0 + self.wind_effect_on_radius_factor * wind_align_cos * wind_strength_factor
                            effective_dist_for_connection = dist / radius_mod_factor if radius_mod_factor > 1e-6 else dist * 1e6
                            if effective_dist_for_connection <= base_local_radius_grids * self.grid_size:
                                edge_index_list.append([i, j])
                                current_edge_attr = [dist, float(dc_w), float(dr_w), wind_align_cos, edge_type_local_wind]
                                edge_attr_list.append(current_edge_attr)
                                weight = self._calculate_edge_weight(dist, wind_align_cos, wind_speed, False, edge_type_local_wind)
                                edge_weights_list.append(weight)

        if not edge_index_list:
            return np.empty((0, 2), dtype=np.int64), \
                   np.empty((0, base_edge_attr_dim_no_globals + 1), dtype=np.float32), \
                   np.empty(0, dtype=np.float32)

        return np.array(edge_index_list, dtype=np.int64), \
               np.array(edge_attr_list, dtype=np.float32), \
               np.array(edge_weights_list, dtype=np.float32)

    def _build_edges_topk_feature_similarity(self, node_features_local, positions,
                                             hourly_weather_params, solar_params):
        num_nodes = len(node_features_local)
        base_edge_attr_dim_no_globals = 4
        edge_type = EDGE_TYPE_SIMILARITY

        if num_nodes == 0 or self.k_similarity <= 0:
            return np.empty((0, 2), dtype=np.int64), \
                   np.empty((0, base_edge_attr_dim_no_globals + 1), dtype=np.float32), \
                   np.empty(0, dtype=np.float32)

        if num_nodes > 1 :
            mean_nf = np.mean(node_features_local, axis=0, keepdims=True)
            std_nf = np.std(node_features_local, axis=0, keepdims=True)
            node_features_normalized_for_knn = (node_features_local - mean_nf) / (std_nf + self.knn_epsilon)
        else:
            node_features_normalized_for_knn = node_features_local

        dist_mat_feat = pairwise_distances(node_features_normalized_for_knn, node_features_normalized_for_knn, metric='euclidean')

        edge_index_list, edge_attr_list, edge_weights_list = [], [], []
        wind_speed = hourly_weather_params['wind_speed_mps'] # Still needed for _calculate_edge_weight signature
        wind_blowing_to_meteo_rad = hourly_weather_params['wind_blowing_to_meteo_rad']

        for i in range(num_nodes):
            dist_mat_feat[i, i] = 1e9
            actual_k = min(self.k_similarity, num_nodes - 1 if num_nodes > 1 else 0)
            if actual_k == 0 and num_nodes > 1 and self.k_similarity > 0: actual_k = 1

            if actual_k > 0:
                topk_idx = np.argsort(dist_mat_feat[i])[:actual_k]
                src_pos = positions[i]
                for nbr_idx in topk_idx:
                    dst_pos = positions[nbr_idx]
                    dist = np.linalg.norm(src_pos - dst_pos)
                    vec_ij = dst_pos - src_pos
                    if np.linalg.norm(vec_ij) < 1e-6: continue
                    dx_grid = int(round(vec_ij[0] / self.grid_size))
                    dy_grid = int(round(vec_ij[1] / self.grid_size))
                    angle_ij_cartesian_rad = math.atan2(vec_ij[1], vec_ij[0])
                    wind_blowing_to_cartesian_rad = (math.pi/2 - wind_blowing_to_meteo_rad) % (2*math.pi)
                    wind_align_cos = math.cos(angle_ij_cartesian_rad - wind_blowing_to_cartesian_rad)
                    edge_index_list.append([i, nbr_idx])
                    current_edge_attr = [dist, float(dx_grid), float(dy_grid), wind_align_cos, edge_type]
                    edge_attr_list.append(current_edge_attr)
                    # Pass all params, but only dist_m and edge_type will be used by simplified weight calc
                    weight = self._calculate_edge_weight(dist, wind_align_cos, wind_speed, False, edge_type)
                    edge_weights_list.append(weight)

        if not edge_index_list:
            return np.empty((0, 2), dtype=np.int64), \
                   np.empty((0, base_edge_attr_dim_no_globals + 1), dtype=np.float32), \
                   np.empty(0, dtype=np.float32)
        return np.array(edge_index_list, dtype=np.int64), \
               np.array(edge_attr_list, dtype=np.float32), \
               np.array(edge_weights_list, dtype=np.float32)

    def _merge_undirected_edges(self, ei1, ea1, ew1, ei2, ea2, ew2):
        edge_attr_dim_with_type = 4 + 1

        if ei1.size == 0:
            edge_index_comb, edge_attr_comb, edge_weights_comb = ei2, ea2, ew2
        elif ei2.size == 0:
            edge_index_comb, edge_attr_comb, edge_weights_comb = ei1, ea1, ew1
        else:
            edge_index_comb = np.concatenate([ei1, ei2], axis=0)
            edge_attr_comb = np.concatenate([ea1, ea2], axis=0)
            edge_weights_comb = np.concatenate([ew1, ew2], axis=0)

        if edge_index_comb.shape[0] == 0:
            return np.empty((0,2), dtype=np.int64), \
                   np.empty((0, edge_attr_dim_with_type), dtype=np.float32), \
                   np.empty(0, dtype=np.float32)

        unique_directed_edges = {}
        for idx in range(edge_index_comb.shape[0]):
            src, dst = edge_index_comb[idx, 0], edge_index_comb[idx, 1]
            attr = edge_attr_comb[idx]
            weight = edge_weights_comb[idx]
            if src == dst: continue
            if (src,dst) not in unique_directed_edges:
                unique_directed_edges[(src,dst)] = (attr, weight)

        final_ei_list, final_ea_list, final_ew_list = [], [], []
        processed_undirected_pairs = set()

        for (s, d), (attr_s_d, weight_s_d) in unique_directed_edges.items():
            u, v = min(s,d), max(s,d)
            if (u,v) in processed_undirected_pairs:
                continue

            final_ei_list.append([s,d])
            final_ea_list.append(attr_s_d)
            final_ew_list.append(weight_s_d)

            if (d,s) in unique_directed_edges:
                attr_d_s, weight_d_s = unique_directed_edges[(d,s)]
                final_ei_list.append([d,s])
                final_ea_list.append(attr_d_s)
                final_ew_list.append(weight_d_s)
            else:
                attr_d_s_created = attr_s_d.copy()
                attr_d_s_created[1] *= -1 # dx (index 1)
                attr_d_s_created[2] *= -1 # dy (index 2)
                attr_d_s_created[3] *= -1 # wind_align_cos (index 3)
                # Edge type (index 4) remains the same
                final_ei_list.append([d,s])
                final_ea_list.append(attr_d_s_created)
                final_ew_list.append(weight_s_d)
            processed_undirected_pairs.add((u,v))

        if not final_ei_list:
            return np.empty((0,2), dtype=np.int64), \
                   np.empty((0, edge_attr_dim_with_type), dtype=np.float32), \
                   np.empty(0, dtype=np.float32)

        return np.array(final_ei_list), np.array(final_ea_list), np.array(final_ew_list)


# ==================== 图增广模块 ===================
class GraphAugmentor:
    def __init__(
        self,
        add_struct_features=False,
        add_neighbor_agg=False,
        add_edge_diff=False,
        use_laplacian_pe=False,
        lap_pe_dim=4,
        normalization='sym',
        max_iter=2000
    ):
        self.add_struct_features = add_struct_features
        self.add_neighbor_agg = add_neighbor_agg
        self.add_edge_diff = add_edge_diff
        self.use_laplacian_pe = use_laplacian_pe
        self.lap_pe_dim = lap_pe_dim
        self.normalization = normalization
        self.max_iter = max_iter
        if add_neighbor_agg and not ('scatter_mean' in globals() or 'torch_scatter' in globals()):
            raise ImportError("torch_scatter.scatter_mean is required for add_neighbor_agg=True. Please install torch-scatter.")


    def augment_static(self, data):
        if self.add_struct_features:
            self._add_structural_features(data)
        if self.add_neighbor_agg:
            self._add_neighbor_mean_features(data)
        if self.add_edge_diff:
            self._add_edge_diff(data)
        if self.use_laplacian_pe:
            self._add_laplacian_positional_encoding(data)
        return data

    def _add_structural_features(self, data):
        if data.num_nodes == 0 or data.num_edges == 0 : return
        try:
            G = to_networkx(data, to_undirected=True)
            deg_dict = dict(G.degree())
            clust_dict = nx.clustering(G)
            pr_dict = nx.pagerank(G, alpha=0.85, max_iter=100, tol=1.e-04)

            deg_list = [deg_dict.get(i, 0) for i in range(data.num_nodes)]
            clust_list = [clust_dict.get(i, 0.0) for i in range(data.num_nodes)]
            pr_list = [pr_dict.get(i, 0.0) for i in range(data.num_nodes)]

            deg_tensor = torch.tensor(deg_list, dtype=torch.float, device=data.x.device).unsqueeze(1)
            clust_tensor = torch.tensor(clust_list, dtype=torch.float, device=data.x.device).unsqueeze(1)
            pr_tensor = torch.tensor(pr_list, dtype=torch.float, device=data.x.device).unsqueeze(1)
            data.x = torch.cat([data.x, deg_tensor, clust_tensor, pr_tensor], dim=1)
        except Exception as e:
            print(f"Error in _add_structural_features: {e}")


    def _add_neighbor_mean_features(self, data):
        if data.num_nodes == 0 or data.num_edges == 0 : return
        try:
            row, col = data.edge_index
            x_mean = scatter_mean(data.x[col].float(), row, dim=0, dim_size=data.num_nodes)
            data.x = torch.cat([data.x, x_mean], dim=1)
        except Exception as e:
            print(f"Error in _add_neighbor_mean_features: {e}")


    def _add_edge_diff(self, data):
        if data.num_nodes == 0 or data.num_edges == 0 : return
        if not hasattr(data, 'edge_attr') or data.edge_attr is None:
            print("Warning: add_edge_diff called but data.edge_attr is None. Skipping.")
            return

        src_nodes_x = data.x[data.edge_index[0]].float()
        dst_nodes_x = data.x[data.edge_index[1]].float()
        feat_diffs = src_nodes_x - dst_nodes_x

        data.edge_attr = data.edge_attr.float()
        data.edge_attr = torch.cat([data.edge_attr, feat_diffs], dim=1)


    def _add_laplacian_positional_encoding(self, data, normalization='sym', edge_weight=None, max_iter=2000):
        if data.num_nodes == 0 or data.num_edges == 0 or self.lap_pe_dim <=0 : return
        try:
            G_nx = to_networkx(data, to_undirected=True, remove_self_loops=True)
            L_sp = nx.laplacian_matrix(G_nx).astype(float)
            N = data.num_nodes
            k = min(self.lap_pe_dim, N - 2 if N > 1 else 0)
            if k <= 0: return

            eigen_vals, eigen_vecs = spla.eigsh(
                L_sp, k=k, which='SM', tol=1e-3,
                maxiter=max_iter, return_eigenvectors=True
            )
            idx = np.argsort(eigen_vals)
            eigen_vecs = eigen_vecs[:, idx]
            lap_pe_tensor = torch.from_numpy(eigen_vecs).float().to(data.x.device)
            data.x = torch.cat([data.x, lap_pe_tensor], dim=1)
        except Exception as e:
            print(f"Error in _add_laplacian_positional_encoding: {e}")

# ==================== 辅助函数 ===================
def compute_global_maxes(input_dir):
    input_files = sorted(Path(input_dir).glob("Input_*.npy"))
    if not input_files: return 1.0, 1.0
    global_tree_max = -np.inf
    global_building_max = -np.inf
    has_valid_tree_max = False
    has_valid_bldg_max = False
    for file_path in input_files:
        data = np.load(file_path).astype(np.float32)
        if data.ndim == 3 and data.shape[2] >= 1 and data[:,:,0].size > 0 :
            current_tree_max = data[:, :, 0].max()
            if not np.isinf(current_tree_max) and not np.isnan(current_tree_max):
                global_tree_max = max(global_tree_max, current_tree_max)
                has_valid_tree_max = True
        if data.ndim == 3 and data.shape[2] >= 3 and data[:,:,2].size > 0:
            current_bldg_max = data[:, :, 2].max()
            if not np.isinf(current_bldg_max) and not np.isnan(current_bldg_max):
                global_building_max = max(global_building_max, current_bldg_max)
                has_valid_bldg_max = True
    final_tree_max = global_tree_max if has_valid_tree_max and global_tree_max > 0 else 1.0
    final_bldg_max = global_building_max if has_valid_bldg_max and global_building_max > 0 else 1.0
    return final_tree_max, final_bldg_max


def generate_data_report(graph, save_report=False):
    report = ["="*40 + "\n图数据结构分析报告 (单小时)\n" + "="*40]
    report.append(f"\n[文件ID (如有)]: {getattr(graph, 'file_id', 'N/A')}")
    report.append(f"[小时 (如有)]: {getattr(graph, 'hour_of_day', 'N/A')}")
    report.append(f"\n[维度信息]")
    report.append(f"节点数量: {graph.num_nodes}")
    report.append(f"边数量: {graph.num_edges}")
    report.append(f"节点特征维度: {graph.x.shape}")
    edge_attr_shape_str = 'N/A'
    if hasattr(graph, 'edge_attr') and graph.edge_attr is not None:
        edge_attr_shape_str = str(graph.edge_attr.shape)
        if graph.edge_attr.shape[1] == 6: # 4 local physical + 1 type + 1 weight
            report.append(f"  (边特征包含: 4局部物理 + 1类型 + 1权重)")
        elif graph.edge_attr.shape[1] == 5:
             report.append(f"  (边特征包含: 4局部物理 + 1类型)")

    report.append(f"边特征维度 (edge_attr): {edge_attr_shape_str}")
    report.append(f"边权重维度 (edge_weight): {graph.edge_weight.shape if hasattr(graph, 'edge_weight') and graph.edge_weight is not None else 'N/A'}")
    report.append(f"目标值维度: {graph.y.shape}")

    report.append("\n[节点特征X (前5行样本)]:\n" + str(graph.x[:5].numpy()))
    if hasattr(graph, 'edge_attr') and graph.edge_attr is not None and graph.num_edges > 0:
        report.append("\n[边特征EA (前5行样本)]:\n" + str(graph.edge_attr[:5].numpy()))
        if graph.edge_attr.shape[1] >= 5:
            edge_types_sample = graph.edge_attr[:min(5, graph.num_edges), 4].numpy()
            report.append(f"  边类型 (前{len(edge_types_sample)}条边, 第5列): {edge_types_sample}")
        if graph.edge_attr.shape[1] >= 6:
            edge_weights_in_attr_sample = graph.edge_attr[:min(5, graph.num_edges), 5].numpy()
            report.append(f"  边权重 (来自边特征第6列, 前{len(edge_weights_in_attr_sample)}条): {edge_weights_in_attr_sample}")

    if hasattr(graph, 'edge_weight') and graph.edge_weight is not None and graph.num_edges > 0:
        report.append("\n[独立边权重EW (前5行样本)]:\n" + str(graph.edge_weight[:5].numpy()))
        report.append(f"  独立边权重统计: min={graph.edge_weight.min().item():.3f}, max={graph.edge_weight.max().item():.3f}, mean={graph.edge_weight.mean().item():.3f}")

    # --- 新增：专门统计阴影边的权重 ---
    report.append("\n[阴影边 (EDGE_TYPE_SHADOW) 权重统计]")
    if hasattr(graph, 'edge_attr') and graph.edge_attr is not None and \
       hasattr(graph, 'edge_weight') and graph.edge_weight is not None and \
       graph.num_edges > 0 and graph.edge_attr.shape[1] >= 5: # 需要至少5列来获取类型

        try:
            # 假设 EDGE_TYPE_SHADOW 是全局定义的常量 (例如 1.0)
            # 边类型在 edge_attr 的第5列 (索引4)
            edge_types = graph.edge_attr[:, 4]
            shadow_edge_mask = (edge_types == EDGE_TYPE_SHADOW) # EDGE_TYPE_SHADOW = 1.0

            num_shadow_edges = shadow_edge_mask.sum().item()
            report.append(f"  数量: {num_shadow_edges}")

            if num_shadow_edges > 0:
                shadow_weights = graph.edge_weight[shadow_edge_mask]
                report.append(f"  Min: {shadow_weights.min().item():.4f}")
                report.append(f"  Max: {shadow_weights.max().item():.4f}")
                report.append(f"  Mean: {shadow_weights.mean().item():.4f}")
                report.append(f"  Std: {shadow_weights.std().item():.4f}")
            else:
                report.append("  未找到阴影边。")
        except IndexError:
            report.append("  错误：无法从edge_attr中提取边类型（维度不足）。")
        except Exception as e:
            report.append(f"  提取阴影边权重时出错: {e}")
    else:
        report.append("  无边或无edge_attr/edge_weight用于统计阴影边。")
    # --- 阴影边权重统计结束 ---

    # --- 新增：专门统计树木活动边的权重 ---
    report.append("\n[树木活动边 (EDGE_TYPE_TREE_ACTIVITY=0.0) 权重统计]") # 假设 EDGE_TYPE_TREE_ACTIVITY = 0.0
    if hasattr(graph, 'edge_attr') and graph.edge_attr is not None and \
       hasattr(graph, 'edge_weight') and graph.edge_weight is not None and \
       graph.num_edges > 0 and graph.edge_attr.shape[1] >= 5: # 需要至少5列来获取类型

        try:
            edge_types = graph.edge_attr[:, 4]
            tree_activity_edge_mask = (edge_types == EDGE_TYPE_TREE_ACTIVITY) # 使用定义好的常量或直接用值 0.0

            num_tree_activity_edges = tree_activity_edge_mask.sum().item()
            report.append(f"  数量: {num_tree_activity_edges}")

            if num_tree_activity_edges > 0:
                tree_activity_weights = graph.edge_weight[tree_activity_edge_mask]
                report.append(f"  Min: {tree_activity_weights.min().item():.4f}")
                report.append(f"  Max: {tree_activity_weights.max().item():.4f}")
                report.append(f"  Mean: {tree_activity_weights.mean().item():.4f}")
                report.append(f"  Std: {tree_activity_weights.std().item():.4f}")
            else:
                report.append("  未找到树木活动边。")
        except IndexError:
            report.append("  错误：无法从edge_attr中提取边类型（维度不足）。")
        except Exception as e:
            report.append(f"  提取树木活动边权重时出错: {e}")
    else:
        report.append("  无边或无edge_attr/edge_weight用于统计树木活动边。")
    # --- 树木活动边权重统计结束 ---

    report.append("\n[目标值Y (前5行样本)]:\n" + str(graph.y[:5].numpy()))

    if hasattr(graph, 'graph_global_env_features') and graph.graph_global_env_features is not None:
        report.append("\n[图级别全局环境特征 (7维)]:\n" + str(graph.graph_global_env_features.numpy()))

    if hasattr(graph, 'hourly_weather') and graph.hourly_weather is not None:
        report.append("\n[原始小时气象参数]:\n" + str(graph.hourly_weather.numpy()))
    if hasattr(graph, 'solar_position') and graph.solar_position is not None:
        report.append("\n[原始太阳位置参数]:\n" + str(graph.solar_position.numpy()))
    final_report = "\n".join(report)
    print(final_report)

def verify_edge_structure(graph, grid_size=4, sample_size=10):
    if not hasattr(graph, 'edge_attr') or graph.edge_attr is None or graph.edge_attr.size(0) == 0 or graph.edge_attr.size(1) < 3:
        print("[verify_edge_structure] 警告：edge_attr 列数不足 (<3) 或无边，无法验证 (dist, dx, dy)")
        return

    edge_attrs_np = graph.edge_attr.cpu().numpy()
    edge_index_np = graph.edge_index.cpu().numpy()
    pos_np = graph.pos.cpu().numpy()

    dist_all = edge_attrs_np[:, 0]
    print(f"\n[verify_edge_structure] 边距离统计 (来自edge_attr[:,0]): min={dist_all.min():.2f}, max={dist_all.max():.2f}, mean={dist_all.mean():.2f}")
    if np.any(dist_all < 0): print("   [警告] 发现负距离！")

    num_edges_to_sample = min(sample_size, graph.num_edges)
    if num_edges_to_sample == 0:
        print("   无边可供抽样检查。")
        return

    sampled_eids = np.random.choice(graph.num_edges, num_edges_to_sample, replace=False)
    print(f"\n   抽样 {num_edges_to_sample} 条边进行 (dist, dx, dy) 对比:")
    for i, eid in enumerate(sampled_eids, 1):
        src_idx, dst_idx = edge_index_np[0, eid], edge_index_np[1, eid]
        dist_stored = edge_attrs_np[eid, 0]
        dx_grid_stored = edge_attrs_np[eid, 1]
        dy_grid_stored = edge_attrs_np[eid, 2]
        pos_src, pos_dst = pos_np[src_idx], pos_np[dst_idx]
        dist_calc = np.linalg.norm(pos_dst - pos_src)
        dx_calc_m = pos_dst[0] - pos_src[0]
        dy_calc_m = pos_dst[1] - pos_src[1]
        dx_grid_calc = dx_calc_m / grid_size
        dy_grid_calc = dy_calc_m / grid_size
        print(f"   [{i}] EdgeID={eid} ({src_idx}->{dst_idx})")
        print(f"     存储: dist={dist_stored:.2f}, dx_g={dx_grid_stored:.2f}, dy_g={dy_grid_stored:.2f}")
        print(f"     计算: dist={dist_calc:.2f}, dx_g={dx_grid_calc:.2f}, dy_g={dy_grid_calc:.2f}")
        if abs(dist_stored-dist_calc) > 1e-1 or abs(dx_grid_stored-dx_grid_calc) > 0.6 or abs(dy_grid_stored-dy_grid_calc) > 0.6 :
            print(f"     [警告] 差异较大!")
        else:
            print(f"     [OK]")

def generate_sequence_y_report(graph_sequence, sequence_index=0, num_y_samples=5):
    """
    为单个图序列中的每个图的 y 值生成报告。
    """
    if not graph_sequence:
        print(f"序列 {sequence_index} 为空，无法生成y值报告。")
        return

    report_lines = [
        f"\n{'='*25} Y 值详细报告: 序列 {sequence_index} (第一个成功处理的空间窗口) {'='*25}",
        f"序列中图的数量 (小时数): {len(graph_sequence)}"
    ]

    first_graph_for_ids = None
    for g in graph_sequence: # 找到第一个非None的图来获取ID信息
        if g is not None:
            first_graph_for_ids = g
            break

    if first_graph_for_ids:
        report_lines.append(f"文件ID (来自序列首图): {getattr(first_graph_for_ids, 'file_id', 'N/A')}")
        report_lines.append(f"窗口索引 (来自序列首图): {getattr(first_graph_for_ids, 'window_index', 'N/A')}")
    else:
        report_lines.append("警告: 序列中所有图均为None，无法获取文件ID或窗口索引。")


    for i, graph_data in enumerate(graph_sequence):
        if graph_data is None:
            report_lines.append(f"\n--- [序列内索引 {i}] 图数据: None ---")
            continue

        actual_hour = getattr(graph_data, 'hour_of_day', 'N/A')
        hour_in_seq_idx = getattr(graph_data, 'hour_index_in_sequence', i) # 应该与 i 一致

        report_lines.append(f"\n--- [序列内索引 {hour_in_seq_idx}] 实际小时: {actual_hour if actual_hour != 'N/A' else '(未知)'}:00 ---")

        y_tensor = graph_data.y
        if y_tensor is None:
            report_lines.append("  y 值: None")
        else:
            # 将y张量展平以便统计，即使它原本是 (N, 1)，展平后不影响统计纯数值
            y_np = y_tensor.cpu().numpy().flatten()
            num_elements_in_y = len(y_np) # y中元素的总数，通常等于节点数
            report_lines.append(f"  y 原始张量形状: {list(y_tensor.shape)}")
            report_lines.append(f"  y 中元素数量 (通常为节点数): {num_elements_in_y}")

            if num_elements_in_y > 0:
                nan_count = np.sum(np.isnan(y_np)) # 使用 np.sum(np.isnan(...)) 更准确
                non_nan_y = y_np[~np.isnan(y_np)]

                report_lines.append(f"  y 中 NaN 值数量: {nan_count}")

                if nan_count == num_elements_in_y:
                    report_lines.append("  y 值统计: 全部为 NaN")
                elif non_nan_y.size > 0 : # 确保有非NaN值才进行统计
                    report_lines.append(f"  y 值统计 (非NaN部分): Min={np.min(non_nan_y):.4f}, Max={np.max(non_nan_y):.4f}, Mean={np.mean(non_nan_y):.4f}, Std={np.std(non_nan_y):.4f}")

                    # 从原始y_np中抽样，这样可以看到NaN值（如果存在并被抽到）
                    samples_to_show = min(num_y_samples, num_elements_in_y)
                    # 确保即使元素少于samples_to_show也能正确运行
                    if num_elements_in_y <= samples_to_show:
                        sampled_indices = np.arange(num_elements_in_y)
                    else:
                        sampled_indices = np.random.choice(num_elements_in_y, samples_to_show, replace=False)

                    sampled_values = y_np[sampled_indices]
                    report_lines.append(f"  y 值随机样本 ({len(sampled_values)}个): {sampled_values}")
                elif nan_count > 0 and non_nan_y.size == 0 : # 逻辑上这个分支不应该在 nan_count != num_elements_in_y 时出现
                    report_lines.append("  y 值统计: 所有值均为 NaN (与 non_nan_y.size == 0 一致)")
                else: # num_elements_in_y > 0 and nan_count == 0 and non_nan_y.size == 0 (不应该发生)
                     report_lines.append("  y 值: 状态异常 (例如，非NaN数组为空但NaN计数为0)")
            else: # num_elements_in_y == 0
                report_lines.append("  y 值: 空 (y中无元素/节点数为0)")

    report_lines.append(f"\n{'='*70}")
    print("\n".join(report_lines))

# 在 process_12_hour_sequences 函数定义处添加新的参数
def process_12_hour_sequences(
    input_dir_str: str,
    output_dir_str: str,
    csv_path_str: str,
    target_date_tuple: tuple,
    start_hour_of_day_sequence: int,
    num_hours_in_sequence: int,
    output_npy_effective_start_hour: int, # 新增：Output.npy中数据对应的起始小时
    output_npy_fill_value_for_missing_y: float = float('nan'), # 新增：缺失y的填充值，默认为NaN
    window_size: int = 50,
    stride: int = 40,
    target_attr_index_in_output: int = 5,
    graph_constructor_params: dict = None,
    graph_augmentor_instance=None
):
    input_dir = Path(input_dir_str)
    output_dir = Path(output_dir_str)
    input_files = sorted(input_dir.glob("Input_*.npy"))

    if not input_files:
        print(f"错误: 在 '{input_dir}' 未找到任何 Input_*.npy 文件。")
        return []

    # --- 全局初始化 (执行一次) ---
    global_tree_max, global_building_max = compute_global_maxes(str(input_dir))
    print(f"全局最大树高: {global_tree_max}, 全局最大建筑高度: {global_building_max}")

    sg_latitude, sg_longitude = 1.3521, 103.8198 # Singapore coordinates
    solar_calc = SolarCalculator(latitude=sg_latitude, longitude=sg_longitude, timezone_str='Asia/Singapore')
    weather_parser = CSVWeatherParser(csv_path=csv_path_str)

    year, month, day = target_date_tuple

    # 初始化 GraphConstructor (使用传入的参数)
    # 将 global_tree_max 和 global_building_max 添加到参数字典中
    gc_params_with_globals = graph_constructor_params.copy()
    gc_params_with_globals['global_tree_max_val'] = global_tree_max
    gc_params_with_globals['global_building_max_val'] = global_building_max
    gc_params_with_globals['target_attr_index'] = target_attr_index_in_output # Ensure this is passed

    graph_builder = GraphConstructor(**gc_params_with_globals)

    all_window_sequences = [] # 外层列表: 每个元素是一个窗口的12小时图序列

    # --- 外层循环: 遍历输入文件 ---
    for input_path in input_files:
        match = re.search(r"Input_(\d+).npy", input_path.name)
        if not match:
            print(f"跳过不匹配的文件名: {input_path.name}")
            continue
        file_id = match.group(1)
        output_path_npy = output_dir / f"Output_{file_id}.npy"

        if not output_path_npy.exists():
            print(f"跳过 {input_path.name}: 缺少对应 Output_{file_id}.npy")
            continue

        print(f"\n=== 开始处理文件对 {file_id} 的所有12小时序列 ===")
        try:
            preprocessor = ClimateDataPreprocessor(
                input_path=str(input_path),
                output_path=str(output_path_npy),
                window_size=window_size,
                stride=stride,
                global_tree_max=global_tree_max,
                global_building_max=global_building_max
            )
            input_windows, output_windows = preprocessor.process()

            # 从 output_windows (即 out_win_all_hrs 的来源) 获取 .npy 文件中的小时数
            # output_windows 的形状是 (num_windows, num_vars_output, win_h, win_w, hours_in_npy)
            # 但 preprocessor.process() 返回的是 (num_windows, win_h, win_w, num_base_features) for input
            # 和 (num_windows, num_vars_output, win_h, win_w, hours_in_output_npy) for output
            # 因此，我们需要从 preprocessor.output_data 来获取 hours_in_output_npy
            if preprocessor.output_data.ndim == 4 and preprocessor.output_data.shape[3] > 0:
                 hours_in_output_npy = preprocessor.output_data.shape[3]
            else:
                 print(f"警告: 无法从 Output_{file_id}.npy 的形状确定小时数，将假定为 {num_hours_in_sequence}")
                 hours_in_output_npy = num_hours_in_sequence

            # --- 中层循环: 遍历该文件的所有空间窗口 ---
            for window_idx, (inp_win_base_feats, out_win_all_hrs) in enumerate(zip(input_windows, output_windows)):
                current_window_hourly_graphs = []
                print(f"  -- 处理文件 {file_id}, 窗口索引 {window_idx} --")

                # --- 内层循环: 遍历序列中的每个小时 ---
                for hour_offset in range(num_hours_in_sequence):
                    actual_hour_of_day = start_hour_of_day_sequence + hour_offset

                    # 确定用于从 out_win_all_hrs 提取 y 值的索引
                    # 以及判断当前 actual_hour_of_day 的 y 值是否有效
                    y_data_slice_index = actual_hour_of_day - output_npy_effective_start_hour
                    y_is_valid_for_this_hour = False
                    target_hour_index_for_graph_build = 0 # 默认为0，如果y无效则此索引无实际意义

                    if 0 <= y_data_slice_index < hours_in_output_npy:
                        y_is_valid_for_this_hour = True
                        target_hour_index_for_graph_build = y_data_slice_index
                    # else y_data_slice_index < 0 (e.g. 7 AM < 8 AM) or y_data_slice_index >= hours_in_output_npy (e.g. 8 PM > 7 PM if npy is 8-19)

                    # 获取当前小时的太阳和天气参数 (这部分逻辑不变)
                    try:
                        current_solar_params = solar_calc.get_solar_position(year, month, day, actual_hour_of_day)
                        current_hourly_weather_params = weather_parser.get_hourly_data(month, day, actual_hour_of_day)
                    except Exception as e:
                        print(f"    错误: 无法获取 {month}-{day} {actual_hour_of_day}:00 的太阳/气象数据: {e}. 跳过此小时。")
                        # current_window_hourly_graphs.append(None) # Or handle as needed
                        print(f"    由于数据获取失败，跳过文件 {file_id}, 窗口 {window_idx}, 小时偏移 {hour_offset} 的图构建。")
                        continue

                    # 构建当前小时的图
                    try:
                        graph = graph_builder.build_graph(
                            inp_win_base_feats,
                            out_win_all_hrs,
                            current_hourly_weather_params,
                            current_solar_params,
                            # 使用计算得到的 target_hour_index_for_graph_build
                            target_hour_index_in_day=target_hour_index_for_graph_build
                        )
                        graph.file_id = file_id
                        graph.window_index = window_idx
                        graph.hour_of_day = actual_hour_of_day
                        graph.hour_index_in_sequence = hour_offset

                        # 如果当前小时的 y 值无效，则用指定值填充
                        if not y_is_valid_for_this_hour:
                            if graph.y is not None: # 确保 y 属性存在
                                fill_value_tensor = torch.full_like(graph.y, output_npy_fill_value_for_missing_y)
                                graph.y = fill_value_tensor
                                # print(f"    提示: 文件 {file_id}, 窗口 {window_idx}, 实际小时 {actual_hour_of_day}:00 的 y 值被填充为 {output_npy_fill_value_for_missing_y}")

                        if graph_augmentor_instance:
                            graph = graph_augmentor_instance.augment_static(graph)

                        current_window_hourly_graphs.append(graph)

                    except Exception as graph_err:
                        print(f"    构建图失败 (文件:{file_id}, 窗口索引:{window_idx}, 小时偏移:{hour_offset}, 实际小时:{actual_hour_of_day}): {graph_err}")
                        import traceback
                        traceback.print_exc()
                        # current_window_hourly_graphs.append(None)

                if len(current_window_hourly_graphs) == num_hours_in_sequence:
                    all_window_sequences.append(current_window_hourly_graphs)
                elif current_window_hourly_graphs:
                     print(f"  警告: 文件 {file_id}, 窗口 {window_idx} 的序列不完整 (仅 {len(current_window_hourly_graphs)}/{num_hours_in_sequence} 个图)，已跳过此窗口序列。")

        except Exception as e:
            print(f"处理文件对 {file_id} 时发生严重错误: {e}")
            import traceback
            traceback.print_exc()
            continue

    print(f"\n处理完成，共生成 {len(all_window_sequences)} 个窗口的{num_hours_in_sequence}小时图序列。")
    if all_window_sequences and all_window_sequences[0]:
        print("\n=== 抽样检查首个窗口的首个图 ===")
        # Ensure all_window_sequences[0] is not empty before accessing all_window_sequences[0][0]
        if all_window_sequences[0][0] is not None:
             generate_data_report(all_window_sequences[0][0]) # Report for the first hour of the first window
             verify_edge_structure(all_window_sequences[0][0], grid_size=graph_builder.grid_size)
        else:
            print("抽样检查失败：首个窗口的首个图未能成功构建。")

    return all_window_sequences

# ==================== 主函数 (12小时序列生成) ===================
def main_generate_12_hour_data(): # <<<< RENAMED
    base_drive_path = Path("/content/drive/MyDrive/Colab Notebooks/Graph Data Process") # 请确保路径正确
    input_dir = base_drive_path / "Input"
    output_dir = base_drive_path / "Output"
    csv_path = base_drive_path / "SGP_SINGAPORE-CHANGI-IAP_486980S_23EPW.csv"
    # save_dir = base_drive_path / "Result" / "SingleHourTest_Weight" # 旧的
    save_dir = base_drive_path / "Result" / "Sequential_12Hour_Data" # <<<< 新的保存目录名
    save_dir.mkdir(parents=True, exist_ok=True)

    # --- 时间参数 ---
    TARGET_YEAR = 2023
    TARGET_MONTH = 5
    TARGET_DAY = 3
    # 定义12小时序列的起始小时 (例如，8点对应气象数据中的8)
    # output_window_all_hours 的最后一维 (12) 对应这12个小时
    # 例如，如果 output NPY 的12个小时是从 8:00 到 19:00
    START_HOUR_OF_DAY_IN_SEQUENCE = 7
    NUM_HOURS_IN_SEQUENCE = 13

    # --- 新增参数：指明 Output.npy 文件中数据实际对应的起始小时 ---
    # 假设您的 Output_*.npy 文件中的第0个时间片对应的是早上8点的数据
    OUTPUT_NPY_EFFECTIVE_START_HOUR = 8
    # --- 新增参数：指定当y数据缺失时（如7点对比8点）的填充值 ---
    # float('nan') 表示用NaN填充，如果想用0，则改为 0.0
    #OUTPUT_NPY_FILL_VALUE = float('nan')
    OUTPUT_NPY_FILL_VALUE = 0.0 # 如果希望用0填充

    # --- GraphConstructor 参数 ---
    # 这些参数将通过字典传递给 process_12_hour_sequences
    graph_constructor_config = {
        'grid_size': 4, # Default, but can be set
        'k_similarity': 8,
        # 'target_attr_index' will be set by process_12_hour_sequences
        'base_building_shadow_radius_max_grids': 15,
        'base_tree_shadow_radius_max_grids': 5,
        'base_tree_activity_radius_max_grids': 5, # Example: t_rmax_activity_grids
        'base_edge_weight': 1.0,
        'shadow_influence_weight_factor': 1.5, # Unused currently
        'wind_alignment_weight_factor': 0.5,   # Unused currently
        'distance_decay_factor_per_grid': 0.01,
        'similarity_dist_decay_factor_per_grid': 0.005,
        'actual_shadow_boost_factor': 1.2,
        'tree_activity_height_influence_factor': 0.2, # Your chosen value
        'shadow_angular_width_deg': 25.0,
        'knn_node_feature_normalization_epsilon': 1e-5,
        'internal_n8_weight': 0.1 # Unused currently
        # global_tree_max_val and global_building_max_val will be added inside processing function
    }

    # --- 其他参数 ---
    target_variable_idx_in_output_file = 5 # 例如，如果目标变量在Output.npy的第6个变量（索引5）
    data_window_size = 50
    data_stride = 40

    augmentor = GraphAugmentor(add_neighbor_agg=True, add_edge_diff=False) # Or None

    print(f"\n--- 开始为 {TARGET_YEAR}-{TARGET_MONTH}-{TARGET_DAY} 生成 {NUM_HOURS_IN_SEQUENCE} 小时图序列数据 ---")
    print(f"序列起始小时 (真实小时): {START_HOUR_OF_DAY_IN_SEQUENCE}")
    print(f"GraphConstructor 参数: {graph_constructor_config}")

    all_sequences = process_12_hour_sequences(
        input_dir_str=str(input_dir),
        output_dir_str=str(output_dir),
        csv_path_str=str(csv_path),
        target_date_tuple=(TARGET_YEAR, TARGET_MONTH, TARGET_DAY),
        start_hour_of_day_sequence=START_HOUR_OF_DAY_IN_SEQUENCE,
        num_hours_in_sequence=NUM_HOURS_IN_SEQUENCE,
        output_npy_effective_start_hour=OUTPUT_NPY_EFFECTIVE_START_HOUR, # 传递新参数
        output_npy_fill_value_for_missing_y=OUTPUT_NPY_FILL_VALUE,   # 传递新参数
        window_size=data_window_size,
        stride=data_stride,
        target_attr_index_in_output=target_variable_idx_in_output_file,
        graph_constructor_params=graph_constructor_config,
        graph_augmentor_instance=augmentor
    )

    if all_sequences:
        save_name = (
            f"graph_seq_{TARGET_YEAR}{TARGET_MONTH:02d}{TARGET_DAY:02d}_"
            f"SeqH{START_HOUR_OF_DAY_IN_SEQUENCE}to{START_HOUR_OF_DAY_IN_SEQUENCE+NUM_HOURS_IN_SEQUENCE-1}_"
            f"NpyH{OUTPUT_NPY_EFFECTIVE_START_HOUR}fill{str(OUTPUT_NPY_FILL_VALUE)}.pkl"
        )
        save_path = save_dir / save_name
        with open(save_path, "wb") as f:
            pickle.dump(all_sequences, f)
        print(f"\n[已保存] {NUM_HOURS_IN_SEQUENCE}小时图序列数据 => {save_path}")

        # --- 新增：生成并打印第一个完整序列的 Y 值报告 ---
        print("\n\n=== 开始生成Y值详细报告 (针对第一个成功处理的空间窗口的完整时间序列) ===")
        # all_sequences 是一个列表的列表，外层列表代表不同的空间窗口滑动的结果
        # 内层列表代表一个空间窗口随时间变化的图序列
        # 我们报告第一个成功处理的空间窗口的整个时间序列
        first_complete_sequence_found = False
        for seq_idx, window_sequence in enumerate(all_sequences):
            if window_sequence and all(graph is not None for graph in window_sequence): # 确保序列非空且内部图都存在
                generate_sequence_y_report(window_sequence, sequence_index=seq_idx)
                first_complete_sequence_found = True
                break # 只报告第一个找到的完整序列
            elif window_sequence: # 序列存在，但可能包含None
                # 可选：报告不完整的序列信息
                num_valid_graphs = sum(1 for g in window_sequence if g is not None)
                print(f"提示: 序列 {seq_idx} 不完整或包含None图 (有效图: {num_valid_graphs}/{len(window_sequence)})，跳过其Y值报告。")


        if not first_complete_sequence_found:
            print("\n提示: 未找到任何完整的图序列进行Y值报告。")
        # --- Y 值报告结束 ---
    else:
        print("\n未生成任何图序列数据。")


if __name__ == "__main__":
    # main_one_hour_test_with_weights() # Old call
    main_generate_12_hour_data()      # <<<< New call
    pass