In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import warnings

# Ignore warnings to keep the output clean
warnings.filterwarnings("ignore")

def plot_trend(ax, cesm2_file, gridmet_file, title):
    # Load CESM2 trends
    cesm2_ds = xr.open_dataset(cesm2_file)
    
    # Load GridMET trends
    gridmet_ds = xr.open_dataset(gridmet_file)
    
    # Extract trends for each ensemble member and calculate the ensemble mean
    ensemble_trends_cesm2_values = cesm2_ds['ensemble_trends'].mean(dim=['lat', 'lon']).values.flatten()
    gridmet_trend_value = gridmet_ds['observed_trends'].mean().item()
    ensemble_mean = ensemble_trends_cesm2_values.mean()
    
    # Prepare data for the boxplot
    trend_data = {'CESM2 Ensemble': ensemble_trends_cesm2_values}
    trend_df = pd.DataFrame(trend_data)
    
    # Plot the boxplot
    sns.boxplot(ax=ax, data=trend_df, color='lightgrey', width=0.5, showfliers=False)
    sns.stripplot(ax=ax, data=trend_df, color='grey', alpha=0.5, size=5, label='Ensemble Member')
    
    # Plot the GridMET and ensemble mean markers
    ax.scatter(0, gridmet_trend_value, color='red', zorder=5, marker='o', s=100, label='GridMET')
    ax.scatter(0, ensemble_mean, color='blue', zorder=5, marker='x', s=100, label='Ensemble Mean')
    
    # Customize the plot
    ax.set_title(title, fontsize=16, weight='bold')
    ax.set_ylabel('Trend (per year)', fontsize=14, weight='bold')
    ax.grid(True)
    
    # Customize tick parameters and improve boxplot aesthetics
    ax.set_xticks([0])
    ax.set_xticklabels(['CESM2/GridMET'], fontsize=12, weight='bold')
    ax.tick_params(axis='y', labelsize=12)
    
    for artist in ax.artists:
        artist.set_edgecolor('black')
        artist.set_facecolor('lightgrey')
    for line in ax.lines:
        line.set_color('black')
        line.set_linewidth(1.5)

# File paths for the CESM2 and GridMET data for each climate metric
cdd_cesm2_file = '/home/shawn_preston/daysbelow0ensemble/testtrend_cdd_analysis.nc'
cdd_gridmet_file = '/home/shawn_preston/daysbelow0ensemble/testtestgridmet_trends_cdd_analysis.nc'

last_frost_cesm2_file = '/home/shawn_preston/lastdayfrostensemble/testtesttrend_last_frost_dates_ensemble.nc'
last_frost_gridmet_file = '/home/shawn_preston/lastdayfrostensemble/testtestgridmet_trends_last_frost_analysis.nc'

gdd_jan_apr_cesm2_file = '/home/shawn_preston/gddcesm2janapr/testtesttrend_gdd_bb_analysis.nc'
gdd_jan_apr_gridmet_file = '/home/shawn_preston/gddcesm2janapr/testtestgridmet_trends_gddbb_analysis.nc'

gdd_jan_sept_cesm2_file = '/home/shawn_preston/gddcdesm2jansept/testtesttrend_gdd_gg_analysis.nc'
gdd_jan_sept_gridmet_file = '/home/shawn_preston/gddcdesm2jansept/testtestgridmet_trends_gddgg_analysis.nc'

tmax_cesm2_file = '/home/shawn_preston/tmax34ensemble/testtesttrend_days_above_threshold_analysis.nc'
tmax_gridmet_file = '/home/shawn_preston/tmax34ensemble/testtestgridmet_trends_days_above_threshold_analysis.nc'

tmin_cesm2_file = '/home/shawn_preston/tmin15ensemble/testtesttrend_days_above_threshold_analysis.nc'
tmin_gridmet_file = '/home/shawn_preston/tmin15ensemble/testtestgridmet_trends_days_above_threshold_analysis.nc'

# Set up the figure with 2 rows and 3 columns for subplots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot each trend on a different subplot
plot_trend(axes[0, 0], cdd_cesm2_file, cdd_gridmet_file, 'CDD (November-March)')
plot_trend(axes[0, 1], last_frost_cesm2_file, last_frost_gridmet_file, 'Last Frost Date (January-July)')
plot_trend(axes[0, 2], gdd_jan_apr_cesm2_file, gdd_jan_apr_gridmet_file, 'GDD Bud Break (January-April)')
plot_trend(axes[1, 0], gdd_jan_sept_cesm2_file, gdd_jan_sept_gridmet_file, 'GDD General Growth (January-September)')
plot_trend(axes[1, 1], tmax_cesm2_file, tmax_gridmet_file, 'Extreme Heat Days (June-August)')
plot_trend(axes[1, 2], tmin_cesm2_file, tmin_gridmet_file, 'Warm Nights (August-September)')

# Add a single legend for all subplots
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, fontsize=12, title_fontsize=14, loc='lower center', bbox_to_anchor=(0.5, 0.03), ncol=3)

# Add a title for the entire figure
fig.suptitle('Comparison of Observed and Modeled Climate Trends in the United States (1980-2022)', fontsize=20, weight='bold', y=0.95)

# Adjust layout and save the figure
plt.tight_layout(rect=[0, 0.08, 1, 0.95])
plt.savefig('/home/shawn_preston/CESM2PAPERFIGURES/Fig2updated.pdf', format='pdf', dpi=600, bbox_inches='tight')
plt.show()
