In [None]:
%matplotlib inline

import intake
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from scipy import stats
import warnings

import analysis_utils as au

In [None]:
import importlib
importlib.reload(ru)

#### Load dataset

In [None]:
# Load dataset
collection_fname = 'dset_dict_historical.npy'
#collection_fname = 'dset_dict_piControl.npy'

dset_dict = np.load(collection_fname, allow_pickle='TRUE').item()
first_dset = list(dset_dict.keys())[0]
models_intersect = dset_dict[first_dset].keys()

#### Regression of sea ice extent on arctic temperature for each model for multiple ensemble members

In [None]:
# Set maximum number of ensemble members to look at for each model
#N.B. doing this for anomalies give you the same answer as not anomalies. Which makes sense when you think about it.
max_ems = 5

warnings.filterwarnings('ignore')

slopes_all, r_all = {}, {}
for m in models_intersect:
    # get ensemble members
    ems = dset_dict['siconc'][m]['member_id'].values
    if len(ems)>max_ems:
        ems = ems[0:max_ems]
    print(m, len(ems))
    
    # Perform regression
    slopes_all[m], r_all[m] = {}, {}
    for i, em in enumerate(ems):
        print(em)
        slopes_all[m][i], r_all[m][i] = au.scatter_tas_SIE_linreg(
                                                        dset_dict['tas'][m]['tas_arc_mean'].sel(member_id=em),
                                                        dset_dict['siconc'][m]['sie_tot_arc'].sel(member_id=em),
                                                        [2,8], False, m)

warnings.filterwarnings('default')

In [None]:
# Calculate ensemble mean and r values for each model
slopes_mean, r_mean = {}, {}

print('Model, slopes (mar, sept), r (mar, sept)')
print()
for m in models_intersect:
    slopes_mean_temp, r_mean_temp = [], []
    for em in slopes_all[m].keys():
        slopes_mean_temp.append(slopes_all[m][em])
        r_mean_temp.append(r_all[m][em])        
        
    slopes_mean[m] = np.mean(slopes_mean_temp,0)
    r_mean[m] = np.mean(r_mean_temp,0)
    
    print(m, slopes_mean[m], r_mean[m])

#### Save

In [None]:
# Save dictionaries for future use
results_fname = 'results_' + collection_fname[10:]
save_flag = True
if save_flag:
    if dset_dict:
        np.save(results_fname, slopes_mean, r_mean)

#### Plots

In [None]:
# Plot slopes for all models
slopes_plot = np.zeros((2,len(slopes_mean.keys())))

for im, m in enumerate(slopes_mean.keys()):
    slopes_plot[0,im] = slopes_mean[m][0]
    slopes_plot[1,im] = slopes_mean[m][1]
    
fig = plt.figure(figsize=(10,6))
plt.pcolormesh(slopes_plot, cmap='Reds_r')
plt.clim(0, -1.2)
plt.xticks(np.arange(0, len(models_intersect), 1) + 0.5,list(slopes_all.keys()),fontsize=14, rotation='vertical')
plt.yticks([0.5,1.5],['March','September'],fontsize=14)
plt.title('Slope: mean Arctic temperature vs. total Arctic sea ice extent', fontsize=18)
plt.colorbar(label='Slope ((10$^{6}$ km$^{2}$)/K)')

In [None]:
# Plot R squared for all models
r_plot = np.zeros((2,len(r_mean.keys())))

for im, m in enumerate(r_mean.keys()):
    r_plot[0,im] = r_mean[m][0]*r_mean[m][0]
    r_plot[1,im] = r_mean[m][1]*r_mean[m][1]
    
fig = plt.figure(figsize=(10,6))
plt.pcolormesh(r_plot, cmap='Reds')
plt.clim(0, 0.75)
plt.xticks(np.arange(0, len(models_intersect), 1) + 0.5,list(slopes_all.keys()),fontsize=14, rotation='vertical')
plt.yticks([0.5,1.5],['March','September'],fontsize=14)
plt.title('R$^{2}$: mean Arctic temperature vs. total Arctic sea ice extent', fontsize=18)
plt.colorbar(label='R$^{2}$')

### Sorting by September slope

In [None]:
dumb = {}
dumb_name = {}
for n,name in enumerate(slopes_mean.keys()):
    dumb[n] = slopes_mean[name][1]
    dumb_name[n] = name

In [None]:
dumb_sorted = {k: v for k, v in sorted(dumb.items(), key=lambda item: item[1])}
slopes_sorted = {}
r_sorted = {}

for n in dumb_sorted.keys():
    slopes_sorted[dumb_name[n]] = slopes_mean[dumb_name[n]]
    r_sorted[dumb_name[n]] = r_mean[dumb_name[n]]

In [None]:
# Plot slopes for all models
slopes_splot = np.zeros((2,len(slopes_mean.keys())))

for im, m in enumerate(slopes_sorted.keys()):
    slopes_splot[0,im] = slopes_sorted[m][0]
    slopes_splot[1,im] = slopes_sorted[m][1]
    
fig = plt.figure(figsize=(10,6))
plt.pcolormesh(slopes_splot, cmap='Reds_r')
plt.clim(0, -1.2)
plt.xticks(np.arange(0, len(models_intersect), 1) + 0.5,list(slopes_all.keys()),fontsize=14, 
           rotation='vertical')
plt.yticks([0.5,1.5],['March','September'],fontsize=14)
plt.title('Slope: mean Arctic temperature vs. total Arctic sea ice extent', fontsize=18)
plt.colorbar(label='Slope ((10$^{6}$ km$^{2}$)/K)')

In [None]:
# Plot R squared for all models
r_plot = np.zeros((2,len(r_mean.keys())))

for im, m in enumerate(r_mean.keys()):
    r_splot[0,im] = r_sorted[m][0]*r_mean[m][0]
    r_splot[1,im] = r_sorted[m][1]*r_mean[m][1]
    
fig = plt.figure(figsize=(10,6))
plt.pcolormesh(r_splot, cmap='Reds')
plt.clim(0, 0.75)
plt.xticks(np.arange(0, len(models_intersect), 1) + 0.5,list(slopes_all.keys()),
           fontsize=14,rotation='vertical')
plt.yticks([0.5,1.5],['March','September'],fontsize=14)
plt.title('R$^{2}$: mean Arctic temperature vs. total Arctic sea ice extent', fontsize=18)
plt.colorbar(label='R$^{2}$')

In [None]:
TAS_ARCTIC_IN = dset_dict['tas']['CESM2']['tas_arc_mean'].sel(member_id='r1i1p1f1')
SIE_ARCTIC_IN = dset_dict['siconc']['CESM2']['sie_tot_arc'].sel(member_id='r1i1p1f1')
MONTHS_IN = [0,1,2]
PLOTFLAG = True
MODEL = 'CESM2'
import calendar

In [None]:
sall, rall, intall = au.scatter_linreg(TAS_ARCTIC_IN,SIE_ARCTIC_IN,MONTHS_IN,MODEL,True)