In [1]:
from src.plot_utils import plot_scatter,plot_map
import os,glob,cmaps,sys,re
import hydroeval as he
import numpy as np
from pyogrio import read_dataframe
import pandas as pd
import pickle
import matplotlib.ticker as ticker
import seaborn as sns
import cartopy.feature as cf
import matplotlib.pyplot as plt
import multiprocessing as mp
import matplotlib as mpl
import cartopy.crs as ccrs
from src.ale import _get_centres
from scipy.stats import pearsonr
import scipy.stats as stats
import pymannkendall as mk
from matplotlib.colors import LogNorm
import matplotlib.font_manager as font_manager
from scipy.interpolate import interp1d
font_manager.fontManager.addfont(os.environ['DATA']+'/fonts/Helvetica/Helvetica.ttf')
font_manager.fontManager.addfont(os.environ['DATA']+'/fonts/Helvetica/Helvetica-Bold.ttf')
plt.style.use(['science','nature','no-latex']) # require install SciencePlots
plt.rc('font', size = 12, family = 'Helvetica')
from parallel_pandas import ParallelPandas
ParallelPandas.initialize(n_cpu=24, split_factor=24)

palette = {'tropical':'#F8D347',
           'dry':'#C7B18A',
           'temperate':"#65C2A5",
           'cold':"#a692b0",
#            'polar':"#B3B3B3"
          }

In [None]:
import pickle
from src.ale import ale,_second_order_quant_plot,_ax_quantiles,_ax_labels,_ax_title
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl

fig, axes = plt.subplots(2, 2, figsize = (10, 8))
var = 'ImperviousSurface'
s = 'urbanization'

for i,name in enumerate(['Qmin7','Qmax7']):
    mc_ale0 = pickle.load(open(f'../results/mc_ale_{var}_xgb_{name}_seasonal4_multi_MSWX_meteo.pkl','rb'))
    mc_quantiles0 = pickle.load(open(f'../results/mc_quantiles_{var}_xgb_{name}_seasonal4_multi_MSWX_meteo.pkl','rb'))
    if i == 0:
        ax = axes[0,i]
        name2 = 'low river flows'
    else:
        ax = axes[0,i]
        name2 = 'high river flows'
    for climate,color in palette.items():
        quantile = np.hstack([_get_centres(a) for a in mc_quantiles0[climate]])
        ale = np.hstack(mc_ale0[climate])
        tmp = pd.DataFrame({
            'ale' : ale,
            'quantile': quantile,
            'group': np.hstack([['group'+str(a)]*len(b) for a,b in enumerate(mc_ale0[climate])])
        })
        # Find common x-axis range
        min_x = tmp['quantile'].min()
        max_x = tmp['quantile'].max()
        n_sample = tmp.groupby('group')['ale'].count().min()
        x_range = np.linspace(min_x, max_x, n_sample)  # Adjust number of points as needed

        # Interpolate each group to common x-axis
        interpolated_data = {}
        for group in tmp['group'].unique():
            group_df =tmp[tmp['group'] == group]
            f = interp1d(group_df['quantile'], group_df['ale'], kind='linear', fill_value='extrapolate')
            interpolated_data[group] = pd.DataFrame({'x': x_range, 'y': f(x_range), 'group': group})

        # Concatenate interpolated data
        interpolated_df = pd.concat(interpolated_data.values())

        # Calculate average y values for each x
        df_avg = interpolated_df.groupby('x').apply(
            lambda x: pd.Series([x.y.mean(), x.y.quantile(.025), x.y.quantile(.975)], index = ['ave','low','upp'])
        ).reset_index()
        
        # transform ALE values from log-scale to percentage-scale
        df_avg[['ave','low','upp']] = (np.exp(df_avg[['ave','low','upp']]) -1) * 100
        
        # Create the lineplot with individual lines
        ax.plot(df_avg.x.values, df_avg.ave.values, color = palette[climate], lw = 2, label = climate, zorder = 3)
        ax.fill_between(
            df_avg.x.values, 
            df_avg.low.values, 
            df_avg.upp.values, 
            color = palette[climate], 
            ec = 'none',
            alpha = .3)
    
    ax.tick_params(axis = 'both', labelsize = 11)
    ax.set_xlabel('Urban area (%)', fontsize = 11)
    ax.set_ylabel(f'{name} in $\Delta$%', fontsize = 11)
    ax.set_title(f'Effects of {s} on {name}', fontsize = 11)
    ax.text(-.1, 1.1, ['a','b'][i], weight = 'bold', fontsize = 12, ha = 'center', va = 'top', transform = ax.transAxes)
    if i == 0:
        ax.legend(fontsize = 11)

# feature1 = 'Urban area (%)'
# feature2 = 'Aridity (rainfall/ET)'

# for i,name in enumerate(['Qmin7','Qmax7']):
#     ale0, quantiles_list = pickle.load(open(f'../results/ale2D_urban&aridity_xgb_{name}_seasonal4_multi_MSWX_meteo.pkl','rb'))
#     ax = axes[1,i]
        
#     cbar = _second_order_quant_plot(fig, ax, quantiles_list, ale0, mark_empty = False)
#     cbar.ax.tick_params(labelsize=11)
#     cbar.set_label(f'{name} in $\Delta$%', fontsize = 11)
    
#     _ax_labels(
#         ax,
#         feature1,
#         feature2,
#         fontsize = 11
#     )
# #     quantiles = np.arange(0, 100, 10)
# #     for twin in ("x", "y"):
# #         _ax_quantiles(ax, quantiles, twin=twin)
#     if i == 0:
#         ax.set_title("Impacts of urbanization and aridity on Qmin7", fontsize = 11)
#     else:
#         ax.set_title("Impacts of urbanization and aridity on Qmax7", fontsize = 11)
#     ax.tick_params(axis = 'both', labelsize = 11)
#     ax.set_yticklabels(ax.get_yticks()/10000)
#     ax.text(-.1, 1.1, ['c','d'][i], weight = 'bold', fontsize = 12, ha = 'center', va = 'center', transform = ax.transAxes)

# # add arrow to indicate dry/wet direction
# ax = axes[1,0]
# ax.annotate("Dry", xy=(-.13, 0.25), xytext=(-.13, 0.1), 
#             arrowprops=dict(arrowstyle = '<|-', color = 'red'), va = 'center', ha = 'center', weight = 'bold',
#             textcoords = 'axes fraction', xycoords = 'axes fraction', fontsize = 10, color = 'red')
# ax.annotate("Wet", xy=(-.13, 0.75), xytext=(-.13, 0.9), 
#             arrowprops=dict(arrowstyle = '<|-', color = 'blue'), va = 'center', ha = 'center', weight = 'bold',
#             textcoords = 'axes fraction', xycoords = 'axes fraction', fontsize = 10, color = 'blue')
    
# fig.tight_layout()
# fig.savefig('../picture/ale_1D_2D_urban_aridity.png', dpi = 600)
