In [22]:
%load_ext autoreload
%autoreload 2
    
import os
import sys
import xarray as xr
import pandas as pd
import netCDF4 as nc
import numpy as np
from dask.distributed import Client
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

sys.path.append('/home/563/sc1326/repos/cdrmip_extremes')
from cdrmip_extremes.configs import data_dir, models, expts
from cdrmip_extremes import load_data, ext_freq
from cdrmip_extremes.plotting import plot_extremes

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
client = Client(memory_limit=None,threads_per_worker=1,n_workers=28)

## Load extreme frequency data ##

In [12]:
ext_freq_data = load_data.load_ext_freq_data()

In [15]:
heat_freq = {
    model: ds_dict['heat_exceedances'] for model, ds_dict in ext_freq_data.items()
}
cold_freq = {
    model: ds_dict['cold_exceedances'] for model, ds_dict in ext_freq_data.items()
}

### Calculate differences between GWLs and model agreement ###

In [23]:
heat_differences, heat_agreement = ext_freq.calc_gwl_differences(heat_freq)
cold_differences, cold_agreement = ext_freq.calc_gwl_differences(cold_freq)

## Plot multi-model median ##

In [None]:
def plot_heat_freq(exceedance,gwl,model,agreement,levels,levels_diff):
    # extract data to plot
    to_plot = {}
    thresholds = ['sigma1','sigma2','sigma3']
    threshold_symbs = ['1$\sigma$','2$\sigma$','3$\sigma$']
    for threshold in thresholds:
        ramp_up = exceedance[threshold].sel(gwl=gwl).sel(simulation='ramp_up')
        ramp_down = exceedance[threshold].sel(gwl=gwl).sel(simulation='ramp_down')
        to_plot[threshold] = {'ramp_up':ramp_up,
                              'ramp_down':ramp_down,
                              'difference':ramp_down-ramp_up
                             }
    
    fig, axes = plt.subplots(3,
                             3,
                             subplot_kw={"projection":ccrs.Robinson()},
                             figsize=(16,9),
                             sharey=True,
                            )
    cbar_kwargs = {"fraction": 0.06, "aspect": 25,
                    "label":"Frequency of Occurrence (% of years within window)",'orientation':'horizontal',
                    'location':'bottom','pad':0.04,
                   }
    cbar_kwargs_diff = {"fraction": 0.06, "aspect": 25,
                    "label":"Change in %",'orientation':'horizontal',
                    'location':'bottom','pad':0.04,
                   } 
    
    norm1 = BoundaryNorm(levels, ncolors=plt.cm.YlOrRd.N, clip=True)
    norm2 = BoundaryNorm(levels_diff, ncolors=plt.cm.RdBu_r.N, clip=True)
    
    for row_index, threshold in enumerate(thresholds):
        for col_index, data in enumerate(list(to_plot[threshold].keys())):
            ax = axes[row_index,col_index]
            if data=='difference':
                norm=norm2
                cmap='RdBu_r'
            else:
                norm=norm1
                cmap='YlOrRd'
            im = ax.imshow(to_plot[threshold][data],
                           transform=ccrs.PlateCarree(),
                           cmap=cmap,
                           origin='lower',
                           extent=(0,360,-90,90),
                           norm=norm
                          )
            ax.coastlines()
            # ax.gridlines(draw_labels=False)
        axes[row_index, 0].annotate(threshold_symbs[row_index], xy=(-0.025, 0.5), xycoords="axes fraction",
                ha="right", va="center", fontsize=20, rotation=0,)
    
    # Add colorbars
    # Colorbar for `norm1` across the first two columns
    cbar_ax1 = fig.add_axes([0.057, -0.02, 0.6,0.025])  # [left, bottom, width, height]
    cbar1 = fig.colorbar(
        plt.cm.ScalarMappable(norm=norm1, cmap='YlOrRd'), 
        cax=cbar_ax1, 
        orientation='horizontal',
        label="% Years",
    )
    cbar1.ax.tick_params(labelsize=14)  # Adjust tick label size
    cbar1.set_label("Frequency of occurrence (% of years within window)", fontsize=16)  # Adjust label font size
    
    # Colorbar for `norm2` across the last column
    cbar_ax2 = fig.add_axes([0.7, -0.02, 0.265,0.025])  # [left, bottom, width, height]
    cbar2 = fig.colorbar(
        plt.cm.ScalarMappable(norm=norm2, cmap='RdBu_r'), 
        cax=cbar_ax2, 
        orientation='horizontal',
        label="Change in %",
        extend='both',
    )
    cbar2.ax.tick_params(labelsize=14)  # Adjust tick label size
    cbar2.set_label("Change in frequency (difference in % of years)", fontsize=16)  # Adjust label font size
    
    axes[0][0].set_title('Ramp Up',fontsize=20)
    axes[0][1].set_title('Ramp Down',fontsize=20)
    axes[0][2].set_title('Ramp Down - Ramp Up',fontsize=20)

    # plot stipling
    for index, threshold in enumerate(thresholds):
        lon = np.array(agreement.lon)
        lat = np.array(agreement.lat)
        hatch_data = agreement.sel(gwl=gwl)[threshold]
        hatch_data, lon = add_cyclic_point(hatch_data,coord=lon)
        hatch_color = (0,0,0,0.5)
    
        # Plot the first xarray with hatching pattern
        contour = axes[index][2].contourf(
            lon,
            lat,
            hatch_data,
            transform=ccrs.PlateCarree(),
            levels=50,
            colors=[hatch_color],
            hatches=['/////'],
            alpha=0)

    # Add subplot annotations (a), (b), (c) near titles but aligned to left
    column_labels = ['(a)', '(b)', '(c)']
    for col_index, label in enumerate(column_labels):
        axes[0, col_index].text(
            0.0, 1.02,  # x=left edge, y=same level as title
            label,
            transform=axes[0, col_index].transAxes,
            fontsize=20,
            # fontweight='bold',
            va='bottom',
            ha='left',
        )
        
    fig.suptitle(f'{model} Change in Heat Extreme Frequency at {gwl}Â°C',y=1.01,fontsize=25)
    fig.tight_layout()
    plt.show()