# Calculate and plot model disagreement #

In [1]:
%load_ext autoreload
%autoreload 2
    
import os
import sys
import xarray as xr
import pandas as pd
import netCDF4 as nc
import numpy as np
from dask.distributed import Client
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

sys.path.append('/home/563/sc1326/repos/cdrmip_extremes')
from cdrmip_extremes.configs import data_dir, models, expts
from cdrmip_extremes import load_data, ext_freq, utils
from cdrmip_extremes.plotting import plot_extremes

In [2]:
client = Client(memory_limit=None,threads_per_worker=1,n_workers=28)

## Load extreme frequency data ##

In [3]:
ext_freq_data = load_data.load_ext_freq_data()

In [4]:
heat_freq = {
    model: ds_dict['heat_exceedances'] for model, ds_dict in ext_freq_data.items()
}
cold_freq = {
    model: ds_dict['cold_exceedances'] for model, ds_dict in ext_freq_data.items()
}

### Calculate differences between GWLs and model agreement ###

In [5]:
heat_differences, heat_agreement = ext_freq.calc_gwl_differences(heat_freq)
cold_differences, cold_agreement = ext_freq.calc_gwl_differences(cold_freq)

### Calculate standard deviation ###

In [6]:
heat_freq_all = xr.concat(list(heat_differences.values()),
                              dim='model',
                              compat='override',
                              coords='minimal'
                             )

cold_freq_all = xr.concat(list(cold_differences.values()),
                              dim='model',
                              compat='override'
                              ,coords='minimal'
                             )
# Calculate std_dev
heat_freq_std_dev = heat_freq_all.std(dim='model')
cold_freq_std_dev = cold_freq_all.std(dim='model')

## Plot ##

In [None]:
def add_colorbar(axes,norm,cmap,label,extend=False,custom_ticks=False,custom_labels=False):
    # Colorbar for `norm2` across the last column
    cbar_ax = fig.add_axes(axes)  # [left, bottom, width, height]
    cbar = fig.colorbar(
        plt.cm.ScalarMappable(norm=norm, cmap=cmap), 
        cax=cbar_ax, 
        orientation='vertical',
        label=label,
        extend=extend,
        pad=0.02
    )
    if custom_ticks:
        cbar.set_ticks(custom_ticks)
        cbar.ax.set_yticklabels(custom_labels)
        cbar.ax.tick_params(labelsize=10)  # Adjust tick label size
    cbar.set_label(label,fontsize=12)
    return cbar_ax, cbar