In [None]:
import pandas as pd
import xarray as xr
import regionmask
import geopandas as gpd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load the shapefile containing region boundaries
gdf = gpd.read_file('/home/shawn_preston/NCA/NCA_Regions.shp')

# Define regions and a mapping from long names to short names
regions = ['Northeast', 'Southeast', 'Midwest', 'Northern Great Plains', 'Southern Plains', 'Southwest', 'Northwest']
region_name_mapping = {
    'Northeast': 'Northeast',
    'Southeast': 'Southeast',
    'Midwest': 'Midwest',
    'Northern Great Plains': 'Northern G. P.',
    'Southern Plains': 'Southern Plains',
    'Southwest': 'Southwest',
    'Northwest': 'Northwest'
}

# File paths and variable names for each metric
file_info = {
    'daysbelow0': {
        '1991_2020': '/home/shawn_preston/weightedcesm2/daysbelow0_1991_2020_weighted.nc',
        '2031_2060': '/home/shawn_preston/weightedcesm2/daysbelow0_2031_2060_weighted.nc',
        'var_name_1991_2020': 'days_below_0_ensemble_1991_2020',
        'var_name_2031_2060': 'days_below_0_ensemble_2031_2060'
    },
    'frost': {
        '1991_2020': '/home/shawn_preston/weightedcesm2/frost_1991_2020_weighted.nc',
        '2031_2060': '/home/shawn_preston/weightedcesm2/frost_2031_2060_weighted.nc',
        'var_name_1991_2020': 'last_frost_dates_ensemble',
        'var_name_2031_2060': 'last_frost_dates_ensemble'
    },
    'gddapr': {
        '1991_2020': '/home/shawn_preston/weightedcesm2/gddapr_1991_2020_weighted.nc',
        '2031_2060': '/home/shawn_preston/weightedcesm2/gddapr_2031_2060_weighted.nc',
        'var_name_1991_2020': 'gdd_ensemble_1991_2020',
        'var_name_2031_2060': 'gdd_ensemble_2031_2060'
    },
    'gddsep': {
        '1991_2020': '/home/shawn_preston/weightedcesm2/gddsep_1991_2020_weighted.nc',
        '2031_2060': '/home/shawn_preston/weightedcesm2/gddsep_2031_2060_weighted.nc',
        'var_name_1991_2020': 'gdd_ensemble_1991_2020',
        'var_name_2031_2060': 'gdd_ensemble_2031_2060'
    },
    'tmax': {
        '1991_2020': '/home/shawn_preston/weightedcesm2/tmax_1991_2020_weighted.nc',
        '2031_2060': '/home/shawn_preston/weightedcesm2/tmax_2031_2060_weighted.nc',
        'var_name_1991_2020': 'heat_ensemble_1991_2020',
        'var_name_2031_2060': 'heat_ensemble_2031_2060'
    },
    'tmin': {
        '1991_2020': '/home/shawn_preston/weightedcesm2/tmin_1991_2020_weighted.nc',
        '2031_2060': '/home/shawn_preston/weightedcesm2/tmin_2031_2060_weighted.nc',
        'var_name_1991_2020': 'night_ensemble_1991_2020',
        'var_name_2031_2060': 'night_ensemble_2031_2060'
    }
}

# Setup the figure for plotting in a 2x3 grid
fig, axs = plt.subplots(2, 3, figsize=(36, 24))
axs = axs.flatten()  # Flatten the axis array for easier indexing

# Mapping from metrics to subplot indices
metric_order = ['daysbelow0', 'frost', 'gddapr', 'gddsep', 'tmax', 'tmin']
ylabels_order = ['Difference in Accumulation Days', 'Difference in Days', 'Difference in Accumulation Days', 'Difference in Accumulation Days', 'Difference in Days Above 34°C', 'Difference in Days Above 15°C']
subplot_labels = ['A', 'B', 'C', 'D', 'E', 'F']
subplot_titles = ['Cold Degree Days (Nov.-Mar.)', 'Last Day of Spring Frost(Jan.-Jul.)', 'GDD Bud Break(Jan.-Apr.)', 'GDD General Growth(Jan.-Jul.)', 'Extreme Heat Days(Jun.-Aug.)', 'Warm Nights(Aug.-Sep.)']

# Initialize a list to store the statistics
stats_list = []

# Process each metric
for idx, metric in enumerate(metric_order):
    paths = file_info[metric]
    ds_1991_2020 = xr.open_dataset(paths['1991_2020'])[paths['var_name_1991_2020']]
    ds_2031_2060 = xr.open_dataset(paths['2031_2060'])[paths['var_name_2031_2060']]

    # Special handling for frost to average over the year before computing delta
    if metric == 'frost':
        ds_1991_2020 = ds_1991_2020.mean(dim='year')
        ds_2031_2060 = ds_2031_2060.mean(dim='year')
    
    if metric in ['tmin', 'tmax']:
        # Divide the values by 30 for these metrics
        ds_1991_2020 = ds_1991_2020 / 30
        ds_2031_2060 = ds_2031_2060 / 30

    # Calculate the difference (delta) between future and historical periods
    delta = ds_2031_2060 - ds_1991_2020
    region_data = []
    region_means = []
    region_medians = []

    for region in regions:
        mask = regionmask.mask_geopandas(gdf[gdf['RegionName'] == region], ds_1991_2020.lon, ds_1991_2020.lat)
        masked_data = delta.where(mask.notnull(), drop=True).mean(dim=('lat', 'lon'))
        region_data.append(masked_data.values)

        # Calculate mean and median for the region
        region_mean = np.mean(masked_data.values)
        region_median = np.median(masked_data.values)
        region_means.append(region_mean)
        region_medians.append(region_median)

        # Store the statistics in the list
        stats_list.append({'Metric': metric, 'Region': region_name_mapping[region], 'Mean': round(region_mean, 3), 'Median': round(region_median, 3)})

    # Create a boxplot for each region
    sns.boxplot(data=region_data, ax=axs[idx])
    axs[idx].set_xticklabels([region_name_mapping[region] for region in regions], rotation=45, fontweight='bold', fontsize=26)
    axs[idx].set_title(subplot_titles[idx], fontsize=30, fontweight='bold')
    axs[idx].set_ylabel(ylabels_order[idx], fontsize=40)

    # Plot mean values as stars inside the boxes
    for i, mean in enumerate(region_means):
        axs[idx].plot(i, mean, marker="*", color='gold', markersize=14)

    axs[idx].tick_params(axis='y', labelsize=26)
    axs[idx].set_ylabel(ylabels_order[idx], rotation=90, labelpad=15, fontsize=20, fontweight='bold')
    axs[idx].grid(False)

    # Annotate subplot labels in the bottom left
    axs[idx].annotate(subplot_labels[idx], xy=(0.03, 0.03), xycoords='axes fraction', fontsize=28, fontweight='bold', color='black')

plt.tight_layout()
plt.suptitle('Fig. 5 Area-Weighted Average Absolute Change in Climate Metrics (2031-2060 vs 1991-2020)', ha='left', x=0.02, y=1.04, fontsize=40, fontweight='bold')
plt.savefig('/home/shawn_preston/CESM2PAPERFIGURES/FIGURE5.png', dpi=300, bbox_inches='tight')
plt.show()

# Convert the list to a DataFrame
stats_df = pd.DataFrame(stats_list)

# Pivot the DataFrame to get a table with metrics as rows and regions as columns
mean_df = stats_df.pivot(index='Metric', columns='Region', values='Mean')
median_df = stats_df.pivot(index='Metric', columns='Region', values='Median')

# Create a combined DataFrame for both Mean and Median
combined_df = pd.concat([mean_df, median_df], keys=['Mean', 'Median']).reset_index()

# Create and format the table
fig, ax = plt.subplots(figsize=(20, 12))
ax.axis('tight')
ax.axis('off')

# Create a table
table_data = combined_df.values
table_columns = combined_df.columns
table = ax.table(cellText=table_data, colLabels=table_columns, cellLoc='center', loc='center')

# Enhance table appearance
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.5, 1.5)

# Set font sizes for header rows
for key, cell in table.get_celld().items():
    if key[0] == 0:
        cell.set_fontsize(14)
    if key[0] in [0, 1]:  # Merge headers
        cell.set_fontsize(12)
        cell.set_facecolor('#D3D3D3')

plt.title('Mean and Median Values of Climate Metrics by Region', fontsize=18, fontweight='bold')
plt.savefig('/home/shawn_preston/CESM2PAPERFIGURES/Climate_Metrics_Table.png', dpi=300, bbox_inches='tight')
plt.show()
