# SAT Calculations #

In [1]:
%load_ext autoreload
%autoreload 2
    
import os
import sys
from dask.distributed import Client
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

sys.path.append('/home/563/sc1326/repos/cdrmip_extremes')
from cdrmip_extremes.configs import data_dir, models, expts
from cdrmip_extremes import load_data, sat, utils

In [2]:
client = Client(memory_limit=None,threads_per_worker=1,n_workers=28)

## Load data ##

In [3]:
# load tas_anom
tas_anom = load_data.load_tas_anom()

In [4]:
gwl_years = load_data.load_gwl_years()

In [5]:
model='NorESM2-LM'

In [6]:
gwl_years[model].values

array([[ 70, 214],
       [ 94, 197],
       [127, 168]])

### Extract mean temperatures over GWL periods (21-year average centred on time of GWL crossing) ###

In [7]:
gwl_periods = {}
for model, ds in tas_anom.items():
    gwl_periods[model] = utils.extract_gwl_period(
        ds,
        gwl_years[model],
        21,
        time_dim='time'
    ) 

### Extract mean temperatures over final 21-year period and equivalent ramp-up GWL ###

In [8]:
# load
match_ds = load_data.load_equiv_gwls()

In [9]:
equiv_gwl_periods = {}
for model, ds in tas_anom.items():
    match = match_ds.sel(model=model)
    final_gwl,exceed_year = match.tas.values, match.year.values
    equiv_gwl_periods[model] = utils.extract_equiv_gwl_period(
        ds,
        final_gwl,
        exceed_year,
        window=21,
        time_dim='time'
    )

In [10]:
equiv_gwl_periods['ACCESS-ESM1-5'].sel(branch='ramp_up').dropna(dim='year').year

### Calculate differences between ramp-up and ramp-down periods ###

In [11]:
slices_15 = {}
gwl=1.5
for model, ds in gwl_periods.items():
    slices_15[model] = utils.compare_gwl_means(ds.sel(gwl=gwl))

slices_equiv = {}
for model, ds in equiv_gwl_periods.items():
    slices_equiv[model] = utils.compare_gwl_means(ds)

### Save both extracted tas periods ###

In [12]:
# save
save_dir = os.path.join(data_dir,'processed/tas/tas_15')
for model, ds_dict in slices_15.items():
    for period, ds in ds_dict.items():
        path = os.path.join(
            save_dir,
            f"{model}_tas_15_{period}.nc"
        )
        # ds.to_netcdf(path)

In [14]:
# save
save_dir = os.path.join(data_dir,'processed/tas/tas_equiv_gwl')
for model, ds_dict in slices_equiv.items():
    for period, ds in ds_dict.items():
        path = os.path.join(
            save_dir,
            f"{model}_tas_equiv_gwl_{period}.nc"
        )
        ds.to_netcdf(path)