In [None]:
import xarray as xr
import geopandas as gpd
import regionmask
import matplotlib.pyplot as plt
import numpy as np
import cartopy.crs as ccrs
import cartopy.feature as cfeature

# Load the shapefile for the region of interest
gdf = gpd.read_file('/home/shawn_preston/shapefile_nation/combined_conus.shp')

# File paths for the datasets
ds_cdd_91 = xr.open_dataset('/home/shawn_preston/daysbelow0ensemble/tavgdaysbelow0_1991_2020.nc')
ds_gddapr_91 = xr.open_dataset('/home/shawn_preston/gddcesm2janapr/GDD_1991_2020.nc')
ds_gddsep_91 = xr.open_dataset('/home/shawn_preston/gddcdesm2jansept/GDD_1991_2020.nc')
ds_frost_91 = xr.open_dataset('/home/shawn_preston/lastdayfrostensemble/last_frost_1991_2020.nc')
ds_tmin_91 = xr.open_dataset('/home/shawn_preston/tmin15ensemble/tmin1591_20.nc')
ds_tmax_91 = xr.open_dataset('/home/shawn_preston/tmax34ensemble/tmax1591_20.nc')

ds_cdd_30 = xr.open_dataset('/home/shawn_preston/daysbelow0ensemble/tavgdaysbelow0_2031_2060.nc')
ds_gddapr_30 = xr.open_dataset('/home/shawn_preston/gddcesm2janapr/GDD_2031_2060.nc')
ds_gddsep_30 = xr.open_dataset('/home/shawn_preston/gddcdesm2jansept/GDD_2031_2060.nc')
ds_frost_30 = xr.open_dataset('/home/shawn_preston/lastdayfrostensemble/last_frost_2031_2060.nc')
ds_tmin_30 = xr.open_dataset('/home/shawn_preston/tmin15ensemble/tmin1531_60.nc')
ds_tmax_30 = xr.open_dataset('/home/shawn_preston/tmax34ensemble/tmax1531_60.nc')

# Calculate the change (delta) for each metric between future and historical periods
delta_days_below_0_ensemble_mean = ds_cdd_30['days_below_0_ensemble_mean_2031_2060'] - ds_cdd_91['days_below_0_ensemble_mean_1991_2020']
delta_gdd_jan_apr_ensemble_mean = ds_gddapr_30['gdd_ensemble_mean_2031_2060'] - ds_gddapr_91['gdd_ensemble_mean_1991_2020']
delta_gdd_jan_sept_ensemble_mean = ds_gddsep_30['gdd_ensemble_mean_2031_2060'] - ds_gddsep_91['gdd_ensemble_mean_1991_2020']
delta_last_frost_dates_ensemble_mean = ds_frost_30['last_frost_dates_ensemble_mean'].mean(dim='year') - ds_frost_91['last_frost_dates_ensemble_mean'].mean(dim='year')
delta_days_above_15_ensemble_mean = ds_tmin_30['night_ensemble_mean_2031_2060'] - ds_tmin_91['night_ensemble_mean_1991_2020']
delta_days_above_34_ensemble_mean = ds_tmax_30['heat_ensemble_mean_2031_2060'] - ds_tmax_91['heat_ensemble_mean_1991_2020']

# Set up the subplots for each metric
fig, axes = plt.subplots(2, 3, figsize=(30, 20), subplot_kw={'projection': ccrs.PlateCarree()})
axes = axes.flatten()

# Define colormaps for each metric
cmaps = ['RdBu', 'RdBu', 'RdBu_r', 'RdBu_r', 'RdBu_r', 'RdBu_r']

# Define the delta variables, their corresponding titles, and levels
delta_variables = [
    (delta_days_below_0_ensemble_mean, 'CDD (November-March)', np.linspace(-300, 300, 21)),
    (delta_last_frost_dates_ensemble_mean, 'Last Day of Spring Frost (January-July)', np.linspace(-15, 15, 21)),
    (delta_gdd_jan_apr_ensemble_mean, 'GDD Bud Break (January-April)', np.linspace(-200, 200, 21)),
    (delta_gdd_jan_sept_ensemble_mean, 'GDD General Growth (January-September)', np.linspace(-500, 500, 21)),
    (delta_days_above_34_ensemble_mean, 'Extreme Heat Days (June-August)', np.linspace(-30, 30, 21)),
    (delta_days_above_15_ensemble_mean, 'Warm Nights (August-September)', np.linspace(-15, 15, 21))
]

cbar_labels = ['Accumulation Days', 'Days', 'Accumulation Days', 'Accumulation Days', 'Days Above Threshold', 'Days Above Threshold']
letters = 'ABCDEF'

# Plot each delta variable
for i, (delta, title, levels) in enumerate(delta_variables):
    # Create a region mask using the shapefile
    mask = regionmask.mask_geopandas(gdf, delta['lon'].values, delta['lat'].values)

    # Apply the mask to the delta variable
    masked_data = delta.where(mask.notnull())

    # Set up the map with Cartopy
    ax = axes[i]
    ax.set_extent([-124.8, -67.06, 25.07, 50], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.STATES, linestyle='-')
    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    gl.xlabel_style = {'size': 12, 'color': 'gray'}
    gl.ylabel_style = {'size': 12, 'color': 'gray'}
    gl.xlines = False
    gl.ylines = False

    # Plot the delta variable
    contour = masked_data.plot(ax=ax, transform=ccrs.PlateCarree(), cmap=cmaps[i], levels=levels, add_colorbar=False)
    cbar = plt.colorbar(contour, ax=ax, orientation='horizontal', pad=0.03, aspect=30, shrink=1)
    cbar.set_ticks(levels[::2])
    cbar.set_label(cbar_labels[i], fontweight='bold', fontsize=14)
    ax.set_title(title, fontsize=18, fontweight='bold')
    cbar.ax.tick_params(labelsize=16)
    ax.text(0.01, 0.01, letters[i], transform=ax.transAxes, fontsize=16, va='bottom', ha='left', fontweight='bold')

# Adjust layout and add the figure title
plt.subplots_adjust(hspace=-.6) 
plt.suptitle('Fig. 3 Absolute Change of ensemble means (2031-2060 - 1991-2020)', x=0.10, ha='left', y=0.63, fontweight='bold', fontsize=28)

# Show the plot
plt.show()

# Uncomment to save the figure
#plt.savefig('/home/shawn_preston/CESM2PAPERFIGURES/Figure3.png', dpi=300, bbox_inches='tight')
#plt.savefig('/home/shawn_preston/CESM2PAPERFIGURES/Figure3.pdf', format='pdf', dpi=600, bbox_inches='tight')
