In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import cartopy.feature as cfeature
import matplotlib as mpl
import xarray as xr
import geopandas as gpd
import regionmask

# Load the data from NetCDF files
ds_1991_2020 = xr.open_dataset('/home/shawn_preston/tmax34ensemble/tmax1591_20.nc')
ds_2031_2060 = xr.open_dataset('/home/shawn_preston/tmax34ensemble/tmax1531_60.nc')

# Load the shapefile
gdf = gpd.read_file('/home/shawn_preston/shapefile_nation/combined_conus.shp')

# Extract longitude and latitude values
lon_1d = ds_1991_2020['lon'].values
lat_1d = ds_1991_2020['lat'].values

# Create a mask using the shapefile
mask = regionmask.mask_geopandas(gdf, lon_1d, lat_1d)

# Extract the extreme heat days for each ensemble member and normalize by 30 years
days_1991_2020_ensemble = ds_1991_2020['heat_ensemble_1991_2020'] / 30
days_2031_2060_ensemble = ds_2031_2060['heat_ensemble_2031_2060'] / 30

# Calculate the absolute change in extreme heat days for each ensemble member
delta_days_ensemble_members = np.abs(days_2031_2060_ensemble - days_1991_2020_ensemble)

# Define boundaries and ticks for the color map
boundaries = np.linspace(-30, 30, 31)
ticks = boundaries[::2]

# Create a diverging color map
cmap = plt.get_cmap('RdBu_r', len(boundaries) - 1)
norm = mpl.colors.BoundaryNorm(boundaries, cmap.N)

# Set up the figure with subplots
nrows, ncols = 10, 5
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(30, 40),
                         subplot_kw={'projection': ccrs.PlateCarree()},
                         gridspec_kw={'wspace': 0, 'hspace': 0})

# Loop through each axis and plot the data for each ensemble member
for i, ax in enumerate(axes.flat):
    if i < len(delta_days_ensemble_members):
        ax.set_extent([-126, -67, 25, 50], ccrs.PlateCarree())
        ax.add_feature(cfeature.BORDERS, linestyle='-')
        ax.add_feature(cfeature.STATES, linestyle='-')

        delta_days = delta_days_ensemble_members[i]
        delta_days_masked = np.where(mask == 0, delta_days, np.nan)

        pcm = ax.pcolormesh(lon_1d, lat_1d, delta_days_masked, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm)
        
        # Add ensemble number in the bottom left corner
        ax.text(0.05, 0.05, f'{i + 1}', transform=ax.transAxes, fontsize=16, color='black',
                ha='left', va='bottom', fontweight='bold')

# Create the colorbar
cbar_ax = fig.add_axes([0.04, 0.04, 0.92, 0.02])
cbar = fig.colorbar(pcm, cax=cbar_ax, orientation='horizontal',
                    boundaries=boundaries, ticks=ticks)
cbar.ax.tick_params(labelsize=16)
cbar.set_label('Absolute Change In Extreme Heat Days', weight='bold', fontsize=18)
for label in cbar.ax.get_xticklabels():
    label.set_fontweight('bold')

# Adjust layout and add the title
plt.subplots_adjust(left=0.04, right=0.96, top=0.94, bottom=0.08, hspace=-6, wspace=0)
plt.suptitle('Fig. 4 Absolute Change in Extreme Heat Days (JJA) for CESM2-LE Ensembles \n(2031-2060 vs. 1991-2020)',
             fontsize=40, weight='bold', y=0.99, ha='left', x=0.04)

# Save the figure
plt.savefig('/home/shawn_preston/CESM2PAPERFIGURES/Figure4.png', dpi=300, bbox_inches='tight')
plt.show()
