# Figure 3

estimating CI for metrics takes 2 hours

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import os
import cmocean
import pandas as pd
import numba

from dask.distributed import Client

from icon_util import regrid

work = os.environ['WORK']

In [None]:
client = Client()

client

In [None]:
client.close()

## Heatwave metrics

In [None]:
def metrics(data,dist,nseason=50):
    '''
        Produce Dataset with heatwave metrics from DataFrame and temperature distribution
    '''
    frequency = data.groupby('ncells')['length'].sum().to_xarray()
    frequency = frequency / nseason

    length = data.groupby('ncells')['length'].mean().to_xarray()

    ds = xr.Dataset(dict(frequency=frequency,length=length))
    ds = ds.assign_coords(dict(clon=dist.clon,clat=dist.clat))

    return ds

@numba.guvectorize(
    "(float64[:],float64[:],float64[:])",
    "(n), (m) -> (m)",
    forceobj=True
)    
def icdf(a,p,out):
    '''
        Inverse empirical cummulative distribution function of array at percentiles p
    '''
    sort = np.sort(a)
    out[:] = sort[np.int64(p*len(a))]


def confid(dist,alpha):
    '''
        Estimate confidence intervals using the inverse empirical cummulative distribution function 
        for distrubution of bootstrap samples.
    '''
    p = xr.DataArray([alpha/2, 1-alpha/2],dims=('percentile'))
    values = xr.apply_ufunc(icdf,
                          *(dist,p),
                          input_core_dims=[['random'],['percentile']],
                          output_core_dims=[['percentile']],
                          dask='parallelized',
                          output_dtypes=[[dist.dtype]])
    values['percentile'] = p
    
    return values

In [None]:
directory = work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_default_bootstrap_p0.1/'
files = [directory+f for f in os.listdir(directory) if f.startswith('sample')]
files.sort()

bootstrap = xr.open_mfdataset(files,combine='nested',concat_dim='random').chunk(dict(random=100))
bootstrap = regrid(bootstrap,lim=(0,90))
bootstrap = bootstrap.mean('longitude').compute()

frequency = confid(bootstrap['frequency'],0.05) * 4
length = confid(bootstrap['length'],0.05)

exp_CI = xr.Dataset(dict(frequency=frequency,length=length))

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_default_t1000_mean_percentiles.nc')
data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_default_t1000_mean_heatwaves.json')
exp = metrics(data,dist)
exp = regrid(exp,lim=(0,90))
exp = exp.mean('longitude').compute()

xr.Dataset(dict(frequency=exp['frequency'],length=exp['length'],frequency_CI=exp_CI['frequency'],length_CI=exp_CI['length'])).to_netcdf('./heatwaves_ref_t1000_mean.nc')

In [None]:
directory = work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp9_bootstrap_p0.1/'
files = [directory+f for f in os.listdir(directory) if f.startswith('sample')]
files.sort()

bootstrap = xr.open_mfdataset(files,combine='nested',concat_dim='random').chunk(dict(random=100))
bootstrap = regrid(bootstrap,lim=(0,90))
bootstrap = bootstrap.mean('longitude').compute()

frequency = confid(bootstrap['frequency'],0.05) * 4
length = confid(bootstrap['length'],0.05)

exp_CI = xr.Dataset(dict(frequency=frequency,length=length))

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp9_t1000_mean_percentiles.nc')
data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp9_t1000_mean_heatwaves.json')
exp = metrics(data,dist)
exp = regrid(exp,lim=(0,90))
exp = exp.mean('longitude').compute()

xr.Dataset(dict(frequency=exp['frequency'],length=exp['length'],frequency_CI=exp_CI['frequency'],length_CI=exp_CI['length'])).to_netcdf('./heatwaves_exp9_t1000_mean.nc')

In [None]:
directory = work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp4_bootstrap_p0.1/'
files = [directory+f for f in os.listdir(directory) if f.startswith('sample')]
files.sort()

bootstrap = xr.open_mfdataset(files,combine='nested',concat_dim='random').chunk(dict(random=100))
bootstrap = regrid(bootstrap,lim=(0,90))
bootstrap = bootstrap.mean('longitude').compute()

frequency = confid(bootstrap['frequency'],0.05) * 4
length = confid(bootstrap['length'],0.05)

exp_CI = xr.Dataset(dict(frequency=frequency,length=length))

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp4_t1000_mean_percentiles.nc')
data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp4_t1000_mean_heatwaves.json')
exp = metrics(data,dist)
exp = regrid(exp,lim=(0,90))
exp = exp.mean('longitude').compute()

xr.Dataset(dict(frequency=exp['frequency'],length=exp['length'],frequency_CI=exp_CI['frequency'],length_CI=exp_CI['length'])).to_netcdf('./heatwaves_exp4_t1000_mean.nc')

In [None]:
directory = work+'/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp8_bootstrap_p0.1/'
files = [directory+f for f in os.listdir(directory) if f.startswith('sample')]
files.sort()

bootstrap = xr.open_mfdataset(files,combine='nested',concat_dim='random').chunk(dict(random=100))
bootstrap = regrid(bootstrap,lim=(0,90))
bootstrap = bootstrap.mean('longitude').compute()

frequency = confid(bootstrap['frequency'],0.05) * 4
length = confid(bootstrap['length'],0.05)

exp_CI = xr.Dataset(dict(frequency=frequency,length=length))

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp8_t1000_mean_percentiles.nc')
data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp8_t1000_mean_heatwaves.json')
exp = metrics(data,dist)
exp = regrid(exp,lim=(0,90))
exp = exp.mean('longitude').compute()

xr.Dataset(dict(frequency=exp['frequency'],length=exp['length'],frequency_CI=exp_CI['frequency'],length_CI=exp_CI['length'])).to_netcdf('./heatwaves_exp8_t1000_mean.nc')

## Hot day persistence

In [None]:
@numba.jit(nopython=True)
def count_duration(array,index):
    '''
        Count occurence of set of consecutive hot days with certain length
        
        - first element of occurence counts sets that are longer than max_duration
    '''
    max_duration = 14
    occurence = np.zeros(max_duration+1,np.int_)
    count = 0
    
    for i in index:
        if array[i]:
            if count > 0:
                if count > max_duration:
                    occurence[0] += 1
                else:
                    occurence[count] += 1
            
            count = 0  
            
        else:
            count +=1
            
    return occurence



def loop_cells(array):
    '''
        Loop counter over ncells
    '''
    N = len(array.time)
    index = np.arange(N)
    
    # prepare array
    len1 = len(count_duration(array.isel(ncells=0).values,index))
    hist = np.zeros((len1,len(array.ncells)),np.int_)
    
    for i in range(len(array.ncells)):
        
        hist[:,i] = count_duration(array.isel(ncells=i).values,index)
        
        
    # prepare data
    length = xr.DataArray(range(1,len1),dims=('length'))
    events = xr.DataArray(hist[1:,:],dims=('length','ncells'),coords=dict(length=length))
    
    days = length * events
    
    missing_events = xr.DataArray(hist[0,:],dims='ncells')
    missing_days = 0.1*N - days.sum('length')
    
    return xr.Dataset(dict(days=days,missing=missing_days,missing_events=missing_events))



def non_hot_days(paths,dist):
    
    # load temperature data
    array = xr.open_mfdataset(paths,combine='nested',concat_dim='time')['ta'].squeeze()
    array = array.sel(time=slice(1300,501300))
    
    # compute daily means
    array['time'] = array['time'].astype(np.int64)
    array = array.groupby('time').reduce(np.mean)
    array = array.compute()
    
    # identify days that fail to exceed the threshold
    failing = (array < dist.sel(p=0.90).squeeze())
    
    return failing

In [None]:
paths = work+'/DATA/icon_simulations/atm_heldsuarez_default/3d/ta/100000pa/atm_heldsuarez_default_pl_*.nc'

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_default_t1000_mean_percentiles.nc')

failing = non_hot_days(paths,dist)

counts_ref = loop_cells(failing)

counts_ref = regrid(counts_ref.assign_coords(dict(clon=dist.clon,clat=dist.clat)),lim=(0,90)).mean('longitude')

counts_ref['days'] = counts_ref['days'] / 50

counts_ref

In [None]:
paths = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp9/3d/ta/100000pa/atm_heldsuarez_butler_exp9_pl_*.nc'

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp9_t1000_mean_percentiles.nc')

failing = non_hot_days(paths,dist)

counts_exp9 = loop_cells(failing)

counts_exp9 = regrid(counts_exp9.assign_coords(dict(clon=dist.clon,clat=dist.clat)),lim=(0,90)).mean('longitude')

counts_exp9['days'] = counts_exp9['days'] / 50

counts_exp9

In [None]:
paths = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp4/3d/ta/100000pa/atm_heldsuarez_butler_exp4_pl_*.nc'

dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp4_t1000_mean_percentiles.nc')

failing = non_hot_days(paths,dist)

counts_exp4 = loop_cells(failing)

counts_exp4 = regrid(counts_exp4.assign_coords(dict(clon=dist.clon,clat=dist.clat)),lim=(0,90)).mean('longitude')

counts_exp4['days'] = counts_exp4['days'] / 50

counts_exp4

In [None]:
directory = work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_default_hot_persistence_bootstrap_p0.03_regrid/'
files = [directory+f for f in os.listdir(directory) if f.startswith('regrid')]
files.sort()

bootstrap = xr.open_mfdataset(files,combine='nested',concat_dim='random')['days'].compute()

CI_ref = confid(bootstrap,0.05)
CI_ref = CI_ref / 50

CI_ref

## Plotting

In [None]:
ref = xr.open_dataset('./heatwaves_ref_t1000_mean.nc')
exp9 = xr.open_dataset('./heatwaves_exp9_t1000_mean.nc')
exp4 = xr.open_dataset('./heatwaves_exp4_t1000_mean.nc')
exp8 = xr.open_dataset('./heatwaves_exp8_t1000_mean.nc')

In [None]:
stormtrack_latitude = lambda da: da['latitude'].isel(latitude=da.argmax('latitude'))

directory = work+'/wolfgang/icon_storm_track/'
eke = xr.open_dataset(directory+'atm_heldsuarez_default_EKE_zonal_mean_10day_highpass.nc')['EKE'].sel(latitude=slice(0,90))
ref_latitude = stormtrack_latitude(eke.mean('time'))

directory = work+'/wolfgang/icon_storm_track/'
eke = xr.open_dataset(directory+'atm_heldsuarez_butler_exp9_EKE_zonal_mean_10day_highpass.nc')['EKE'].sel(latitude=slice(0,90))
exp9_latitude = stormtrack_latitude(eke.mean('time'))


directory = work+'/wolfgang/icon_storm_track/'
eke = xr.open_dataset(directory+'atm_heldsuarez_butler_exp4_EKE_zonal_mean_10day_highpass.nc')['EKE'].sel(latitude=slice(0,90))
exp4_latitude = stormtrack_latitude(eke.mean('time'))

directory = work+'/wolfgang/icon_storm_track/'
eke = xr.open_dataset(directory+'atm_heldsuarez_butler_exp8_EKE_zonal_mean_10day_highpass.nc')['EKE'].sel(latitude=slice(0,90))
exp8_latitude = stormtrack_latitude(eke.mean('time'))

In [None]:
fig, axes = plt.subplots(nrows=2,ncols=2,figsize=(8,6))
axes = axes.flatten()

# Frequency
l1 = ref['frequency'].plot(ax=axes[0],linestyle='-')
l2 = exp9['frequency'].plot(ax=axes[0],linestyle='--')
l3 = exp4['frequency'].plot(ax=axes[0],linestyle=':')
l4 = exp8['frequency'].plot(ax=axes[0],linestyle='-.')

axes[0].fill_between(ref['latitude'].values,ref['frequency_CI'].sel(percentile=0.025).values,ref['frequency_CI'].sel(percentile=0.975).values,alpha=0.3)
axes[0].fill_between(exp9['latitude'].values,exp9['frequency_CI'].sel(percentile=0.025).values,exp9['frequency_CI'].sel(percentile=0.975).values,alpha=0.3)
axes[0].fill_between(exp4['latitude'].values,exp4['frequency_CI'].sel(percentile=0.025).values,exp4['frequency_CI'].sel(percentile=0.975).values,alpha=0.3)
axes[0].fill_between(exp8['latitude'].values,exp8['frequency_CI'].sel(percentile=0.025).values,exp8['frequency_CI'].sel(percentile=0.975).values,alpha=0.3)

ylim = axes[0].get_ylim()

axes[0].plot([ref_latitude,ref_latitude],ylim,color=l1[0].get_color(),linestyle=':',linewidth=1)
axes[0].plot([exp9_latitude,exp9_latitude],ylim,color=l2[0].get_color(),linestyle=':',linewidth=1)
axes[0].plot([exp4_latitude,exp4_latitude],ylim,color=l3[0].get_color(),linestyle=':',linewidth=1)
axes[0].plot([exp8_latitude,exp8_latitude],ylim,color=l4[0].get_color(),linestyle=':',linewidth=1)

axes[0].set_ylim(ylim)

# length
ref['length'].plot(ax=axes[2],linestyle='-')
exp9['length'].plot(ax=axes[2],linestyle='--')
exp4['length'].plot(ax=axes[2],linestyle=':')
exp8['length'].plot(ax=axes[2],linestyle='-.')

axes[2].fill_between(ref['latitude'].values,ref['length_CI'].sel(percentile=0.025).values,ref['length_CI'].sel(percentile=0.975).values,alpha=0.3)
axes[2].fill_between(exp9['latitude'].values,exp9['length_CI'].sel(percentile=0.025).values,exp9['length_CI'].sel(percentile=0.975).values,alpha=0.3)
axes[2].fill_between(exp4['latitude'].values,exp4['length_CI'].sel(percentile=0.025).values,exp4['length_CI'].sel(percentile=0.975).values,alpha=0.3)
axes[2].fill_between(exp8['latitude'].values,exp8['length_CI'].sel(percentile=0.025).values,exp8['length_CI'].sel(percentile=0.975).values,alpha=0.3)

ylim = axes[2].get_ylim()

axes[2].plot([ref_latitude,ref_latitude],ylim,color=l1[0].get_color(),linestyle=':',linewidth=1)
axes[2].plot([exp9_latitude,exp9_latitude],ylim,color=l2[0].get_color(),linestyle=':',linewidth=1)
axes[2].plot([exp4_latitude,exp4_latitude],ylim,color=l3[0].get_color(),linestyle=':',linewidth=1)
axes[2].plot([exp8_latitude,exp8_latitude],ylim,color=l4[0].get_color(),linestyle=':',linewidth=1)

axes[2].set_ylim(ylim)


# persistence tropical heating
sig = (counts_exp9['days'] < CI_ref.sel(percentile=0.025)) + (counts_exp9['days'] > CI_ref.sel(percentile=0.975))

C = (counts_exp9['days']-counts_ref['days']).plot(ax=axes[1],levels=np.arange(-3,3.25,0.25),add_colorbar=False,extend='both',cmap=cmocean.cm.delta)
                                                
sig.astype('double').plot.contourf(ax=axes[1],levels=[0,0.5,1],hatches=['//',''],alpha=0,add_colorbar=False)

zeros = xr.zeros_like(counts_ref['latitude']).assign_coords(length=-1000)
C0 = xr.concat([zeros,counts_ref['days']],dim='length').plot.contour(ax=axes[1],x='latitude',levels=[1,3,5,7,9,11],colors='k')

plt.clabel(C0,fontsize='x-small')


# persistence arctic heating
sig = (counts_exp4['days'] < CI_ref.sel(percentile=0.025)) + (counts_exp4['days'] > CI_ref.sel(percentile=0.975))

C = (counts_exp4['days']-counts_ref['days']).plot(ax=axes[3],levels=np.arange(-3,3.25,0.25),add_colorbar=False,extend='both',cmap=cmocean.cm.delta)
                                                
sig.astype('double').plot.contourf(ax=axes[3],levels=[0,0.5,1],hatches=['//',''],alpha=0,add_colorbar=False)

zeros = xr.zeros_like(counts_ref['latitude']).assign_coords(length=-1000)
C0 = xr.concat([zeros,counts_ref['days']],dim='length').plot.contour(ax=axes[3],x='latitude',levels=[1,3,5,7,9,11],colors='k')

plt.clabel(C0,fontsize='x-small')


for ax in axes:
    
    ax.set_xlabel('')
    ax.set_xlim(0,90)
    ax.set_xticks([0,30,60,90])
    ax.set_xticks([15,45,75],minor=True)
    ax.set_title('')


for ax in axes[[1,3]]:
    
    ax.set_ylabel('Length [days]',fontsize=13)
    ax.set_ylim(0.5,14)
    ax.set_yticks([3,6,9,12])
    ax.set_yticks([1,2,4,5,7,8,10,11,13,14],minor=True)
    
    xlim = ax.get_xlim()
    ax.plot(xlim,[2.5,2.5],linestyle=':',linewidth=1.5,color='m')
    ax.set_xlim(xlim)

    
for ax in axes[[0,2]]:
    
    ax.xaxis.grid(which='both')
        
axes[2].set_xlabel('Latitude [°N]',fontsize=13)
axes[3].set_xlabel('Latitude [°N]',fontsize=13)
axes[0].set_ylabel('Heatwave frequency [days/year]',fontsize=13,labelpad=10)
axes[2].set_ylabel('Mean duration [days]',fontsize=13,labelpad=15)



axes[1].set_title('Tropical warming',weight='bold',fontsize='smaller')



axes[3].set_title('Arctic warming',weight='bold',fontsize='smaller')

axes[2].legend(['Reference','Tropical warming','Arctic warming','Combined forcing'],fontsize=11,loc='upper center',ncols=2,columnspacing=0.8,handlelength=1.)
    

fig.subplots_adjust(0,0,1,1,0.3,0.3)

cbar = plt.colorbar(C,ax=axes[[1,3]],orientation='horizontal',shrink=0.9,pad=0.13)
cbar.set_label(r'Density difference [year$^{-1}$]',fontsize=13)
cbar.set_ticks([-3,-2,-1,0,1,2,3])

box = list(axes[2].get_position().bounds)
box[1] += 0.11
box[3] -= 0.03
axes[2].set_position(box)

box = list(axes[0].get_position().bounds)
box[1] += 0.03
box[3] -= 0.03
axes[0].set_position(box)

trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)

axes[0].text(-0.1,1.1,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom')
axes[2].text(-0.1,1.1,'b)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom')
axes[1].text(0.06,1.1,'c)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom')
axes[3].text(0.06,1.1,'d)',transform=axes[3].transAxes+trans,fontsize='large',va='bottom')