# Figure 5

In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import matplotlib.transforms as mtransforms
import matplotlib as mpl
import matplotlib.colors
import pandas as pd
import cmocean
import numba
import scipy.stats

from icon_util import regrid

plt.rcParams.update({'font.size': 14})
work = os.environ['WORK']

## Heatwave metrics and EKE

In [None]:
def metrics(data,dist,nseason=50):
    '''
        Produce Dataset with heatwave metrics from DataFrame and temperature distribution
    '''
    frequency = data.groupby('ncells')['length'].sum().to_xarray()
    frequency = frequency / nseason

    length = data.groupby('ncells')['length'].mean().to_xarray()

    ds = xr.Dataset(dict(frequency=frequency,length=length))
    ds = ds.assign_coords(dict(clon=dist.clon,clat=dist.clat))

    return ds


In [None]:
directory = work+'/wolfgang/icon_storm_track/'

experiments_long = ['default','butler_exp1','butler_exp2','butler_exp3','butler_exp4',
                    'butler_exp8','butler_exp9','butler_exp10','butler_exp11',
                    'butler_exp12','butler_exp13','butler_exp14','butler_exp15']
experiments_short = ['ref','exp1','exp2','exp3','exp4',
                    'exp8','exp9','exp10','exp11',
                    'exp12','exp13','exp14','exp15',]

eke = []
freq = []


for l, s in zip(experiments_long,experiments_short):
    
    energy = xr.open_dataset(directory+'atm_heldsuarez_%s_EKE_zonal_mean_10day_highpass.nc'%(l))['EKE'].sel(latitude=slice(0,90))
    energy = energy.mean('time') / 1000
    eke.append(energy.assign_coords(exp=s))
    
    dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_%s_t1000_mean_percentiles.nc'%(l))
    data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_%s_t1000_mean_heatwaves.json'%(l))
    exp = metrics(data,dist)['frequency']
    exp = regrid(exp,lim=(0,90))
    exp = exp.mean('longitude').compute()
    
    freq.append(exp.assign_coords(exp=s))
    
eke = xr.concat(eke,dim='exp') 

freq = xr.concat(freq,dim='exp')

## Plotting

In [None]:
class OLS:
    '''
        Ordinary least squares
        
        based on vonStorch & Zwiers (2010), Ch. 8.3.1-3
    '''
    
    def __init__(self,x,y):
        Mx = np.mean(x).values
        My = np.mean(y).values
        n = len(x)
        
        # biased variance and covariance estimates - that would acctually require a division by n
        Sxx = np.sum(x**2).values - n * Mx**2
        Syy = np.sum(y**2).values - n * My **2
        Sxy = np.sum(x * y).values - n * Mx * My
        
        # unbiased estimates for slope and intercept
        self.a1 = Sxy / Sxx
        self.a0 = My - self.a1 * Mx
        
        # sum of squared errors
        SSE = Syy - self.a1 * Sxy
        
        # Syy is also called the total sum of squares
        # coefficient of determination
        self.R2 = 1 - SSE / Syy
        
        # unbiases estimateof variance of errors
        self.sig2E = SSE / (n-2)
        
    def fit(self,xi):
        return self.a0 + self.a1*xi
    
    
class Confidence:
    '''
        Estimate CI width for model parameters and response
    '''
    
    def __init__(self,model,x,alpha=0.05):
        
        self.Mx = np.mean(x).values
        self.n = len(x)
        self.Sxx = np.sum(x**2).values - self.n * self.Mx**2
        
        # critical t-value
        self.t = scipy.stats.t.isf(alpha/2,self.n-2)
        
        self.sigE = np.sqrt(model.sig2E)
        
        
    def slope(self):
        '''
            based on Ch. 8.3.7
        '''
        return self.t * self.sigE * self.Sxx**(-1/2)
    
    
    def response(self,xi):
        '''
            based on Ch. 8.3.11
        '''
        return self.t * self.sigE * np.sqrt(1 + 1/self.n + (xi - self.Mx)**2/self.Sxx)
    

In [None]:
def get_trop_color(exp):
    
    
    values = dict(exp1=0.5,exp2=0,exp3=0.25,exp4=0,exp8=0.2,exp9=0.2,
                  exp10=0,exp11=0,exp12=-0.2,exp13=0.25,exp14=0.25,exp15=0.5)
    
    norm = matplotlib.colors.Normalize(vmin=0,vmax=0.7,clip=True)
    
    
    if exp == 'ref':
        return 'k'
    else:
        if values[exp] >=0:
            return mpl.colormaps['Oranges'](norm(values[exp]))
        else:
            return mpl.colormaps['Blues'](norm(np.abs(values[exp])))
        
def get_aa_color(exp):
    
    
    values = dict(exp1=0,exp2=0.5,exp3=0,exp4=1.0,exp8=1.,exp9=0,
                  exp10=1.5,exp11=-1,exp12=0,exp13=0.5,exp14=1,exp15=1)
    
    norm = matplotlib.colors.Normalize(vmin=0,vmax=1.7,clip=True)
    
    
    if exp == 'ref':
        return 'w'
    else:
        if values[exp] >=0:
            return mpl.colormaps['Purples'](norm(values[exp]))
        else:
            return mpl.colormaps['Blues'](norm(np.abs(values[exp])))


def subplot(x,y,ax,add_legend=False):
    
    model = OLS(x,y)
    ci = Confidence(model,x)
    
    print('Intercept %.2f'%(model.a0))
    print('Slope %.4f'%(model.a1))
    print('R %.4f'%(np.sqrt(model.R2)))
    
    # plot scatter
    for exp in x['exp']:
        ax.plot(x.sel(exp=exp),y.sel(exp=exp),linestyle='',marker='o',label=exp.values,
               markersize=9,markeredgewidth=2,markerfacecolor=get_aa_color(exp.item()),markeredgecolor=get_trop_color(exp.item()))
        
    if add_legend: ax.legend(fontsize=10,ncols=4,handlelength=1,loc='lower center',frameon=False)
    
    
    # plot fit and response interval
    xlim = ax.get_xlim()
    values = np.hstack([xlim[0],x.sortby(x).values,xlim[1]])
    
    ax.plot(values,model.fit(values),color='grey')
    ax.fill_between(values,model.fit(values)-ci.response(values),model.fit(values)+ci.response(values),color='grey',alpha=0.1)
    
    ax.set_xlim(xlim)
    
    # print coefficient of determination and slope in title
    ax.set_title(r'R$^2$=%.2f, a$_1$=%.2f$\pm$%.2f'%(model.R2,model.a1,ci.slope()))
    

In [None]:
fig, axes = plt.subplots(nrows=2,ncols=2,figsize=(8,10)) #,sharex='col',sharey='row'

axes = axes.flatten()

# storm track position vs position of minimum
x = eke['latitude'].isel(latitude=eke.argmax('latitude'))
y = freq['latitude'].isel(latitude=freq.argmin('latitude'))

subplot(x,y,axes[0])


# storm track magnuitude vs position of minimum
x = eke.max('latitude')
y = freq['latitude'].isel(latitude=freq.argmin('latitude'))

subplot(x,y,axes[1])


# storm track position vs minimal frequency
x = eke['latitude'].isel(latitude=eke.argmax('latitude'))
y = freq.min('latitude')

subplot(x,y,axes[2])


# storm magnitude vs minimal frequency
x = eke.max('latitude')
y = freq.min('latitude')

subplot(x,y,axes[3],add_legend=True)


for ax in axes:
    
    ax.grid()


axes[0].set_ylabel('Position of heatwave frequency minimum [°N]')
axes[1].set_ylabel('Position of heatwave frequency minimum [°N]')
axes[2].set_ylabel('Minimum heatwave frequency [days/year]')
axes[3].set_ylabel('Minimum heatwave frequency [days/year]')
axes[0].set_xlabel('Position of EKE maximum [°N]')
axes[2].set_xlabel('Position of EKE maximum [°N]')
axes[1].set_xlabel(r'Maximum EKE [kJ m$^{-2}$]')
axes[3].set_xlabel(r'Maximum EKE [kJ m$^{-2}$]')

axes[0].set_ylim(axes[1].get_ylim())
axes[3].set_ylim(axes[2].get_ylim())


fig.subplots_adjust(0,0,1,1,0.25,0.25)

trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)

axes[0].text(-0.0,1.1,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom')
axes[1].text(0.1,1.1,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom')
axes[2].text(0.0,1.1,'c)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom')
axes[3].text(0.1,1.1,'d)',transform=axes[3].transAxes+trans,fontsize='large',va='bottom')
