In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import geopandas as gpd
import regionmask
from scipy.stats import linregress
import warnings

# Ignore warnings to keep the output clean
warnings.filterwarnings('ignore')

# Provided latitudes and longitudes
latitudes = np.array([
    25.91623037, 26.85863874, 27.80104712, 28.7434555, 29.68586387,
    30.62827225, 31.57068063, 32.51308901, 33.45549738, 34.39790576,
    35.34031414, 36.28272251, 37.22513089, 38.16753927, 39.10994764,
    40.05235602, 40.9947644, 41.93717277, 42.87958115, 43.82198953,
    44.76439791, 45.70680628, 46.64921466, 47.59162304, 48.53403141,
    49.47643979
])
longitudes = np.array([
    235., 236.25, 237.5, 238.75, 240., 241.25, 242.5, 243.75, 245., 246.25,
    247.5, 248.75, 250., 251.25, 252.5, 253.75, 255., 256.25, 257.5, 258.75,
    260., 261.25, 262.5, 263.75, 265., 266.25, 267.5, 268.75, 270., 271.25,
    272.5, 273.75, 275., 276.25, 277.5, 278.75, 280., 281.25, 282.5, 283.75,
    285., 286.25, 287.5, 288.75, 290., 291.25, 292.5, 293.75, 295.
])

def earth_radius(lat):
    """Calculate the Earth's radius at a given latitude."""
    a = 6378137  # semi-major axis in meters
    b = 6356752.3142  # semi-minor axis in meters
    e2 = 1 - (b**2 / a**2)  # square of eccentricity
    lat_rad = np.deg2rad(lat)  # Convert from degrees to radians
    lat_gc = np.arctan((1 - e2) * np.tan(lat_rad))  # Convert from geodetic to geocentric
    r = (a * (1 - e2)**0.5) / (1 - (e2 * np.cos(lat_gc)**2))**0.5  # Calculate Earth's radius at given latitude
    return r

def area_grid(lat, lon):
    """Calculate the area of each grid cell."""
    lon, lat = np.meshgrid(lon, lat)  # Create a meshgrid for unique lat-lon pairs
    R = earth_radius(lat)  # Calculate Earth's radius for each latitude
    dlat = np.deg2rad(np.gradient(lat, axis=0))  # Gradient of latitude in radians
    dlon = np.deg2rad(np.gradient(lon, axis=1))  # Gradient of longitude in radians
    dy = dlat * R  # Change in y direction (latitude)
    dx = dlon * R * np.cos(np.deg2rad(lat))  # Change in x direction (longitude), adjusted for latitude
    area = dy * dx  # Calculate the area of each grid cell
    return area

def prepare_data(latitudes, longitudes, shapefile_path, data_path, variable_name, land_mask_variable, region, time_coord):
    """Prepare data for analysis."""
    area = area_grid(latitudes, longitudes)  # Calculate grid cell areas
    total_area = np.sum(area)  # Total area of all grid cells
    num_grid_cells = area.size  # Number of grid cells
    average_area_per_cell = total_area / num_grid_cells  # Average area per grid cell
    normalized_areas = area / average_area_per_cell  # Normalize areas

    data = xr.open_dataset(data_path)  # Load data from NetCDF file
    gdf = gpd.read_file(shapefile_path)  # Load region shapefile
    region_df = gdf[gdf['RegionName'].str.contains(region)]  # Filter regions by name

    lon_1d = data['lon'].values  # Extract 1D longitude values
    lat_1d = data['lat'].values  # Extract 1D latitude values
    mask = regionmask.mask_geopandas(region_df, lon_1d, lat_1d)  # Create a region mask
    land_mask = mask.notnull()  # Determine land mask

    masked_trends = data[variable_name].where(land_mask, drop=True)  # Apply mask to trends
    normalized_areas_reshaped = normalized_areas[:masked_trends.shape[1], :masked_trends.shape[2]].reshape(1, masked_trends.shape[1], masked_trends.shape[2])
    masked_trends_weighted = masked_trends * normalized_areas_reshaped  # Apply normalized areas

    mean_ensemble_trends = masked_trends.mean(dim=['lat', 'lon'])  # Calculate mean trends
    mean_gdd = data[land_mask_variable].where(land_mask, drop=True).mean(dim='ensemble').mean(dim=['lat', 'lon'])

    max_trend_ensemble_id = mean_ensemble_trends.argmax(dim='ensemble').item()  # Find max trend ensemble ID
    min_trend_ensemble_id = mean_ensemble_trends.argmin(dim='ensemble').item()  # Find min trend ensemble ID

    yearly_max_trend = data[land_mask_variable].isel(ensemble=max_trend_ensemble_id).where(land_mask, drop=True).mean(dim=['lat', 'lon'])
    yearly_min_trend = data[land_mask_variable].isel(ensemble=min_trend_ensemble_id).where(land_mask, drop=True).mean(dim=['lat', 'lon'])

    yearly_max_trend = yearly_max_trend.rename({time_coord: 'year'})  # Rename coordinate for rolling operation
    yearly_min_trend = yearly_min_trend.rename({time_coord: 'year'})
    mean_gdd = mean_gdd.rename({time_coord: 'year'})

    return data, mean_ensemble_trends, mean_gdd, yearly_max_trend, yearly_min_trend, max_trend_ensemble_id, min_trend_ensemble_id, masked_trends, region_df, data['year'].values, land_mask

def plot_min_trends(metrics, latitudes, longitudes, shapefile_path, plot_path):
    """Plot minimum climate extreme trends for each metric."""
    fig, axs = plt.subplots(2, 3, figsize=(18, 12), subplot_kw={'projection': ccrs.PlateCarree()})  # Create subplots
    axs = axs.flatten()  # Flatten the axis array
    
    letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'  # Letters for subplot labels
    letter_idx = 0

    for idx, metric in enumerate(metrics):
        data_path, variable_name, land_mask_variable, ylabel, colorbar_label, cmap, levels, time_coord = metric
        data, mean_ensemble_trends, mean_gdd, yearly_max_trend, yearly_min_trend, max_trend_ensemble_id, min_trend_ensemble_id, masked_trends, region_df, years, land_mask = prepare_data(
            latitudes, longitudes, shapefile_path, data_path, variable_name, land_mask_variable, 'Northwest|Southwest', time_coord
        )

        bounds = region_df.total_bounds  # Get region boundaries
        north_extend = bounds[3] + 0.1  # Extend north boundary slightly
        x0, x1, y0, y1 = bounds[0], bounds[2] - 8.88, bounds[1] + 10.6, north_extend  # Define plot extent

        ax = axs[idx]  # Select current axis
        ax.add_feature(cfeature.COASTLINE)  # Add coastlines
        ax.add_feature(cfeature.STATES, linestyle='-')  # Add state borders
        plot = masked_trends.isel(ensemble=min_trend_ensemble_id).plot(
            ax=ax,
            transform=ccrs.PlateCarree(),
            cmap=cmap,
            levels=levels,
            add_colorbar=True,
            cbar_kwargs={
                'orientation': 'horizontal',
                'shrink': 0.87,
                'pad': 0.05,
                'ticks': levels[::3],
                'label': colorbar_label
            }
        )
        cbar = plot.colorbar
        cbar.set_label(colorbar_label, fontsize=16)  # Set colorbar label
        cbar.ax.tick_params(labelsize=16)  # Set colorbar tick params
        region_df.boundary.plot(ax=ax, edgecolor='black', linewidth=2, transform=ccrs.PlateCarree())  # Plot region boundaries
        ax.set_extent([x0, x1, y0, y1], crs=ccrs.PlateCarree())  # Set plot extent
        ax.text(-0.1, .95, f'{letters[letter_idx]}', transform=ax.transAxes, fontsize=16, fontweight='bold', va='bottom', ha='right')  # Add subplot label
        ax.text(0.95, 0.90, f"{min_trend_ensemble_id}", transform=ax.transAxes, fontsize=16, fontweight='bold', ha='right')  # Add ensemble ID label
        ax.set_title(ylabel, fontsize=16, fontweight='bold')  # Set subplot title
        letter_idx += 1  # Increment letter index

        # Add stippling for significant p-values
        p_values = data['ensemble_p_values'].isel(ensemble=min_trend_ensemble_id).where(land_mask, drop=True)
        significant = (p_values < 0.05).values
        lon, lat = np.meshgrid(masked_trends.lon, masked_trends.lat)
        ax.scatter(lon[significant], lat[significant], color='k', marker='.', s=9, transform=ccrs.PlateCarree())
        
        gl = ax.gridlines(draw_labels=True, color='gray', alpha=0.5)  # Add gridlines
        gl.top_labels = False  # Hide top labels
        gl.right_labels = False  # Hide right labels
        gl.xlines = False  # Hide x lines
        gl.ylines = False  # Hide y lines
        gl.xlocator = plt.MultipleLocator(1)  # Set x locator
        gl.ylocator = plt.MultipleLocator(1)  # Set y locator
        gl.xlabel_style = {'size': 8}  # Set x label style
        gl.ylabel_style = {'size': 8}  # Set y label style

    plt.subplots_adjust(hspace=-0.1)  # Adjust subplot spacing
    plt.suptitle('Minimum Climate Extreme Trends for Northwest (1980-2064)', fontsize=26, fontweight='bold', y=.84)  # Set supertitle
    plt.savefig(plot_path, format='png', dpi=600, bbox_inches='tight')  # Save figure
    plt.show()  # Show figure

metrics = [
    ('/home/shawn_preston/daysbelow0ensemble/cdd_trend_analysis.nc', 'ensemble_trends', 'ensemble_cdd_yearly', 'Cold Degree Days', 'CDD/Yr', 'RdBu', np.arange(-5, 6, 1), 'time'),
    ('/home/shawn_preston/lastdayfrostensemble/trend_last_frost_dates_ensemble.nc', 'ensemble_trends', 'ensemble_last_frost_dates_yearly', 'Days Since Jan. 1', 'Frost Day/Yr', 'RdBu', np.arange(-1, 1.1, .2), 'year'),
    ('/home/shawn_preston/gddcesm2janapr/gdd_trends_analysis.nc', 'ensemble_trends', 'ensemble_gdd_yearly', 'GDD Bud Break', 'GDD/Yr', 'RdBu_r', np.arange(-5, 6, 1), 'year'),
    ('/home/shawn_preston/gddcdesm2jansept/gdd_trends_analysis.nc', 'ensemble_trends', 'ensemble_gdd_yearly', 'GDD General Growth', 'GDD/Yr', 'RdBu_r', np.arange(-20, 20, 2), 'year'),
    ('/home/shawn_preston/tmax34ensemble/trend_days_above_threshold_analysis.nc', 'ensemble_trends', 'ensemble_days_above_yearly', 'Extreme Heat Days', 'Extreme Heat Days/Yr', 'RdBu_r', np.arange(-.5, .51, .1), 'year'),
    ('/home/shawn_preston/tmin15ensemble/trend_days_above_threshold_analysis.nc', 'ensemble_trends', 'ensemble_days_above_yearly', 'Warm Nights', 'Warm Nights/Yr', 'RdBu_r', np.arange(-.5, .51, .1), 'year')
]

plot_min_trends(metrics, latitudes, longitudes, '/home/shawn_preston/NCA/NCA_Regions.shp', '/home/shawn_preston/CESM2PAPERFIGURES/Min_Trends.png')
