In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress, spearmanr
from scipy.optimize import least_squares
from sklearn.metrics import mean_squared_error
import os
import warnings
import multiprocessing as mp
from functools import partial
import itertools
from numba import jit, prange
import gc

warnings.filterwarnings('ignore')

# ============================================================================
# 全局配置
# ============================================================================
N_WORKERS = min(mp.cpu_count() - 1, 64)  # 80
print(f"Using {N_WORKERS} workers for parallel processing")

# ============================================================================
# 全局变量（用于多进程共享数据）
# ============================================================================
_GLOBAL_NODE_DATA = {}
_GLOBAL_WIDTH_STATS = None
_GLOBAL_SWOT_DATA = None
_GLOBAL_FITTER = None
_GLOBAL_QC_DATA = None

# 新增：用于收集诊断信息的全局变量
_GLOBAL_DROP_REASONS = None

# ============================================================================
# Numba加速的核心计算函数
# ============================================================================
@jit(nopython=True, parallel=True, cache=True)
def compute_inconsistency_matrix(w, h):
    """使用Numba加速计算不一致性矩阵"""
    n = len(w)
    inverse = np.zeros(n, dtype=np.int64)
    for i in prange(n):
        count = 0
        for j in range(n):
            w_diff = w[i] - w[j]
            h_diff = h[i] - h[j]
            if w_diff * h_diff < 0:
                count += 1
        inverse[i] = count
    return inverse

@jit(nopython=True, cache=True)
def calculate_areas_numba(w_list, h_list, w50, a50):
    """使用Numba加速面积计算"""
    n = len(w_list)
    areas = np.full(n, np.nan)
    
    # 边界检查：如果数据点太少，直接返回
    if n < 2:
        return areas
    
    idx50 = np.searchsorted(w_list, w50)
    if idx50 >= n:
        idx50 = n - 1
    if idx50 < 1:
        idx50 = 1
    
    # 防止除零错误
    denom = w_list[idx50] - w_list[idx50-1]
    if abs(denom) < 1e-10:
        # 如果宽度差太小，使用平均值
        h50 = (h_list[idx50-1] + h_list[idx50]) / 2.0
    else:
        h50 = (h_list[idx50-1] * (w_list[idx50] - w50) +
               h_list[idx50] * (w50 - w_list[idx50-1])) / denom
    
    areas[idx50] = a50 + 0.5 * (w50 + w_list[idx50]) * (h_list[idx50] - h50)
    
    for i in range(idx50 + 1, n):
        areas[i] = areas[i-1] + 0.5 * (w_list[i-1] + w_list[i]) * \
                  (h_list[i] - h_list[i-1])
    
    for i in range(idx50 - 1, -1, -1):
        areas[i] = areas[i+1] - 0.5 * (w_list[i+1] + w_list[i]) * \
                  (h_list[i+1] - h_list[i])
    
    return areas

@jit(nopython=True, cache=True)
def nse_numba(observed, simulated):
    """Numba加速的NSE计算"""
    obs_mean = np.mean(observed)
    numerator = np.sum((observed - simulated)**2)
    denominator = np.sum((observed - obs_mean)**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

@jit(nopython=True, cache=True)
def kge_numba(observed, simulated):
    """Numba加速的KGE计算"""
    obs_mean = np.mean(observed)
    sim_mean = np.mean(simulated)
    obs_std = np.std(observed)
    sim_std = np.std(simulated)
    
    if obs_std == 0 or sim_std == 0 or obs_mean == 0 or sim_mean == 0:
        return np.nan
    
    n = len(observed)
    cov = np.sum((observed - obs_mean) * (simulated - sim_mean)) / n
    r = cov / (obs_std * sim_std)
    
    alpha = sim_mean / obs_mean
    beta = (sim_std / sim_mean) / (obs_std / obs_mean)  # 变异系数比
    
    return 1 - np.sqrt((r - 1)**2 + (alpha - 1)**2 + (beta - 1)**2)

# ============================================================================
# 全局并行处理函数（必须在模块级别定义才能被pickle）
# ============================================================================
def _compute_node_corr(node_id):
    """计算单个节点的秩相关系数"""
    global _GLOBAL_NODE_DATA
    data = _GLOBAL_NODE_DATA.get(node_id)
    if data is None or len(data['width']) < 5:
        return (node_id, data['stationid'] if data else None, 0.0)
    
    try:
        corr, _ = spearmanr(data['width'], data['wse'])
        if np.isnan(corr):
            corr = 0.0
    except:
        corr = 0.0
    
    return (node_id, data['stationid'], corr)

def _process_qc_station(stationid):
    """处理单个站点的质控（全局函数版本）"""
    global _GLOBAL_WIDTH_STATS, _GLOBAL_SWOT_DATA
    
    if _GLOBAL_WIDTH_STATS is None or stationid not in _GLOBAL_WIDTH_STATS.index:
        return None
    
    df = _GLOBAL_SWOT_DATA[_GLOBAL_SWOT_DATA['stationid'] == stationid].copy()
    if len(df) < 5:
        return None
    
    # 步骤1: 不确定度筛选
    df['width_u_r'] = df['width_u'] / df['width']
    df1 = df[(df['wse_u'] <= 0.4) & (df['width_u_r'] <= 0.1)]
    if len(df1) < 5:
        return None
    
    # 步骤2: 顺序一致性剔除
    df2 = _remove_inconsistent_points(df1)
    if len(df2) < 5:
        return None
    
    # 步骤3: 离群值剔除
    w_low = _GLOBAL_WIDTH_STATS.loc[stationid, 'w_low']
    w_high = _GLOBAL_WIDTH_STATS.loc[stationid, 'w_high']
    d_bankfull = 0.27 * (w_high / 7.2) ** 0.6
    h50 = df2['wse'].median()
    
    df3 = df2[(df2['wse'] <= h50 + d_bankfull) & (df2['wse'] >= h50 - d_bankfull)]
    
    return df3 if len(df3) >= 5 else None

def _remove_inconsistent_points(df, inverse_ratio_thresh=0.5):
    """顺序一致性剔除"""
    indices_to_keep = list(df.index)
    
    while True:
        n = len(indices_to_keep)
        if n < 5:
            break
        
        df_current = df.loc[indices_to_keep]
        w = df_current['width'].values.astype(np.float64)
        h = df_current['wse'].values.astype(np.float64)
        
        inverse = compute_inconsistency_matrix(w, h)
        
        idx_max = np.argmax(inverse)
        if inverse[idx_max] / n < inverse_ratio_thresh:
            break
        
        indices_to_keep.pop(idx_max)
    
    return df.loc[indices_to_keep]

def _fit_station_wrapper(stationid):
    """拟合单个站点（全局函数版本）- 带诊断信息"""
    global _GLOBAL_FITTER, _GLOBAL_QC_DATA
    
    df_station = _GLOBAL_QC_DATA[_GLOBAL_QC_DATA['stationid'] == stationid]
    if len(df_station) == 0:
        return (None, {'stationid': stationid, 'reason': 'no_data_in_qc', 'details': 'Station not found in QC data'})
    
    comid = df_station.iloc[0]['COMID']
    result, drop_info = _GLOBAL_FITTER.fit_station_with_diagnostics(df_station, stationid, comid)
    return (result, drop_info)

import glob  # 在文件顶部添加

def _validate_station_wrapper(args):
    """验证单个站点（全局函数版本）"""
    s, df_hypso, df_width, df_val_folder, df_fit, skip_width_filter = args
    
    # 【修复1】使用 glob 匹配文件
    file_pattern = os.path.join(df_val_folder, f'{s}*.csv')
    matching_files = glob.glob(file_pattern)
    
    if not matching_files:
        return None
    
    file_path = matching_files[0]
    
    try:
        df_val = pd.read_csv(file_path)
        df_val = df_val.rename(columns={'Q': 'qobs'})
        df_val['stationid'] = s
        df_val = df_val.dropna(subset=['qobs'])
        
        # 【修复2】统一 date 列类型为 datetime
        df_val['date'] = pd.to_datetime(df_val['date'])
        
        df_width_s = df_width[df_width['stationid'] == s].copy()
        df_width_s['date'] = pd.to_datetime(df_width_s['date'])
        
        df_val = df_val.merge(df_width_s, on=['stationid', 'date'], how='inner')
        
        df_curve = df_hypso[df_hypso['stationid'] == s].reset_index(drop=True)
        station_fit = df_fit[df_fit['stationid'] == s]
        if station_fit.empty or df_curve.empty:
            return None
        
        row = station_fit.iloc[0]
        w_low, w_high, slp = row['w_low'], row['w_high'], row['slp']
        
        # 【修复3】处理合并后可能产生的 width_x/width_y 列名冲突
        width_col = 'width_x' if 'width_x' in df_val.columns else 'width'
        
        if skip_width_filter:
            df_val = df_val.drop_duplicates('date')
        else:
            df_val = df_val[
                (df_val[width_col] >= w_low) &
                (df_val[width_col] <= w_high)
            ].drop_duplicates('date')
        
        if len(df_val) < 10:
            return None
        
        curve_width = df_curve['width'].values
        curve_area = df_curve['area'].values
        val_width = df_val[width_col].values
        
        area_hypso = np.interp(val_width, curve_width, curve_area)
        
        df_val['area_hypso'] = area_hypso
        df_val['Q_est'] = (area_hypso**(5/3) * val_width**(-2/3) * slp**0.5 / 0.035)
        
        df_val = df_val.dropna()
        if len(df_val) < 10:
            return None
        
        obs = df_val['qobs'].values.astype(np.float64)
        sim = df_val['Q_est'].values.astype(np.float64)
        
        kge_val = kge_numba(obs, sim)
        nse_val = nse_numba(obs, sim)
        rmse = np.sqrt(np.mean((obs - sim)**2))
        nrmse_val = rmse / np.mean(obs)
        
        df_val['kge'] = kge_val
        df_val['nse'] = nse_val
        df_val['nrmse'] = nrmse_val
        
        # 【修复4】返回时重命名 width 列
        return df_val[['stationid', 'date', width_col, 'area_hypso',
                       'qobs', 'Q_est', 'kge', 'nse', 'nrmse']].rename(
                           columns={width_col: 'width'})
    except Exception as e:
        print(f"Error processing station {s}: {e}")
        return None


def _rolling_median_group(group):
    """滑动中值处理"""
    group = group.sort_values('date')
    group['width'] = group['width'].rolling(window=5, center=True, min_periods=1).median()
    group['wse'] = group['wse'].rolling(window=5, center=True, min_periods=1).median()
    return group

# ============================================================================
# 模块1: 数据统计工具 (修改为IQR方法)
# ============================================================================
class WidthStatistics:
    """计算河流宽度的统计特征 - 使用IQR方法"""
    
    # 定义IQR配置：{选项: IQR倍数}
    # w_low = Q1 - k * IQR
    # w_high = Q3 + k * IQR
    IQR_CONFIG = {
        '1.0': 1.0,
        '1.5': 1.5,
        '2.0': 2.0,
        '2.5': 2.5,
        '3.0': 3.0,
        '4.0': 4.0
    }
    
    @staticmethod
    def calculate_width_iqr(df, min_width=30, valid_ratio=0.95, min_iqr=5):
        """
        计算每个站点的宽度IQR范围
        
        Parameters:
        -----------
        df : DataFrame
            输入数据
        min_width : float
            最小有效宽度
        valid_ratio : float
            有效数据比例阈值
        min_iqr : float
            最小IQR阈值，当IQR小于此值时跳过该站点
        """
        stationids = df['stationid'].unique()
        result_data = []
        skipped_stations = []
        
        for stationid in stationids:
            station_data_all = df[df['stationid'] == stationid]['width'].dropna()
            station_data = station_data_all[station_data_all >= min_width]
            
            if len(station_data_all) == 0:
                continue
            if len(station_data) / len(station_data_all) < valid_ratio:
                continue
            
            if len(station_data) > 10:
                w50 = station_data.median()
                
                # 计算Q1, Q3和IQR
                q1 = station_data.quantile(0.25)
                q3 = station_data.quantile(0.75)
                iqr = q3 - q1
                
                # 检查IQR是否足够大，如果Q1和Q3太接近则跳过该站点
                if iqr < min_iqr:
                    skipped_stations.append((stationid, q1, q3, iqr))
                    continue
                
                row_data = {
                    'stationid': stationid,
                    'w50': w50,
                    'Q1': q1,
                    'Q3': q3,
                    'IQR': iqr
                }
                
                # 动态计算所有IQR倍数配置的范围
                for key, k in WidthStatistics.IQR_CONFIG.items():
                    w_low = max(q1 - k * iqr, min_width)  # 确保不低于最小宽度
                    w_high = q3 + k * iqr
                    row_data[f'w_low_iqr{key}'] = w_low
                    row_data[f'w_high_iqr{key}'] = w_high
                
                result_data.append(row_data)
        
        # 打印跳过的站点信息
        if skipped_stations:
            print(f"Skipped {len(skipped_stations)} stations due to small IQR (< {min_iqr}):")
            for sid, q1, q3, iqr in skipped_stations[:10]:  # 只打印前10个
                print(f"  Station {sid}: Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}")
            if len(skipped_stations) > 10:
                print(f"  ... and {len(skipped_stations) - 10} more stations")
        
        return pd.DataFrame(result_data)
    
    @staticmethod
    def get_iqr_columns(b_option):
        """根据B选项获取对应的IQR列名"""
        if b_option not in WidthStatistics.IQR_CONFIG:
            raise ValueError(f"Invalid B option: {b_option}. Valid options: {list(WidthStatistics.IQR_CONFIG.keys())}")
        
        return f'w_low_iqr{b_option}', f'w_high_iqr{b_option}'
    
    # 保留旧方法以兼容（如果需要）
    @staticmethod
    def calculate_width_percentiles(df, min_width=30, valid_ratio=0.95, min_iqr=5):
        """计算每个站点的宽度分位数（保留用于兼容）"""
        return WidthStatistics.calculate_width_iqr(df, min_width, valid_ratio, min_iqr)

# ============================================================================
# 模块2: 节点选择
# ============================================================================
class NodeSelector:
    """为每个站点选择最优节点"""
    
    @staticmethod
    def select_best_nodes(df_swot, min_data_points=10):
        """为每个站点选择秩相关系数最大的节点"""
        global _GLOBAL_NODE_DATA
        
        # 向量化计算节点数据量
        node_counts = df_swot.groupby('node_id').size()
        valid_nodes = node_counts[node_counts >= min_data_points].index
        df_swot = df_swot[df_swot['node_id'].isin(valid_nodes)].copy()
        
        # 预计算每个节点的数据
        _GLOBAL_NODE_DATA = {}
        for node_id, group in df_swot.groupby('node_id'):
            _GLOBAL_NODE_DATA[node_id] = {
                'width': group['width'].values,
                'wse': group['wse'].values,
                'stationid': group['stationid'].iloc[0]
            }
        
        node_ids = list(_GLOBAL_NODE_DATA.keys())
        
        # 使用进程池并行计算
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_compute_node_corr, node_ids)
        
        df_node = pd.DataFrame(results, columns=['node_id', 'stationid', 'rank_corr'])
        df_node = df_node.dropna(subset=['stationid'])
        
        # 选择每个站点的最大秩相关系数节点
        max_idx = df_node.groupby('stationid')['rank_corr'].idxmax()
        df_node_rmax = df_node.loc[max_idx]
        
        # 筛选数据
        df_filtered = df_swot[df_swot['node_id'].isin(df_node_rmax['node_id'])]
        
        print(f"Original nodes: {len(df_node)}, Selected nodes: {len(df_node_rmax)}")
        
        # 清理全局变量
        _GLOBAL_NODE_DATA = {}
        
        return df_filtered, df_node_rmax

# ============================================================================
# 模块3: 数据质控
# ============================================================================
class DataQualityControl:
    """SWOT数据质量控制"""
    
    def __init__(self, width_stats):
        self.width_stats = width_stats.set_index('stationid')
    
    def apply_qc(self, df_swot, draw_figure=False, output_folder=None):
        """应用完整的质量控制流程"""
        global _GLOBAL_WIDTH_STATS, _GLOBAL_SWOT_DATA
        
        _GLOBAL_WIDTH_STATS = self.width_stats
        _GLOBAL_SWOT_DATA = df_swot
        
        stationids = df_swot['stationid'].unique()
        
        # 使用进程池并行处理
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_process_qc_station, stationids)
        
        # 清理全局变量
        _GLOBAL_WIDTH_STATS = None
        _GLOBAL_SWOT_DATA = None
        
        # 合并结果
        results = [df for df in results if df is not None]
        
        if not results:
            return pd.DataFrame()
        
        result = pd.concat(results)
        result = result.drop_duplicates(subset=['node_id', 'date', 'stationid'])
        result.reset_index(drop=True, inplace=True)
        
        return result

# ============================================================================
# 模块4: 水位-面积曲线拟合 (添加诊断功能)
# ============================================================================
class HydraulicCurveFitter:
    """拟合水位-宽度关系曲线"""
    
    def __init__(self, width_stats, river_attrs, skip_width_filter=False):
        """
        Parameters:
        -----------
        width_stats : DataFrame
            宽度统计数据
        river_attrs : DataFrame
            河流属性数据
        skip_width_filter : bool
            是否跳过宽度筛选（用于datemean模式）
        """
        self.width_stats = width_stats.set_index('stationid')
        self.river_attrs = river_attrs.set_index('COMID')
        self.skip_width_filter = skip_width_filter
        
        self.R_list = np.array([0.5, 1, 2, 4, 8])
        self.GAP_list = np.array([-0.3,-0.1, 0, 0.1,0.3])
        self.W_list = np.array([0.3, 0.5, 0.7])
        
        # 诊断信息收集
        self.drop_reasons = []
    
    @staticmethod
    def power_function(params, X, y):
        wse0, a, b = params
        return y - (wse0 + a * X**b)
    
    def loss_function(self, z, weight, n_swot):
        rho = np.zeros((3, len(z)))
        rho[0] = 2 * ((1 + z)**0.5 - 1)
        rho[1] = (1 + z)**(-0.5)
        rho[2] = -0.5 * (1 + z)**(-1.5)
        
        factor = (n_swot - 2) / weight * (1 - weight) / 2
        rho[:, 0] *= factor
        rho[:, 1] *= factor
        
        return rho
    
    def calculate_h50(self, df, w50):
        df = df.copy()
        df['w50_diff'] = np.abs(df['width'] - w50)
        df = df.sort_values('w50_diff')
        
        xdata = df.iloc[:5]['width'].values
        ydata = df.iloc[:5]['wse'].values
        xdata_uni = np.unique(xdata)
        
        if len(xdata_uni) < 2:
            return df.iloc[:5]['wse'].mean()
        
        res = linregress(xdata, ydata)
        if res[0] >= 0:
            return res[0] * w50 + res[1]
        else:
            return df.iloc[:5]['wse'].mean()
    
    def fit_station_with_diagnostics(self, df_station, stationid, comid):
        """拟合单个站点的水位-宽度关系 - 带诊断信息"""
        drop_info = {'stationid': stationid, 'COMID': comid, 'reason': None, 'details': None}
        
        # 检查1: stationid是否在width_stats中
        if stationid not in self.width_stats.index:
            drop_info['reason'] = 'not_in_width_stats'
            drop_info['details'] = f'Station {stationid} not found in width statistics'
            return None, drop_info
        
        # 检查2: comid是否在river_attrs中
        if comid not in self.river_attrs.index:
            drop_info['reason'] = 'not_in_river_attrs'
            drop_info['details'] = f'COMID {comid} not found in river attributes'
            return None, drop_info
        
        try:
            q50 = self.river_attrs.loc[comid, 'q50_weighted']
            slp = self.river_attrs.loc[comid, 'slope']
            w50, w_low, w_high = self.width_stats.loc[stationid, ['w50', 'w_low', 'w_high']]
            d_bankfull = 0.27 * (w_high / 7.2)**0.6
            
            h50 = self.calculate_h50(df_station, w50)
            a50 = (q50 * 0.035 / slp**0.5 * w50**(2/3))**(3/5)
            
            # 记录原始点数
            original_count = len(df_station)
            
            # 根据skip_width_filter决定是否进行宽度筛选
            if self.skip_width_filter:
                # datemean模式：不对width做筛选
                df_filtered = df_station.copy()
            else:
                # node模式：正常进行宽度筛选
                df_filtered = df_station[
                    (df_station['width'] >= w_low) &
                    (df_station['width'] <= w_high)
                ]
            
            filtered_count = len(df_filtered)
            
            # 检查3: 宽度筛选后数据点太少
            if filtered_count < 3:
                drop_info['reason'] = 'too_few_points_after_width_filter'
                drop_info['details'] = (f'Original points: {original_count}, '
                                        f'After width filter: {filtered_count} (< 3 required), '
                                        f'w_low={w_low:.2f}, w_high={w_high:.2f}, '
                                        f'width range in data: [{df_station["width"].min():.2f}, {df_station["width"].max():.2f}]')
                return None, drop_info
            
            swot_wsemax = df_filtered.sort_values('wse', ascending=False).iloc[0]
            d_wsemax = 0.27 * (swot_wsemax['width'] / 7.2)**0.6
            
            results = []
            fit_failures = {'ab_negative': 0, 'ls_failed': 0, 'exception': 0}
            
            for r_low in self.R_list:
                for gap in self.GAP_list:
                    for weight in self.W_list:
                        result, fail_reason = self._fit_single_config_with_diagnostics(
                            df_filtered, r_low, gap, weight,
                            w_low, w_high, w50, h50, a50,
                            swot_wsemax, d_bankfull, d_wsemax, slp, q50
                        )
                        if result is not None:
                            result.update({
                                'stationid': stationid,
                                'COMID': comid,
                                'R': r_low,
                                'GAP': gap,
                                'W': weight
                            })
                            results.append(result)
                        elif fail_reason:
                            fit_failures[fail_reason] = fit_failures.get(fail_reason, 0) + 1
            
            # 检查4: 所有参数组合的拟合都失败
            if not results:
                total_configs = len(self.R_list) * len(self.GAP_list) * len(self.W_list)
                drop_info['reason'] = 'all_fits_failed'
                drop_info['details'] = (f'All {total_configs} parameter combinations failed. '
                                        f'Failures: ab_negative={fit_failures["ab_negative"]}, '
                                        f'ls_failed={fit_failures["ls_failed"]}, '
                                        f'exception={fit_failures["exception"]}')
                return None, drop_info
            
            # 成功
            drop_info['reason'] = 'success'
            drop_info['details'] = f'Successfully fitted with {len(results)} parameter combinations'
            return pd.DataFrame(results), drop_info
            
        except Exception as e:
            drop_info['reason'] = 'exception'
            drop_info['details'] = f'Exception during fitting: {str(e)}'
            return None, drop_info
    
    def fit_station(self, df_station, stationid, comid):
        """拟合单个站点的水位-宽度关系（保持原有接口）"""
        result, _ = self.fit_station_with_diagnostics(df_station, stationid, comid)
        return result
    
    def _fit_single_config_with_diagnostics(self, df, r_low, gap, weight, w_low, w_high, w50,
                          h50, a50, swot_wsemax, d_bankfull, d_wsemax, slp, q50):
        """拟合单个参数配置 - 带失败原因"""
        try:
            a_low = a50 * (r_low + 1) / r_low / w50**(r_low + 1)
            h0 = h50 - a_low * w50**r_low
            h_low = h0 + a_low * w_low**r_low
            h_high = swot_wsemax['wse'] + (d_bankfull - d_wsemax) + gap * d_bankfull
            
            xdata = np.insert(df['width'].values, 0, [w_low, w_high])
            ydata = np.insert(df['wse'].values, 0, [h_low, h_high])
            a_default = (h_high - h0) / w_high**2
            
            n_swot = len(df)
            
            def loss_wrapper(z):
                return self.loss_function(z, weight, n_swot)
            
            ls = least_squares(
                self.power_function,
                x0=[h0, a_default, 2],
                loss=loss_wrapper,
                args=(xdata, ydata),
                max_nfev=100
            )
            
            if ls.status > 0:
                wse0, a, b = ls.x
                if a * b < 0:
                    return None, 'ab_negative'
                
                return {
                    'wse0': wse0, 'a': a, 'b': b,
                    'a50': a50, 'w50': w50, 'q50': q50,
                    'w_low': w_low, 'w_high': w_high,
                    'h_low': h_low, 'h_high': h_high,
                    'slp': slp
                }, None
            else:
                return None, 'ls_failed'
        except Exception as e:
            return None, 'exception'
    
    def _fit_single_config(self, df, r_low, gap, weight, w_low, w_high, w50,
                          h50, a50, swot_wsemax, d_bankfull, d_wsemax, slp, q50):
        """拟合单个参数配置（保持原有接口）"""
        result, _ = self._fit_single_config_with_diagnostics(
            df, r_low, gap, weight, w_low, w_high, w50,
            h50, a50, swot_wsemax, d_bankfull, d_wsemax, slp, q50
        )
        return result
    
    def fit_all_stations(self, df_qc):
        """并行拟合所有站点 - 带诊断输出"""
        global _GLOBAL_FITTER, _GLOBAL_QC_DATA
        
        _GLOBAL_FITTER = self
        _GLOBAL_QC_DATA = df_qc
        
        unique_stations = df_qc['stationid'].unique()
        stations_in_qc = set(unique_stations)
        
        print(f"\n{'='*60}")
        print(f"DIAGNOSTIC: Fitting {len(unique_stations)} stations from QC data")
        print(f"{'='*60}")
        
        # 使用进程池并行处理
        with mp.Pool(processes=N_WORKERS) as pool:
            results_with_info = pool.map(_fit_station_wrapper, unique_stations)
        
        # 清理全局变量
        _GLOBAL_FITTER = None
        _GLOBAL_QC_DATA = None
        
        # 分离结果和诊断信息
        results = []
        drop_infos = []
        
        for result, drop_info in results_with_info:
            if result is not None:
                results.append(result)
            drop_infos.append(drop_info)
        
        # 生成诊断报告
        self._generate_diagnostic_report(drop_infos, stations_in_qc)
        
        if results:
            return pd.concat(results, ignore_index=True)
        return None
    
    def _generate_diagnostic_report(self, drop_infos, stations_in_qc):
        """生成诊断报告"""
        # 统计各种丢弃原因
        reason_counts = {}
        reason_details = {}
        
        for info in drop_infos:
            reason = info['reason']
            if reason not in reason_counts:
                reason_counts[reason] = 0
                reason_details[reason] = []
            reason_counts[reason] += 1
            if reason != 'success':
                reason_details[reason].append({
                    'stationid': info['stationid'],
                    'COMID': info.get('COMID'),
                    'details': info['details']
                })
        
        print(f"\n{'='*60}")
        print("DIAGNOSTIC REPORT: Station Drop Analysis (QC -> Fit)")
        print(f"{'='*60}")
        print(f"Total stations in QC data: {len(stations_in_qc)}")
        print(f"Successfully fitted: {reason_counts.get('success', 0)}")
        print(f"Dropped stations: {len(stations_in_qc) - reason_counts.get('success', 0)}")
        print()
        
        print("Drop reasons breakdown:")
        print("-" * 40)
        for reason, count in sorted(reason_counts.items(), key=lambda x: -x[1]):
            if reason != 'success':
                print(f"  {reason}: {count} stations")
        print()
        
        # 详细列出每种原因的前几个例子
        for reason in ['not_in_width_stats', 'not_in_river_attrs', 
                       'too_few_points_after_width_filter', 'all_fits_failed', 
                       'exception', 'no_data_in_qc']:
            if reason in reason_details and reason_details[reason]:
                print(f"\n{reason.upper()} (showing first 5 examples):")
                print("-" * 40)
                for item in reason_details[reason][:5]:
                    print(f"  Station: {item['stationid']}")
                    print(f"    COMID: {item['COMID']}")
                    print(f"    Details: {item['details']}")
                if len(reason_details[reason]) > 5:
                    print(f"  ... and {len(reason_details[reason]) - 5} more stations")
        
        # 保存完整诊断报告到CSV
        df_diagnostics = pd.DataFrame(drop_infos)
        return df_diagnostics

# ============================================================================
# 模块5: 水位-面积曲线生成
# ============================================================================
class HypsometricCurveGenerator:
    """生成水位-面积关系曲线"""
    
    @staticmethod
    def generate_curves(df_fit, n_points=100):
        """为所有站点生成中值水位-面积曲线"""
        stationids = sorted(df_fit['stationid'].unique())
        df_res = []
        
        for s in stationids:
            df_station = df_fit[df_fit['stationid'] == s]
            w_low, w_high, w50, a50 = df_station.iloc[0][
                ['w_low', 'w_high', 'w50', 'a50']
            ]
            
            # 边界检查：跳过无效的宽度范围
            if w_high <= w_low or abs(w_high - w_low) < 1e-6:
                print(f"Warning: Skipping station {s} due to invalid width range (w_low={w_low}, w_high={w_high})")
                continue
            
            wse0 = df_station['wse0'].values
            a = df_station['a'].values
            b = df_station['b'].values
            
            w_list = np.linspace(w_low, w_high, n_points)
            
            # 向量化计算
            heights_all = wse0[:, np.newaxis] + a[:, np.newaxis] * w_list**b[:, np.newaxis]
            h_list = np.median(heights_all, axis=0)
            hmax = np.max(heights_all, axis=0)
            hmin = np.min(heights_all, axis=0)
            
            # Numba加速的面积计算
            areas = calculate_areas_numba(w_list, h_list, w50, a50)
            
            df_curve = pd.DataFrame({
                'stationid': s,
                'width': w_list,
                'wse': h_list,
                'wse_max': hmax,
                'wse_min': hmin,
                'area': areas
            })
            
            df_res.append(df_curve)
        
        return pd.concat(df_res, ignore_index=True) if df_res else pd.DataFrame()

# ============================================================================
# 模块6: 验证与评估
# ============================================================================
class ModelValidator:
    """模型验证与性能评估"""
    
    @staticmethod
    def relative_rmse(observed, simulated):
        rmse = np.sqrt(mean_squared_error(observed, simulated))
        return rmse / np.mean(observed)
    
    def validate(self, df_hypso, df_width, df_val_folder, df_fit, skip_width_filter=False):
        """
        验证模型性能
        
        Parameters:
        -----------
        skip_width_filter : bool
            是否跳过宽度筛选（用于datemean模式）
        """
        stationids = sorted(df_hypso['stationid'].unique())
        print(stationids)
        # 【修改】在参数列表中添加 skip_width_filter
        args_list = [
            (s, df_hypso, df_width, df_val_folder, df_fit,  skip_width_filter)
            for s in stationids
        ]
        
        # 使用进程池并行处理
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_validate_station_wrapper, args_list)
        
        results = [df for df in results if df is not None]
        
        return pd.concat(results, ignore_index=True) if results else pd.DataFrame()

# ============================================================================
# 配置运行函数
# ============================================================================
def run_configuration(a, b, c, d, common_data):
    """
    运行单个配置
    
    Parameters:
    -----------
    a : str
        处理方式: 'node' 或 'datemean'
    b : str
        IQR倍数选项: '1.0', '1.5', '2.0', '2.5', '3.0', '4.0'
    c : str
        QA选项: 'noqa', 'qaloose', 'qastrict'
    d : str
        版本选项: 'VersionD', 'VersionC'
    common_data : dict
        共享数据
    """
    print(f"\n{'='*60}")
    print(f"Running configuration: A={a}, B={b}, C={c}, D={d}")
    print(f"{'='*60}")
    
    gc.collect()
    
    # 步骤1: 计算宽度统计（使用IQR方法）
    print("Step 1: Calculating width statistics using IQR method...")
    df_l8 = common_data['df_l8']
    df_w_stats = WidthStatistics.calculate_width_iqr(df_l8)
    
    # 使用IQR方法获取列名
    low_col, high_col = WidthStatistics.get_iqr_columns(b)
    df_w_stats['w_low'] = df_w_stats[low_col]
    df_w_stats['w_high'] = df_w_stats[high_col]
    
    df_w_stats.to_csv(f'1.width_statistic_iqr_{a}_{b}_{c}_{d}.csv', index=False)
    
    df_comid = common_data['df_comid']
    df_attrs = common_data['df_attrs']

    # 根据c选项加载SWOT数据
    if c == 'noqa':
        # noqa时文件名固定
        df_swot = pd.read_csv(f'1.all_matched_points_{d}.csv')
    else:
        # qaloose或qastrict时，文件名根据a确定
        df_swot = pd.read_csv(f'2.swot_{a}_{c}_{d}.csv')
   
    df_swot = df_swot.merge(df_comid, on='stationid', how='inner')

    if a == 'node':
        print("Step 2: Selecting best nodes...")
        df_swot_filtered, df_nodes = NodeSelector.select_best_nodes(df_swot, min_data_points=10)
        
        print("Step 3: Applying quality control...")
        qc = DataQualityControl(df_w_stats)
        df_qc = qc.apply_qc(df_swot_filtered, draw_figure=False)
        
    elif a == 'datemean':
        print("Using smoothed data (skipping width filter)...")
        # 使用进程池并行处理滑动中值
        groups = [group for _, group in df_swot.groupby('stationid')]
        
        with mp.Pool(processes=N_WORKERS) as pool:
            results = pool.map(_rolling_median_group, groups)
        
        df_qc = pd.concat(results)
    
    if 'COMID' not in df_qc.columns:
        df_qc = df_qc.merge(df_comid, on='stationid', how='left')
    
    cols = ['COMID'] + [col for col in df_qc.columns if col != 'COMID']
    df_qc = df_qc[cols]
    df_qc.to_csv(f'2.swot-points-selection_iqr_{a}_{b}_{c}_{d}.csv', index=False)
    
    # ========== 新增：输出QC阶段站点统计 ==========
    stations_in_qc = df_qc['stationid'].unique()
    print(f"\n[DIAGNOSTIC] Stations after QC (Step 2): {len(stations_in_qc)}")
    
    # 步骤4: 拟合
    print("Step 4: Fitting hydraulic curves...")
    # 根据a选项决定是否跳过宽度筛选
    skip_width_filter = (a == 'datemean')
    fitter = HydraulicCurveFitter(df_w_stats, df_attrs, skip_width_filter=skip_width_filter)
    df_fit_all = fitter.fit_all_stations(df_qc)
    
    if df_fit_all is None or len(df_fit_all) == 0:
        print(f"No fit data for {a}_{b}_{c}_{d}")
        return
    
    # ========== 新增：输出拟合阶段站点统计 ==========
    stations_in_fit = df_fit_all['stationid'].unique()
    print(f"\n[DIAGNOSTIC] Stations after Fitting (Step 3): {len(stations_in_fit)}")
    print(f"[DIAGNOSTIC] Station loss from QC to Fit: {len(stations_in_qc) - len(stations_in_fit)}")
    
    # 保存诊断报告
    diagnostic_file = f'diagnostic_qc_to_fit_{a}_{b}_{c}_{d}.csv'
    
    df_fit_all.to_csv(f'3.fit_proba_modified_q50_iqr_{a}_{b}_{c}_{d}.csv', index=False)
    
    # 步骤5: 生成曲线
    print("Step 5: Generating hypsometric curves...")
    df_hypso = HypsometricCurveGenerator.generate_curves(df_fit_all)
    
    if df_hypso is None or len(df_hypso) == 0:
        print(f"No hypsometric curves generated for {a}_{b}_{c}_{d}")
        return
    
    df_hypso.to_csv(f'4.hypso_med_modified_q50_iqr_{a}_{b}_{c}_{d}.csv', index=False)
    
    # 步骤6: 验证
    print("Step 6: Validating model...")
    validator = ModelValidator()
    df_width = common_data['df_width']
    
    # 【修改】传入 skip_width_filter 参数
    df_results = validator.validate(
        df_hypso, df_width,
        '/home/xj/device5/202411-SWAP/northchina/landsat/4-validation/observation',
        df_fit_all,
        skip_width_filter=skip_width_filter  # 传递参数
    )
    
    if df_results is None or len(df_results) == 0:
        print(f"No validation results for {a}_{b}_{c}_{d}")
        return
    
    df_results.to_csv(f'5.q_kge_med_modified_q50_iqr_{a}_{b}_{c}_{d}.csv', index=False)
    
    print(f"Configuration {a}_{b}_{c}_{d} completed!")
    gc.collect()

# ============================================================================
# 配置生成函数
# ============================================================================
def generate_configs():
    """
    生成所有有效的配置组合
    规则: noqa只和node组合，qaloose和qastrict可以和所有A选项组合
    
    B选项现在是IQR倍数: '1.0', '1.5', '2.0', '2.5', '3.0', '4.0'
    注意: 
    - 当a='node'时，遍历所有b选项，正常进行宽度筛选
    - 当a='datemean'时，也遍历所有b选项（用于验证阶段），但拟合阶段不做宽度筛选
    """
    A_options = ['node']
    B_options = ['1.0','1.5']  # IQR倍数
    C_options = ['noqa','qaloose']
    D_options = ['VersionD']
    
    configs = []
    
    for a in A_options:
        for b in B_options:
            for c in C_options:
                for d in D_options:
                    # noqa只和node组合
                    if c == 'noqa' and a != 'node':
                        continue
                    configs.append((a, b, c, d))
    
    return configs

# ============================================================================
# 主程序
# ============================================================================
def main():
    """主程序流程"""
    import time
    total_start = time.time()
    
    print("Loading common data...")
    df_l8 = pd.read_csv('../2-preprocess/1.north_glow_datemean_width_timeseries.csv')
    df_comid = pd.read_csv('../2-preprocess/4.q50_weighted_slope.csv')[['stationid', 'COMID']]
    df_attrs = pd.read_csv('../2-preprocess/4.q50_weighted_slope.csv')
    df_width = pd.read_csv('../2-preprocess/1.north_glow_datemean_width_timeseries.csv')
    df_width['date'] = pd.to_datetime(df_width['date'])
    
    common_data = {
        'df_l8': df_l8,
        'df_comid': df_comid,
        'df_attrs': df_attrs,
        'df_width': df_width
    }
    
    # 生成有效配置
    configs = generate_configs()
    
    print(f"\nTotal configurations to run: {len(configs)}")
    print("Configurations:")
    for cfg in configs:
        print(f"  {cfg}")
    
    # 运行所有配置
    for a, b, c, d in configs:
        start = time.time()
        run_configuration(a, b, c, d, common_data)
        print(f"Time for ({a}, {b}, {c}, {d}): {time.time() - start:.2f}s")
    
    # 生成箱型图
    print("\nGenerating boxplot comparisons...")
    metrics = ['kge', 'nse', 'nrmse']
    data_dict = {metric: [] for metric in metrics}
    labels = []
    
    for a, b, c, d in configs:
        file = f'5.q_kge_med_modified_q50_iqr_{a}_{b}_{c}_{d}.csv'
        if os.path.exists(file):
            df = pd.read_csv(file)
            label = f'{a}_{b}_{c}_{d}'
            labels.append(label)
            for metric in metrics:
                station_metrics = df.groupby('stationid')[metric].mean().values
                data_dict[metric].append(station_metrics)
        
    for metric in metrics:
        if data_dict[metric]:
            fig, ax = plt.subplots(figsize=(14, 6))
            ax.boxplot(data_dict[metric], labels=labels)
            ax.set_title(f'{metric.upper()} Boxplot Comparison')
            ax.set_xlabel('Configuration (A_B_C_D)')
            ax.set_ylabel(metric.upper())
            plt.xticks(rotation=45, ha='right')
            plt.tight_layout()
            plt.savefig(f'boxplot_{metric}.png', dpi=150)
            plt.close()
    
    print(f"\nTotal time: {time.time() - total_start:.2f}s")

if __name__ == '__main__':
    main()

Using 64 workers for parallel processing
Loading common data...

Total configurations to run: 4
Configurations:
  ('node', '1.0', 'noqa', 'VersionD')
  ('node', '1.0', 'qaloose', 'VersionD')
  ('node', '1.5', 'noqa', 'VersionD')
  ('node', '1.5', 'qaloose', 'VersionD')

Running configuration: A=node, B=1.0, C=noqa, D=VersionD
Step 1: Calculating width statistics using IQR method...
Step 2: Selecting best nodes...
Original nodes: 61, Selected nodes: 21
Step 3: Applying quality control...

[DIAGNOSTIC] Stations after QC (Step 2): 17
Step 4: Fitting hydraulic curves...

DIAGNOSTIC: Fitting 17 stations from QC data

DIAGNOSTIC REPORT: Station Drop Analysis (QC -> Fit)
Total stations in QC data: 17
Successfully fitted: 14
Dropped stations: 3

Drop reasons breakdown:
----------------------------------------
  all_fits_failed: 2 stations
  too_few_points_after_width_filter: 1 stations


TOO_FEW_POINTS_AFTER_WIDTH_FILTER (showing first 5 examples):
----------------------------------------
  St