In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
from scipy.optimize import least_squares
import os
import warnings

warnings.filterwarnings('ignore')

# -----------------------------
# 配置参数
# -----------------------------
R_list = np.array([0.5, 1, 2])
GAP_list = np.array([-0.1, 0, 0.1])
W_list = np.array([0.3, 0.5, 0.7])

# -----------------------------
# 数学函数
# -----------------------------
def power_func(params, X, y):
    wse0, a, b = params
    return y - (wse0 + a * X**b)

def custom_loss(z):
    rho = np.zeros((3, len(z)))
    rho[0] = 2 * ((1 + z)**0.5 - 1)
    rho[1] = (1 + z)**(-0.5)
    rho[2] = -0.5 * (1 + z)**(-1.5)
    rho[:, 0] *= (len(z) - 2) / weight * (1 - weight) / 2
    rho[:, 1] *= (len(z) - 2) / weight * (1 - weight) / 2
    return rho

def estimate_h50(df, w50):
    df['w50_diff'] = np.abs(df['width'] - w50)
    df = df.sort_values('w50_diff').iloc[:5]
    x, y = df['width'].values, df['wse'].values
    if len(np.unique(x)) < 2:
        return np.mean(y)
    res = linregress(x, y)
    return res[0] * w50 + res[1] if res[0] >= 0 else np.mean(y)

def solve_a50(q50, slp, w50):
    return (q50 * 0.035 / slp**0.5 * w50**(2/3))**(3/5)

def estimate_bankfull_depth(width):
    return 0.27 * (width / 7.2)**0.6

# -----------------------------
# 主处理函数
# -----------------------------
def process_station(station_id, df_s3all, df_attr, df_w):
    results = []
    df_s3 = df_s3all[df_s3all['stationid'] == station_id]
    st, comid = df_s3.iloc[0][['stationid', 'COMID']]
    q50 = df_attr.loc[comid, 'q50_weighted']
    slp = df_attr.loc[comid, 'slope']
    w50, w_low, w_high = df_w.loc[st, ['w50', 'w_low', 'w_high']]
    d_bankfull = estimate_bankfull_depth(w_high)

    h50 = estimate_h50(df_s3.copy(), w50)

    df_s4 = df_s3[(df_s3['width'] >= w_low) & (df_s3['width'] <= w_high)]
    if len(df_s4) < 3:
        return []

    swot_max = df_s4.sort_values('wse', ascending=False).iloc[0]
    d_wsemax = estimate_bankfull_depth(swot_max['width'])
    a50 = solve_a50(q50, slp, w50)

    for r_low in R_list:
        for gap in GAP_list:
            for wgt in W_list:
                a_low = a50 * (r_low + 1) / r_low / w50**(r_low + 1)
                h0 = h50 - a_low * w50**r_low
                h_low = h0 + a_low * w_low**r_low
                h_high = swot_max['wse'] + (d_bankfull - d_wsemax) + gap * d_bankfull

                xdata = np.insert(df_s4['width'].values, [0, 0], [w_low, w_high])
                ydata = np.insert(df_s4['wse'].values, [0, 0], [h_low, h_high])
                a_default = (h_high - h0) / w_high**2

                global weight  # make weight accessible in loss function
                weight = wgt
                try:
                    ls = least_squares(power_func, x0=[h0, a_default, 2], loss=custom_loss, args=(xdata, ydata))
                except Exception as e:
                    print(f"\nFit failed for station {station_id}, skipping: {e}")
                    continue

                results.append({
                    'stationid': st, 'R': r_low, 'GAP': gap, 'W': wgt,
                    'wse0': ls.x[0], 'a': ls.x[1], 'b': ls.x[2],
                    'a50': a50, 'w50': w50, 'q50': q50,
                    'w_low': w_low, 'w_high': w_high,
                    'h_low': h_low, 'h_high': h_high, 'slp': slp
                })

    return results

# -----------------------------
# 主程序
# -----------------------------
def main():
    print("Loading data...")
    df_s3all = pd.read_csv('swot_s3.csv').drop_duplicates(subset=['stationid', 'time'], keep='first')
    df_attr = pd.read_csv('gages3000_GRFR_q50_slp.csv')
    df_q50_weighted = pd.read_csv('q50_weighted.csv')
    df_attr = df_attr.merge(df_q50_weighted[['stationid', 'q50_weighted']], on='stationid', how='inner')
    df_attr = df_attr.drop_duplicates(subset='COMID').set_index('COMID')

    df_w = pd.read_csv('1_3_w50.csv', index_col='stationid')
    df_s3all = df_s3all[df_s3all['stationid'].isin(df_w.index)]
    station_ids = df_s3all['stationid'].unique()

    all_results = []

    for i, sid in enumerate(station_ids):
        print(f'\rProcessing station {i + 1}/{len(station_ids)}: {sid}', end='')
        station_results = process_station(sid, df_s3all, df_attr, df_w)
        all_results.extend(station_results)

    print('\nSaving results...')
    df_res = pd.DataFrame(all_results)
    df_res = df_res.dropna()
    df_res.to_csv('3/fit_proba_modified_q50.csv', index=False)
    print("Done.")

if __name__ == '__main__':
    main()


Processing 1178/1178: GRDC_41480500007
Done!


In [19]:
df_ori = pd.read_csv('./ori/3/fit_proba_modified_q50.csv')
df_res = pd.read_csv('3/fit_proba_modified_q50.csv')
# station_ids = ['Brazil_10500000','Brazil_81350000']

# df_ori = df_ori[df_ori['stationid'].isin(station_ids)]
# df_res  = df_res[df_res['stationid'].isin(station_ids)]
# print(df_res.columns)
df_res = df_res.sort_values(by=['stationid','R', 'GAP', 'W', 'COMID']).reset_index()
df_ori = df_ori.sort_values(by=['stationid','R', 'GAP', 'W', 'COMID']).reset_index()
df_dif = df_ori[['R', 'GAP', 'W', 'COMID', 'wse0', 'a', 'b', 'a50', 'w50',
       'q50', 'w_low', 'w_high', 'h_low', 'h_high', 'slp']]-df_res[['R', 'GAP', 'W', 'COMID', 'wse0', 'a', 'b', 'a50', 'w50',
       'q50', 'w_low', 'w_high', 'h_low', 'h_high', 'slp']]
df_nonzero = df_dif[(df_dif != 0).any(axis=1)]

pd.options.display.float_format = '{:.3f}'.format
print(df_nonzero)

          R   GAP     W  COMID   wse0      a      b   a50   w50   q50  w_low  \
0     0.000 0.000 0.000  0.000  0.000 -0.000  0.000 0.000 0.000 0.000  0.000   
1     0.000 0.000 0.000  0.000  0.000 -0.000  0.000 0.000 0.000 0.000  0.000   
2     0.000 0.000 0.000  0.000  0.000 -0.000  0.000 0.000 0.000 0.000  0.000   
3     0.000 0.000 0.000  0.000 -0.000  0.000 -0.000 0.000 0.000 0.000  0.000   
4     0.000 0.000 0.000  0.000 -0.000 -0.000  0.000 0.000 0.000 0.000  0.000   
...     ...   ...   ...    ...    ...    ...    ...   ...   ...   ...    ...   
19354 0.000 0.000 0.000  0.000  0.000 -0.000  0.000 0.000 0.000 0.000  0.000   
19355 0.000 0.000 0.000  0.000 -0.000  0.000 -0.000 0.000 0.000 0.000  0.000   
19356 0.000 0.000 0.000  0.000 -0.000 -0.000  0.000 0.000 0.000 0.000  0.000   
19357 0.000 0.000 0.000  0.000 -0.000  0.000 -0.000 0.000 0.000 0.000  0.000   
19358 0.000 0.000 0.000  0.000  0.000 -0.000  0.000 0.000 0.000 0.000  0.000   

       w_high  h_low  h_high   slp  
0 