In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress, spearmanr
from scipy.optimize import least_squares
from sklearn.metrics import mean_squared_error
import os
import warnings
import multiprocessing as mp
from functools import partial
import itertools
from numba import jit, prange
import gc

warnings.filterwarnings('ignore')

# ============================================================================
# 全局配置
# ============================================================================
N_WORKERS = min(mp.cpu_count() - 1, 64)  # 80
print(f"Using {N_WORKERS} workers for parallel processing")

# ============================================================================
# 全局变量（用于多进程共享数据）
# ============================================================================
_GLOBAL_NODE_DATA = {}
_GLOBAL_WIDTH_STATS = None
_GLOBAL_SWOT_DATA = None
_GLOBAL_FITTER = None
_GLOBAL_QC_DATA = None
_GLOBAL_SMOOTH_WINDOW = 5  # 【新增】全局滑动窗口大小

# ============================================================================
# Numba加速的核心计算函数
# ============================================================================
@jit(nopython=True, parallel=True, cache=True)
def compute_inconsistency_matrix(w, h):
    """使用Numba加速计算不一致性矩阵"""
    n = len(w)
    inverse = np.zeros(n, dtype=np.int64)
    for i in prange(n):
        count = 0
        for j in range(n):
            w_diff = w[i] - w[j]
            h_diff = h[i] - h[j]
            if w_diff * h_diff < 0:
                count += 1
        inverse[i] = count
    return inverse

@jit(nopython=True, cache=True)
def calculate_areas_numba(w_list, h_list, w50, a50):
    """使用Numba加速面积计算"""
    n = len(w_list)
    areas = np.full(n, np.nan)
    
    # 边界检查：如果数据点太少，直接返回
    if n < 2:
        return areas
    
    idx50 = np.searchsorted(w_list, w50)
    if idx50 >= n:
        idx50 = n - 1
    if idx50 < 1:
        idx50 = 1
    
    # 防止除零错误
    denom = w_list[idx50] - w_list[idx50-1]
    if abs(denom) < 1e-10:
        # 如果宽度差太小，使用平均值
        h50 = (h_list[idx50-1] + h_list[idx50]) / 2.0
    else:
        h50 = (h_list[idx50-1] * (w_list[idx50] - w50) +
               h_list[idx50] * (w50 - w_list[idx50-1])) / denom
    
    areas[idx50] = a50 + 0.5 * (w50 + w_list[idx50]) * (h_list[idx50] - h50)
    
    for i in range(idx50 + 1, n):
        areas[i] = areas[i-1] + 0.5 * (w_list[i-1] + w_list[i]) * \
                  (h_list[i] - h_list[i-1])
    
    for i in range(idx50 - 1, -1, -1):
        areas[i] = areas[i+1] - 0.5 * (w_list[i+1] + w_list[i]) * \
                  (h_list[i+1] - h_list[i])
    
    return areas

@jit(nopython=True, cache=True)
def nse_numba(observed, simulated):
    """Numba加速的NSE计算"""
    obs_mean = np.mean(observed)
    numerator = np.sum((observed - simulated)**2)
    denominator = np.sum((observed - obs_mean)**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

@jit(nopython=True, cache=True)
def kge_numba(observed, simulated):
    """Numba加速的KGE计算"""
    obs_mean = np.mean(observed)
    sim_mean = np.mean(simulated)
    obs_std = np.std(observed)
    sim_std = np.std(simulated)
    
    if obs_std == 0 or sim_std == 0 or obs_mean == 0 or sim_mean == 0:
        return np.nan
    
    n = len(observed)
    cov = np.sum((observed - obs_mean) * (simulated - sim_mean)) / n
    r = cov / (obs_std * sim_std)
    
    alpha = sim_mean / obs_mean
    beta = (sim_std / sim_mean) / (obs_std / obs_mean)  # 变异系数比
    
    return 1 - np.sqrt((r - 1)**2 + (alpha - 1)**2 + (beta - 1)**2)

# ============================================================================
# 全局并行处理函数（必须在模块级别定义才能被pickle）
# ============================================================================
def _compute_node_corr(node_id):
    """计算单个节点的秩相关系数"""
    global _GLOBAL_NODE_DATA
    data = _GLOBAL_NODE_DATA.get(node_id)
    if data is None or len(data['width']) < 5:
        return (node_id, data['stationid'] if data else None, 0.0)
    
    try:
        corr, _ = spearmanr(data['width'], data['wse'])
        if np.isnan(corr):
            corr = 0.0
    except:
        corr = 0.0
    
    return (node_id, data['stationid'], corr)

def _process_qc_station(stationid):
    """处理单个站点的质控（全局函数版本）"""
    global _GLOBAL_WIDTH_STATS, _GLOBAL_SWOT_DATA
    
    if _GLOBAL_WIDTH_STATS is None or stationid not in _GLOBAL_WIDTH_STATS.index:
        return None
    
    df = _GLOBAL_SWOT_DATA[_GLOBAL_SWOT_DATA['stationid'] == stationid].copy()
    if len(df) < 5:
        return None
    
    # 步骤1: 不确定度筛选
    df['width_u_r'] = df['width_u'] / df['width']
    df1 = df[(df['wse_u'] <= 0.4) & (df['width_u_r'] <= 0.1)]
    if len(df1) < 5:
        return None
    
    # 步骤2: 顺序一致性剔除
    df2 = _remove_inconsistent_points(df1)
    if len(df2) < 5:
        return None
    
    # 步骤3: 离群值剔除
    w_low = _GLOBAL_WIDTH_STATS.loc[stationid, 'w_low']
    w_high = _GLOBAL_WIDTH_STATS.loc[stationid, 'w_high']
    d_bankfull = 0.27 * (w_high / 7.2) ** 0.6
    h50 = df2['wse'].median()
    
    df3 = df2[(df2['wse'] <= h50 + d_bankfull) & (df2['wse'] >= h50 - d_bankfull)]
    
    return df3 if len(df3) >= 5 else None

def _remove_inconsistent_points(df, inverse_ratio_thresh=0.5):
    """顺序一致性剔除"""
    indices_to_keep = list(df.index)
    
    while True:
        n = len(indices_to_keep)
        if n < 5:
            break
        
        df_current = df.loc[indices_to_keep]
        w = df_current['width'].values.astype(np.float64)
        h = df_current['wse'].values.astype(np.float64)
        
        inverse = compute_inconsistency_matrix(w, h)
        
        idx_max = np.argmax(inverse)
        if inverse[idx_max] / n < inverse_ratio_thresh:
            break
        
        indices_to_keep.pop(idx_max)
    
    return df.loc[indices_to_keep]

def _fit_station_wrapper(stationid):
    """拟合单个站点（全局函数版本）"""
    global _GLOBAL_FITTER, _GLOBAL_QC_DATA
    
    df_station = _GLOBAL_QC_DATA[_GLOBAL_QC_DATA['stationid'] == stationid]
    if len(df_station) == 0:
        return None
    
    comid = df_station.iloc[0]['COMID']
    return _GLOBAL_FITTER.fit_station(df_station, stationid, comid)

def _validate_station_wrapper(args):
    """验证单个站点（全局函数版本）"""
    # 添加 skip_width_filter 参数
    s, df_hypso, df_width, df_val_folder, df_fit, start_date, skip_width_filter = args
    
    file_path = os.path.join(df_val_folder, f'{s}.csv')
    if not os.path.exists(file_path):
        return None
    
    try:
        df_val = pd.read_csv(file_path)
        num_days = len(df_val)
        df_val['date'] = pd.date_range(start=start_date, periods=num_days, freq='D')
        df_val['stationid'] = s
        df_val = df_val.dropna(subset=['qobs'])
        
        df_width_s = df_width[df_width['stationid'] == s]
        df_val = df_val.merge(df_width_s, on=['stationid', 'date'], how='inner')
        
        df_curve = df_hypso[df_hypso['stationid'] == s].reset_index(drop=True)
        station_fit = df_fit[df_fit['stationid'] == s]
        if station_fit.empty or df_curve.empty:
            return None
        
        row = station_fit.iloc[0]
        w_low, w_high, slp = row['w_low'], row['w_high'], row['slp']
        
        # 根据 skip_width_filter 决定是否进行宽度筛选
        if skip_width_filter:
            # datemean模式：不对width做筛选，只去重
            df_val = df_val.drop_duplicates('date')
        else:
            # node模式：正常进行宽度筛选
            df_val = df_val[
                (df_val['width'] >= w_low) &
                (df_val['width'] <= w_high)
            ].drop_duplicates('date')
        
        if len(df_val) < 10:
            return None
        
        # 向量化插值
        curve_width = df_curve['width'].values
        curve_area = df_curve['area'].values
        val_width = df_val['width'].values
        
        area_hypso = np.interp(val_width, curve_width, curve_area)
        
        df_val['area_hypso'] = area_hypso
        df_val['Q_est'] = (area_hypso**(5/3) * val_width**(-2/3) * slp**0.5 / 0.035)
        
        df_val = df_val.dropna()
        if len(df_val) < 10:
            return None
        
        obs = df_val['qobs'].values.astype(np.float64)
        sim = df_val['Q_est'].values.astype(np.float64)
        
        kge_val = kge_numba(obs, sim)
        nse_val = nse_numba(obs, sim)
        rmse = np.sqrt(np.mean((obs - sim)**2))
        nrmse_val = rmse / np.mean(obs)
        
        df_val['kge'] = kge_val
        df_val['nse'] = nse_val
        df_val['nrmse'] = nrmse_val
        
        return df_val[['stationid', 'date', 'width', 'area_hypso',
                       'qobs', 'Q_est', 'kge', 'nse', 'nrmse']]
    except Exception as e:
        print(f"Error processing station {s}: {e}")
        return None

def _rolling_median_group(group):
    """滑动中值处理 - 使用全局窗口大小"""
    global _GLOBAL_SMOOTH_WINDOW
    window = _GLOBAL_SMOOTH_WINDOW
    group = group.sort_values('date')
    group['width'] = group['width'].rolling(window=window, center=True, min_periods=1).median()
    group['wse'] = group['wse'].rolling(window=window, center=True, min_periods=1).median()
    return group

# ============================================================================
# 模块1: 数据统计工具
# ============================================================================
class WidthStatistics:
    """计算河流宽度的统计特征"""
    
    # 定义percentile配置：{选项: (低分位数, 高分位数)}
    PERCENTILE_CONFIG = {
        '0': (0, 100),    # 使用全部数据范围
        '3': (3, 97),
        '5': (5, 95),
        '10': (10, 90)
    }
    
    @staticmethod
    def calculate_width_percentiles(df, min_width=30, valid_ratio=0.95):
        """计算每个站点的宽度分位数"""
        stationids = df['stationid'].unique()
        result_data = []
        
        for stationid in stationids:
            station_data_all = df[df['stationid'] == stationid]['width'].dropna()
            station_data = station_data_all[station_data_all >= min_width]
            
            if len(station_data_all) == 0:
                continue
            if len(station_data) / len(station_data_all) < valid_ratio:
                continue
            
            if len(station_data) > 10:
                w50 = station_data.median()
                
                # 动态计算所有配置的分位数
                row_data = {'stationid': stationid, 'w50': w50}
                
                for key, (low_pct, high_pct) in WidthStatistics.PERCENTILE_CONFIG.items():
                    w_low, w_high = np.percentile(station_data, [low_pct, high_pct])
                    row_data[f'w{low_pct}_low'] = w_low
                    row_data[f'w{high_pct}_high'] = w_high
                
                result_data.append(row_data)
        
        return pd.DataFrame(result_data)
    
    @staticmethod
    def get_percentile_columns(b_option):
        """根据B选项获取对应的列名"""
        if b_option not in WidthStatistics.PERCENTILE_CONFIG:
            raise ValueError(f"Invalid B option: {b_option}. Valid options: {list(WidthStatistics.PERCENTILE_CONFIG.keys())}")
        
        low_pct, high_pct = WidthStatistics.PERCENTILE_CONFIG[b_option]
        return f'w{low_pct}_low', f'w{high_pct}_high'

# ============================================================================
# 模块2: 节点选择
# ============================================================================
class NodeSelector:
    """为每个站点选择最优节点"""
    
    @staticmethod
    def select_best_nodes(df_swot, min_data_points=10):
        """为每个站点选择秩相关系数最大的节点"""
        global _GLOBAL_NODE_DATA
        
        # 向量化计算节点数据量
        node_counts = df_swot.groupby('node_id').size()
        valid_nodes = node_counts[node_counts >= min_data_points].index
        df_swot = df_swot[df_swot['node_id'].isin(valid_nodes)].copy()
        
        # 预计算每个节点的数据
        _GLOBAL_NODE_DATA = {}
        for node_id, group in df_swot.groupby('node_id'):
            _GLOBAL_NODE_DATA[node_id] = {
                'width': group['width'].values,
                'wse': group['wse'].values,
                'stationid': group['stationid'].iloc[0]
            }
        
        node_ids = list(_GLOBAL_NODE_DATA.keys())
        
        # 使用进程池并行计算
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_compute_node_corr, node_ids)
        
        df_node = pd.DataFrame(results, columns=['node_id', 'stationid', 'rank_corr'])
        df_node = df_node.dropna(subset=['stationid'])
        
        # 选择每个站点的最大秩相关系数节点
        max_idx = df_node.groupby('stationid')['rank_corr'].idxmax()
        df_node_rmax = df_node.loc[max_idx]
        
        # 筛选数据
        df_filtered = df_swot[df_swot['node_id'].isin(df_node_rmax['node_id'])]
        
        print(f"Original nodes: {len(df_node)}, Selected nodes: {len(df_node_rmax)}")
        
        # 清理全局变量
        _GLOBAL_NODE_DATA = {}
        
        return df_filtered, df_node_rmax

# ============================================================================
# 模块3: 数据质控
# ============================================================================
class DataQualityControl:
    """SWOT数据质量控制"""
    
    def __init__(self, width_stats):
        self.width_stats = width_stats.set_index('stationid')
    
    def apply_qc(self, df_swot, draw_figure=False, output_folder=None):
        """应用完整的质量控制流程"""
        global _GLOBAL_WIDTH_STATS, _GLOBAL_SWOT_DATA
        
        _GLOBAL_WIDTH_STATS = self.width_stats
        _GLOBAL_SWOT_DATA = df_swot
        
        stationids = df_swot['stationid'].unique()
        
        # 使用进程池并行处理
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_process_qc_station, stationids)
        
        # 清理全局变量
        _GLOBAL_WIDTH_STATS = None
        _GLOBAL_SWOT_DATA = None
        
        # 合并结果
        results = [df for df in results if df is not None]
        
        if not results:
            return pd.DataFrame()
        
        result = pd.concat(results)
        result = result.drop_duplicates(subset=['node_id', 'date', 'stationid'])
        result.reset_index(drop=True, inplace=True)
        
        return result

# ============================================================================
# 模块4: 水位-面积曲线拟合
# ============================================================================
class HydraulicCurveFitter:
    """拟合水位-宽度关系曲线"""
    
    def __init__(self, width_stats, river_attrs, skip_width_filter=False):
        """
        Parameters:
        -----------
        width_stats : DataFrame
            宽度统计数据
        river_attrs : DataFrame
            河流属性数据
        skip_width_filter : bool
            是否跳过宽度筛选（用于datemean模式）
        """
        self.width_stats = width_stats.set_index('stationid')
        self.river_attrs = river_attrs.set_index('COMID')
        self.skip_width_filter = skip_width_filter
        
        self.R_list = np.array([0.5, 1, 2, 4, 8])
        self.GAP_list = np.array([-0.3, -0.1, 0, 0.1, 0.3])
        self.W_list = np.array([0.3, 0.5, 0.7])
    
    @staticmethod
    def power_function(params, X, y):
        wse0, a, b = params
        return y - (wse0 + a * X**b)
    
    def loss_function(self, z, weight, n_swot):
        rho = np.zeros((3, len(z)))
        rho[0] = 2 * ((1 + z)**0.5 - 1)
        rho[1] = (1 + z)**(-0.5)
        rho[2] = -0.5 * (1 + z)**(-1.5)
        
        factor = (n_swot - 2) / weight * (1 - weight) / 2
        rho[:, 0] *= factor
        rho[:, 1] *= factor
        
        return rho
    
    def calculate_h50(self, df, w50):
        df = df.copy()
        df['w50_diff'] = np.abs(df['width'] - w50)
        df = df.sort_values('w50_diff')
        
        xdata = df.iloc[:5]['width'].values
        ydata = df.iloc[:5]['wse'].values
        xdata_uni = np.unique(xdata)
        
        if len(xdata_uni) < 2:
            return df.iloc[:5]['wse'].mean()
        
        res = linregress(xdata, ydata)
        if res[0] >= 0:
            return res[0] * w50 + res[1]
        else:
            return df.iloc[:5]['wse'].mean()
    
    def fit_station(self, df_station, stationid, comid):
        """拟合单个站点的水位-宽度关系"""
        if stationid not in self.width_stats.index:
            return None
        if comid not in self.river_attrs.index:
            return None
        
        try:
            q50 = self.river_attrs.loc[comid, 'q50_weighted']
            slp = self.river_attrs.loc[comid, 'slope']
            w50, w_low, w_high = self.width_stats.loc[stationid, ['w50', 'w_low', 'w_high']]
            d_bankfull = 0.27 * (w_high / 7.2)**0.6
            
            h50 = self.calculate_h50(df_station, w50)
            a50 = (q50 * 0.035 / slp**0.5 * w50**(2/3))**(3/5)
            
            # 根据skip_width_filter决定是否进行宽度筛选
            if self.skip_width_filter:
                # datemean模式：不对width做筛选
                df_filtered = df_station.copy()
            else:
                # node模式：正常进行宽度筛选
                df_filtered = df_station[
                    (df_station['width'] >= w_low) &
                    (df_station['width'] <= w_high)
                ]
            
            if len(df_filtered) < 3:
                return None
            
            swot_wsemax = df_filtered.sort_values('wse', ascending=False).iloc[0]
            d_wsemax = 0.27 * (swot_wsemax['width'] / 7.2)**0.6
            
            results = []
            for r_low in self.R_list:
                for gap in self.GAP_list:
                    for weight in self.W_list:
                        result = self._fit_single_config(
                            df_filtered, r_low, gap, weight,
                            w_low, w_high, w50, h50, a50,
                            swot_wsemax, d_bankfull, d_wsemax, slp, q50
                        )
                        if result is not None:
                            result.update({
                                'stationid': stationid,
                                'COMID': comid,
                                'R': r_low,
                                'GAP': gap,
                                'W': weight
                            })
                            results.append(result)
            
            return pd.DataFrame(results) if results else None
        except Exception as e:
            return None
    
    def _fit_single_config(self, df, r_low, gap, weight, w_low, w_high, w50,
                          h50, a50, swot_wsemax, d_bankfull, d_wsemax, slp, q50):
        """拟合单个参数配置"""
        a_low = a50 * (r_low + 1) / r_low / w50**(r_low + 1)
        h0 = h50 - a_low * w50**r_low
        h_low = h0 + a_low * w_low**r_low
        h_high = swot_wsemax['wse'] + (d_bankfull - d_wsemax) + gap * d_bankfull
        
        xdata = np.insert(df['width'].values, 0, [w_low, w_high])
        ydata = np.insert(df['wse'].values, 0, [h_low, h_high])
        a_default = (h_high - h0) / w_high**2
        
        n_swot = len(df)
        
        def loss_wrapper(z):
            return self.loss_function(z, weight, n_swot)
        
        try:
            ls = least_squares(
                self.power_function,
                x0=[h0, a_default, 2],
                loss=loss_wrapper,
                args=(xdata, ydata),
                max_nfev=100
            )
            
            if ls.status > 0:
                wse0, a, b = ls.x
                if a * b < 0:
                    return None
                
                return {
                    'wse0': wse0, 'a': a, 'b': b,
                    'a50': a50, 'w50': w50, 'q50': q50,
                    'w_low': w_low, 'w_high': w_high,
                    'h_low': h_low, 'h_high': h_high,
                    'slp': slp
                }
        except:
            pass
        
        return None
    
    def fit_all_stations(self, df_qc):
        """并行拟合所有站点"""
        global _GLOBAL_FITTER, _GLOBAL_QC_DATA
        
        _GLOBAL_FITTER = self
        _GLOBAL_QC_DATA = df_qc
        
        unique_stations = df_qc['stationid'].unique()
        
        # 使用进程池并行处理
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_fit_station_wrapper, unique_stations)
        
        # 清理全局变量
        _GLOBAL_FITTER = None
        _GLOBAL_QC_DATA = None
        
        results = [df for df in results if df is not None]
        
        if results:
            return pd.concat(results, ignore_index=True)
        return None

# ============================================================================
# 模块5: 水位-面积曲线生成
# ============================================================================
class HypsometricCurveGenerator:
    """生成水位-面积关系曲线"""
    
    @staticmethod
    def generate_curves(df_fit, n_points=100):
        """为所有站点生成中值水位-面积曲线"""
        stationids = sorted(df_fit['stationid'].unique())
        df_res = []
        
        for s in stationids:
            df_station = df_fit[df_fit['stationid'] == s]
            w_low, w_high, w50, a50 = df_station.iloc[0][
                ['w_low', 'w_high', 'w50', 'a50']
            ]
            
            # 边界检查：跳过无效的宽度范围
            if w_high <= w_low or abs(w_high - w_low) < 1e-6:
                print(f"Warning: Skipping station {s} due to invalid width range (w_low={w_low}, w_high={w_high})")
                continue
            
            wse0 = df_station['wse0'].values
            a = df_station['a'].values
            b = df_station['b'].values
            
            w_list = np.linspace(w_low, w_high, n_points)
            
            # 向量化计算
            heights_all = wse0[:, np.newaxis] + a[:, np.newaxis] * w_list**b[:, np.newaxis]
            h_list = np.median(heights_all, axis=0)
            hmax = np.max(heights_all, axis=0)
            hmin = np.min(heights_all, axis=0)
            
            # Numba加速的面积计算
            areas = calculate_areas_numba(w_list, h_list, w50, a50)
            
            df_curve = pd.DataFrame({
                'stationid': s,
                'width': w_list,
                'wse': h_list,
                'wse_max': hmax,
                'wse_min': hmin,
                'area': areas
            })
            
            df_res.append(df_curve)
        
        return pd.concat(df_res, ignore_index=True) if df_res else pd.DataFrame()

# ============================================================================
# 模块6: 验证与评估
# ============================================================================
class ModelValidator:
    """模型验证与性能评估"""
    
    @staticmethod
    def relative_rmse(observed, simulated):
        rmse = np.sqrt(mean_squared_error(observed, simulated))
        return rmse / np.mean(observed)
    
    def validate(self, df_hypso, df_width, df_val_folder, df_fit, skip_width_filter=False):
        """
        验证模型性能
        
        Parameters:
        -----------
        skip_width_filter : bool
            是否跳过宽度筛选（用于datemean模式）
        """
        stationids = sorted(df_hypso['stationid'].unique())
        start_date = pd.to_datetime('1979-01-01')
        
        # 在参数列表中添加 skip_width_filter
        args_list = [
            (s, df_hypso, df_width, df_val_folder, df_fit, start_date, skip_width_filter)
            for s in stationids
        ]
        
        # 使用进程池并行处理
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_validate_station_wrapper, args_list)
        
        results = [df for df in results if df is not None]
        
        return pd.concat(results, ignore_index=True) if results else pd.DataFrame()

# ============================================================================
# 配置运行函数
# ============================================================================
def run_configuration(a, b, c, d, e, common_data):
    """
    运行单个配置
    
    Parameters:
    -----------
    a : str
        处理方式: 'node' 或 'datemean'
    b : str
        分位数选项: '0', '3', '5', '10'
    c : str
        QA选项: 'noqa', 'qaloose', 'qastrict'
    d : str
        版本选项: 'VersionD', 'VersionC'
    e : int
        【新增】滑动窗口大小: 3, 5, 7 (仅对datemean有效，node模式时为None)
    common_data : dict
        共享数据
    """
    global _GLOBAL_SMOOTH_WINDOW
    
    # 【新增】构建配置标识字符串
    if a == 'datemean':
        config_str = f'{a}_{b}_{c}_{d}_w{e}'
    else:
        config_str = f'{a}_{b}_{c}_{d}'
    
    print(f"\n{'='*60}")
    print(f"Running configuration: A={a}, B={b}, C={c}, D={d}, E={e}")
    print(f"{'='*60}")
    
    gc.collect()
    
    # 步骤1: 计算宽度统计
    print("Step 1: Calculating width statistics...")
    df_l8 = common_data['df_l8']
    df_w_stats = WidthStatistics.calculate_width_percentiles(df_l8)
    
    # 使用优化后的方法获取列名
    low_col, high_col = WidthStatistics.get_percentile_columns(b)
    df_w_stats['w_low'] = df_w_stats[low_col]
    df_w_stats['w_high'] = df_w_stats[high_col]
    
    df_w_stats.to_csv(f'1.width_statistic_{config_str}.csv', index=False)
    
    df_comid = common_data['df_comid']
    df_attrs = common_data['df_attrs']

    # 根据c选项加载SWOT数据
    if c == 'noqa':
        # noqa时文件名固定
        df_swot = pd.read_csv(f'1.all_matched_points_{d}.csv')
    else:
        # qaloose或qastrict时，文件名根据a确定
        df_swot = pd.read_csv(f'2.swot_{a}_{c}_{d}.csv')
   
    df_swot = df_swot.merge(df_comid, on='stationid', how='inner')

    if a == 'node':
        print("Step 2: Selecting best nodes...")
        df_swot_filtered, df_nodes = NodeSelector.select_best_nodes(df_swot, min_data_points=10)
        
        print("Step 3: Applying quality control...")
        qc = DataQualityControl(df_w_stats)
        df_qc = qc.apply_qc(df_swot_filtered, draw_figure=False)
        
    elif a == 'datemean':
        print(f"Using smoothed data (skipping width filter, window={e})...")
        
        # 【修改】设置全局滑动窗口大小
        _GLOBAL_SMOOTH_WINDOW = e
        
        # 使用进程池并行处理滑动中值
        groups = [group for _, group in df_swot.groupby('stationid')]
        
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_rolling_median_group, groups)
        
        df_qc = pd.concat(results)
    
    if 'COMID' not in df_qc.columns:
        df_qc = df_qc.merge(df_comid, on='stationid', how='left')
    
    cols = ['COMID'] + [col for col in df_qc.columns if col != 'COMID']
    df_qc = df_qc[cols]
    df_qc.to_csv(f'2.swot-points-selection_{config_str}.csv', index=False)
    
    # 步骤4: 拟合
    print("Step 4: Fitting hydraulic curves...")
    # 根据a选项决定是否跳过宽度筛选
    skip_width_filter = (a == 'datemean')
    fitter = HydraulicCurveFitter(df_w_stats, df_attrs, skip_width_filter=skip_width_filter)
    df_fit_all = fitter.fit_all_stations(df_qc)
    
    if df_fit_all is None or len(df_fit_all) == 0:
        print(f"No fit data for {config_str}")
        return
    
    df_fit_all.to_csv(f'3.fit_proba_modified_q50_{config_str}.csv', index=False)
    
    # 步骤5: 生成曲线
    print("Step 5: Generating hypsometric curves...")
    df_hypso = HypsometricCurveGenerator.generate_curves(df_fit_all)
    
    if df_hypso is None or len(df_hypso) == 0:
        print(f"No hypsometric curves generated for {config_str}")
        return
    
    df_hypso.to_csv(f'4.hypso_med_modified_q50_{config_str}.csv', index=False)
    
    # 步骤6: 验证
    print("Step 6: Validating model...")
    validator = ModelValidator()
    df_width = common_data['df_width']
    
    # 传入 skip_width_filter 参数
    df_results = validator.validate(
        df_hypso, df_width,
        '/home/xj/device5/data/daily_Q',
        df_fit_all,
        skip_width_filter=skip_width_filter
    )
    
    if df_results is None or len(df_results) == 0:
        print(f"No validation results for {config_str}")
        return
    
    df_results.to_csv(f'5.q_kge_med_modified_q50_{config_str}.csv', index=False)
    
    print(f"Configuration {config_str} completed!")
    gc.collect()

# ============================================================================
# 配置生成函数
# ============================================================================
def generate_configs():
    """
    生成所有有效的配置组合
    规则: 
    - noqa只和node组合
    - qaloose和qastrict可以和所有A选项组合
    - 【新增】datemean模式有额外的E选项（滑动窗口大小: 3, 5, 7）
    - node模式的E选项为None
    """
    A_options = ['datemean', 'node']
    B_options = ['0', '3', '5', '10'] #percentile
    C_options = ['noqa', 'qaloose', 'qastrict']
    D_options = ['VersionD', 'VersionC']
    E_options = [3, 5, 7]  # 【新增】滑动窗口大小选项
    
    configs = []
    
    for a in A_options:
        for b in B_options:
            for c in C_options:
                for d in D_options:
                    # noqa只和node组合
                    if c == 'noqa' and a != 'node':
                        continue
                    
                    if a == 'datemean':
                        # 【新增】datemean模式遍历所有窗口大小
                        for e in E_options:
                            configs.append((a, b, c, d, e))
                    else:
                        # node模式，E为None
                        configs.append((a, b, c, d, None))
    
    return configs

# ============================================================================
# 主程序
# ============================================================================
def main():
    """主程序流程"""
    import time
    total_start = time.time()
    
    print("Loading common data...")
    df_l8 = pd.read_csv('../2-preprocess/1.gages3000_glow_datemean_width_timeseries.csv')
    df_comid = pd.read_csv('../2-preprocess/4.q50_weighted_slp.csv')[['stationid', 'COMID']]
    df_attrs = pd.read_csv('../2-preprocess/4.q50_weighted_slp.csv')
    df_width = pd.read_csv('../2-preprocess/1.gages3000_glow_datemean_width_timeseries.csv')
    df_width['date'] = pd.to_datetime(df_width['date'])
    
    common_data = {
        'df_l8': df_l8,
        'df_comid': df_comid,
        'df_attrs': df_attrs,
        'df_width': df_width
    }
    
    # 生成有效配置
    configs = generate_configs()
    
    print(f"\nTotal configurations to run: {len(configs)}")
    print("Configurations:")
    for cfg in configs:
        print(f"  {cfg}")
    
    # 运行所有配置
    for a, b, c, d, e in configs:
        start = time.time()
        run_configuration(a, b, c, d, e, common_data)
        print(f"Time for ({a}, {b}, {c}, {d}, {e}): {time.time() - start:.2f}s")
    
    # 生成箱型图
    print("\nGenerating boxplot comparisons...")
    metrics = ['kge', 'nse', 'nrmse']
    data_dict = {metric: [] for metric in metrics}
    labels = []
    
    for a, b, c, d, e in configs:
        # 【修改】根据配置构建文件名
        if a == 'datemean':
            config_str = f'{a}_{b}_{c}_{d}_w{e}'
        else:
            config_str = f'{a}_{b}_{c}_{d}'
        
        file = f'5.q_kge_med_modified_q50_{config_str}.csv'
        if os.path.exists(file):
            df = pd.read_csv(file)
            labels.append(config_str)
            for metric in metrics:
                station_metrics = df.groupby('stationid')[metric].mean().values
                data_dict[metric].append(station_metrics)
        
    for metric in metrics:
        if data_dict[metric]:
            fig, ax = plt.subplots(figsize=(20, 8))  # 【修改】增大图片尺寸
            ax.boxplot(data_dict[metric], labels=labels)
            ax.set_title(f'{metric.upper()} Boxplot Comparison')
            ax.set_xlabel('Configuration (A_B_C_D_E)')
            ax.set_ylabel(metric.upper())
            plt.xticks(rotation=90, ha='center')  # 【修改】旋转90度以适应更多标签
            plt.tight_layout()
            plt.savefig(f'boxplot_{metric}.png', dpi=150)
            plt.close()
    
    print(f"\nTotal time: {time.time() - total_start:.2f}s")

if __name__ == '__main__':
    main()

Using 64 workers for parallel processing
Loading common data...

Total configurations to run: 72
Configurations:
  ('datemean', '0', 'qaloose', 'VersionD', 3)
  ('datemean', '0', 'qaloose', 'VersionD', 5)
  ('datemean', '0', 'qaloose', 'VersionD', 7)
  ('datemean', '0', 'qaloose', 'VersionC', 3)
  ('datemean', '0', 'qaloose', 'VersionC', 5)
  ('datemean', '0', 'qaloose', 'VersionC', 7)
  ('datemean', '0', 'qastrict', 'VersionD', 3)
  ('datemean', '0', 'qastrict', 'VersionD', 5)
  ('datemean', '0', 'qastrict', 'VersionD', 7)
  ('datemean', '0', 'qastrict', 'VersionC', 3)
  ('datemean', '0', 'qastrict', 'VersionC', 5)
  ('datemean', '0', 'qastrict', 'VersionC', 7)
  ('datemean', '3', 'qaloose', 'VersionD', 3)
  ('datemean', '3', 'qaloose', 'VersionD', 5)
  ('datemean', '3', 'qaloose', 'VersionD', 7)
  ('datemean', '3', 'qaloose', 'VersionC', 3)
  ('datemean', '3', 'qaloose', 'VersionC', 5)
  ('datemean', '3', 'qaloose', 'VersionC', 7)
  ('datemean', '3', 'qastrict', 'VersionD', 3)
  ('da