In [1]:
%load_ext autoreload
%autoreload 2
    
import os
import sys
import xarray as xr
import pandas as pd
import netCDF4 as nc
import numpy as np
from dask.distributed import Client
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

sys.path.append('/home/563/sc1326/repos/cdrmip_extremes')
from cdrmip_extremes.configs import data_dir, models, expts
from cdrmip_extremes import load_data, ext_freq, utils
from cdrmip_extremes.plotting import plot_extremes

In [2]:
client = Client(memory_limit=None,threads_per_worker=1,n_workers=28)

### Load concatenated temperature data & pre-calculated GWL crossing years ###

In [3]:
tas_data = load_data.load_tas_concat()

In [4]:
gwl_years = load_data.load_gwl_years()

### Load thresholds & extreme data ###

In [5]:
thresholds = load_data.load_threshold_data()

### Select hottest and coldest months for cdr-reversibility experiment ###

In [6]:
extreme_months_dict = load_data.load_monthly_extreme_data(ext_vars=['extreme_months'])
extreme_months = {
    model:
    ds_dict['extreme_months']
    for model, ds_dict in extreme_months_dict.items()
}

#### Rename 'max' 'min' coordinates to calculate 'cold' and 'heat' extremes ####

In [7]:
for model, da in extreme_months.items():
    new_da = da.assign_coords(
        extrema=xr.where(
            da.extrema == 'max', 'heat',
            xr.where(da.extrema == 'min', 'cold', da.extrema)
        )
    )
    extreme_months[model] = new_da


### Index data by GWL period ###

In [8]:
ext_month_tas = {}
for model, ds in tas_data.items():
    ext_months = extreme_months[model]
    ext_month_tas[model] = ext_freq.select_extreme_month(ds,ext_months)

In [9]:
gwl_periods = {}
for model, ds in ext_month_tas.items():
    gwl_periods[model] = utils.extract_gwl_period(
        ds,
        gwl_years[model],
        21,
        time_dim='year'
    ) 

### Repeat for end-of-simulation equivalent GWLs ###

In [10]:
match_ds = load_data.load_equiv_gwls()

In [11]:
equiv_gwl_periods = {}
for model, ds in ext_month_tas.items():
    match = match_ds.sel(model=model)
    final_gwl,exceed_year = match.tas.values, match.year.values
    equiv_gwl_periods[model] = utils.extract_equiv_gwl_period(
        ds,
        final_gwl,
        exceed_year,
        window=21,
        time_dim='year'
    )

## Calculate frequency of extreme threshold exceedance ##

#### Firstly for 1.5, 2 and 3 degree GWLs ####

In [12]:
exceedances = {model:{} for model in models}
for model in models:
    exceedances[model]['heat_exceedances'] = ext_freq.calculate_exceedances(
        gwl_periods[model].tas,
        thresholds[model]['heat_thresholds'],
        ext_type='heat'
    )
    exceedances[model]['cold_exceedances'] = ext_freq.calculate_exceedances(
        gwl_periods[model].tas,
        thresholds[model]['cold_thresholds'],
        ext_type='cold'
    )

##### Calculate multi-model median #####

In [13]:
# calculate median
median_cold = (xr.concat([exceedances[model]['cold_exceedances'] for model in models],
                        dim='model',
                        compat='override',
                        coords='minimal')
               .median(dim='model')
              )
median_heat = (xr.concat([exceedances[model]['heat_exceedances'] for model in models],
                         dim='model',
                         compat='override',
                         coords='minimal'
                        )
               .median(dim='model')
              )
exceedances['median'] = {
    'heat_exceedances': median_heat,
    'cold_exceedances': median_cold
}

#### Then for final 21-year period and equivalent ramp-up GWL ####

In [14]:
exceedances_final = {model:{} for model in models}
for model in models:
    exceedances_final[model]['heat_exceedances'] = ext_freq.calculate_exceedances(
        equiv_gwl_periods[model].tas,
        thresholds[model]['heat_thresholds'],
        ext_type='heat'
    )
    exceedances_final[model]['cold_exceedances'] = ext_freq.calculate_exceedances(
        equiv_gwl_periods[model].tas,
        thresholds[model]['cold_thresholds'],
        ext_type='cold'
    )

##### Calculate multi-model median #####

In [15]:
# calculate median
median_cold = (xr.concat([exceedances_final[model]['cold_exceedances'] for model in models],
                        dim='model',
                        compat='override',
                        coords='minimal')
               .median(dim='model')
              )
median_heat = (xr.concat([exceedances_final[model]['heat_exceedances'] for model in models],
                         dim='model',
                         compat='override',
                         coords='minimal'
                        )
               .median(dim='model')
              )
exceedances_final['median'] = {
    'heat_exceedances': median_heat,
    'cold_exceedances': median_cold
}

## Save Data ##

In [16]:
save_dir = os.path.join(
    data_dir,'processed/extremes'
)
for model, ds_dict in exceedances.items():
    for var, ds in ds_dict.items():
        var_dir = os.path.join(save_dir,var)
        os.makedirs(var_dir,exist_ok=True)
        path = os.path.join(
            var_dir,
            f"{model}_{var}.nc"
        )
        ds.to_netcdf(path)

In [17]:
save_dir = os.path.join(
    data_dir,'processed/extremes'
)
for model, ds_dict in exceedances.items():
    for var, ds in ds_dict.items():
        var_dir = os.path.join(save_dir,f"{var}_final")
        os.makedirs(var_dir,exist_ok=True)
        path = os.path.join(
            var_dir,
            f"{model}_{var}_final.nc"
        )
        ds.to_netcdf(path)