In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors as mpc
from scipy.stats import spearmanr
from adjustText import adjust_text
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 16
plt.rcParams['axes.titlesize'] = 24
plt.rcParams['axes.labelsize'] = 24
plt.rcParams['axes.titlepad'] = 10
#plt.rcParams['axes.labelpad'] = -2
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16
plt.rcParams['xtick.major.width'] = 1.6
plt.rcParams['ytick.major.width'] = 1.6
plt.rcParams['axes.facecolor'] = '#F8F8F8'
plt.rcParams['figure.facecolor'] = '#F8F8F8'

N = 3
M = 6
BOUNDS = [0.12, 0.08, 0.76, 0.78]
MARGIN_LEFT, MARGIN_RIGHT, MARGIN_BOT, MARGIN_TOP = 0.04, 0, 0, 0
subplot_width = (1 - MARGIN_LEFT - MARGIN_RIGHT) / M
subplot_height = (1 - MARGIN_TOP - MARGIN_BOT) / N
AXES_LEFT = MARGIN_LEFT + (np.arange(M) + BOUNDS[0]) * subplot_width
AXES_BOT = MARGIN_BOT + (np.arange(N) + BOUNDS[1]) * subplot_height
AXES_WIDTH = subplot_width * BOUNDS[2]
AXES_HEIGHT = subplot_height * BOUNDS[3]
fig = plt.figure(figsize=(20,12))
ax = [[plt.axes([AXES_LEFT[j], AXES_BOT[i], AXES_WIDTH, AXES_HEIGHT])\
       for j in range(min(M,18-(N-1-i)*M))] for i in np.arange(N-1,-1,-1)]

df_all = pd.read_csv('../node_clip_0907.csv')
df_comid = pd.read_csv('../swot_w_h_north_china.csv')
df_final_all = pd.read_csv('../swot_fit.csv')
df_all = df_all.replace(-999999999999, np.nan)
df_all = df_all.dropna()
df_all = df_all.rename(columns={'STCD':'siteid'})
df_comid = df_comid[['siteid','station','COMID']].drop_duplicates(subset='siteid')
df_all = df_all.merge(df_comid, on='siteid', how='left')
df_all = df_all.drop(columns=['Unnamed: 0'])
df_all = df_all[~df_all['station'].isin(['baimasi','kuerbin','longmenzhen'])]
stations = sorted(df_all['station'].unique())

# sort node by ascending longitude
df_all['lon_mean'] = df_all.groupby('node_id')['lon'].transform('mean')
df_all = df_all.sort_values(['station','lon_mean'])

# filter by uncertainty
df_all['width_u_r'] = df_all['width_u'] / df_all['width']
df_all = df_all[(df_all['wse_u']<=0.4) & (df_all['width_u_r']<=0.1)]

# calculate rank_corr
df_node_all = df_all[['node_id','station','lon_mean']].drop_duplicates(subset='node_id')
df_node_all = df_node_all.set_index('node_id')
for node in df_node_all.index:
    df_sel = df_all[df_all['node_id']==node]
    df_node_all.loc[node,'rank_corr'] = spearmanr(df_sel['width'], df_sel['wse'])[0]
df_node_all = df_node_all.reset_index()
#breakpoint()

# main
for panel in range(18):
    s = stations[panel]
    i, j = int(panel / M), panel % M
    df_node = df_node_all[df_node_all['station']==s]
    df_final = df_final_all[df_final_all['station']==s]

    corr_node = df_node['rank_corr'].values
    corr_final = spearmanr(df_final['width'], df_final['wse'])[0]

    ax[i][j].bar(x=np.arange(len(corr_node)), height=corr_node, width=0.75, fc='#B0B0B0', ec='#606060', lw=1, zorder=0)
    if s != 'xingjiawopeng':
        ax[i][j].bar(x=[3], height=[corr_final], width=0.75, fc='#D56565', ec='#703030', lw=1, zorder=0)
    ax[i][j].plot([-5,5], [0,0], color='black', lw=1, zorder=1)
    
    ax[i][j].set_xlim(-0.5,3.5)
    ax[i][j].set_ylim(-0.5,0.8)
    ax[i][j].set_xticks(np.arange(4), ['N1','N2','N3','Final'])
    ax[i][j].set_yticks(np.arange(-0.4,1,0.4))
    ax[i][j].set_title(s, fontweight='bold')

    if j == 0: ax[i][j].set_ylabel('Rank CC')
    # print(s,corr_node,corr_final)

#plt.show()
plt.savefig('fig5.png', dpi=200)