# Figure 4

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import os
import cmocean
import pandas as pd
import matplotlib.transforms as mtransforms
import scipy.signal
import scipy.fft 
import numba

from icon_util import regrid

work = os.environ['WORK']

## Hovmoeller composites

In [None]:
def construct_composite(data,dates,central,dt=0.25):
    
    # cyclic data with -360 < lon < 360
    upper = data.sel(longitude=slice(-179,0))
    upper['longitude'] = upper['longitude'] + 360
    lower = data.sel(longitude=slice(0,179))
    lower['longitude'] = lower['longitude'] - 360
    
    cyclic = xr.concat([lower,data,upper],dim='longitude')
    
    rangeIndex = xr.DataArray(np.arange(len(data['time'])),coords=dict(time=data['time']))
    
    composite = []
    
    for day, lon in zip(dates.values,central.values):
        
        # select, subtract lon, intepolate
        onset = cyclic.sel(time=slice(day,day+0.9)).sel(longitude=lon,method='nearest')
        onset = onset['time'].isel(time=onset['ta'].argmax('time'))
        i = rangeIndex.sel(time=onset).values.astype('int')
        
        selection = cyclic.isel(time=slice(i-40,i+40))
        selection = selection.drop('time').rename(time='step')
        selection = selection.assign_coords(onset=onset.values)
        
        selection['longitude'] = selection['longitude'] - lon
        selection = selection.interp(longitude=range(-180,180))
        
        if len(selection.step)==80: composite.append(selection.assign_coords(step=np.arange(-40,40)*dt))
        
    return xr.concat(composite,dim='onset')

In [None]:
data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp9_t1000_mean_heatwaves.json')
dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp9_t1000_mean_percentiles.nc')

indices = dist['ncells'].where((dist['clat']<=np.radians(53)) * (dist['clat']>=np.radians(52)),drop=True)
samples = data[data['ncells'].isin(indices.values)]

dates = samples['start'].to_xarray()
central = np.degrees(dist['clon'].isel(ncells=samples['ncells']))

samples

In [None]:
# temperature anomaly with respect to 90th percentile
directory = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp9/3d/ta/100000pa/'

files = [directory + f for f in os.listdir(directory) if f.endswith('.nc')]
files.sort()

ta = xr.open_mfdataset(files,combine='nested',concat_dim='time')['ta']

ta = ta.sel(time=slice(1300,501300)).squeeze()

ta = ta - dist.sel(p=0.9)

ta = regrid(ta.drop('plev'),lim=(52,52.5),res=1)

ta = ta.mean('latitude').compute()

ta

In [None]:
exp9 = construct_composite(xr.Dataset(dict(ta=ta)),dates,central)

exp9

In [None]:
data = pd.read_json(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp4_t1000_mean_heatwaves.json')
dist = xr.open_dataarray(work+'/wolfgang/icon_postproc_t1000_mean/atm_heldsuarez_butler_exp4_t1000_mean_percentiles.nc')

indices = dist['ncells'].where((dist['clat']<=np.radians(44)) * (dist['clat']>=np.radians(43)),drop=True)
samples = data[data['ncells'].isin(indices.values)]

dates = samples['start'].to_xarray()
central = np.degrees(dist['clon'].isel(ncells=samples['ncells']))

samples

In [None]:
# temperature anomaly with respect to 90th percentile
directory = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp4/3d/ta/100000pa/'

files = [directory + f for f in os.listdir(directory) if f.endswith('.nc')]
files.sort()

ta = xr.open_mfdataset(files,combine='nested',concat_dim='time')['ta']

ta = ta.sel(time=slice(1300,501300)).squeeze()

ta = ta - dist.sel(p=0.9)

ta = regrid(ta.drop('plev'),lim=(43,43.5),res=1)

ta = ta.mean('latitude').compute()

ta

In [None]:
exp4 = construct_composite(xr.Dataset(dict(ta=ta)),dates,central)

exp4

## Phase and group velocity

In [None]:
@numba.guvectorize(   
    "(float64[:],int16,complex128[:])",
    "(n), () -> (n)",
    forceobj=True
)
def _hilbert(y,N,out):
    '''
        Analytic signal using the Hilbert transform technique (Marple, 1999)
    '''
    # Check whether the signal has even or odd length
    if N%2==0:
        a = int(N/2)
    else:
        if N>1:
            a = int((N-1)/2)
        else:
            a = 0
        
    # FFT of y
    z = np.fft.fft(y)
    
    # Zero-out the negative frequencies
    z[a+1:N] = 0
    # Double the positive frequencies except from the 0th and (N/2)th ones
    z = 2*z
    z[0] = z[0]/2
    if N%2==0: 
        # For the even-length case, we also have the Nyquist frequency in the spectrum. 
        # This is shared between the positive and negative frequencies so we need to keep it once (see Marple 1999). 
        # For odd lengths, there is no Nyquist frequency in the spectrum.
        z[a] = z[a]/2

    # Inverse FFT to get the analytic signal
    out[:] = np.fft.ifft(z)
    
    
@numba.vectorize([numba.float64(numba.float64, numba.float64)])
def _rad_diff(a,b):
    '''
        In cases where the upstream and downstream phase differ more than pi or -pi, add/subtract 2pi where needed.
    '''
    diff = a - b
    if diff > np.pi:
        diff -= 2*np.pi
    elif diff < -np.pi:
        diff += 2*np.pi
        
    return diff
    
    
@numba.guvectorize(   
    "(float64[:],float64[:])",
    "(n) -> (n)",
    forceobj=True
)    
def _finite_difference(a,out):
    '''
        Use centered differences in the interior, one-sided differences at the boundaries
    '''
    out[1:-1] = _rad_diff(a[2:],a[:-2])/2.
    out[-1] = _rad_diff(a[-1],a[-2])
    out[0] = _rad_diff(a[1],a[0])

In [None]:
@numba.guvectorize(   
    "(float64[:],float64[:],float64[:])",
    "(n), (n) -> (n)",
    forceobj=True
)
def _fft_filter(y,mask,out):
    '''
        Filter by multiplication in spectral space
    '''
    z = scipy.fft.fft(y)
    z = z * mask
    out[:] = scipy.fft.ifft(z)
    
    


def filtering(da,kmin=2,kmax=10,alpha=0.5):
    '''
    '''
    # prepare mask
    N = len(da.longitude)
    taper = scipy.signal.tukey(kmax-kmin+1,alpha=alpha)
    mask = np.zeros(N)
    mask[kmin:kmax+1] = taper
    mask = xr.DataArray(mask,dims=('freq'))
    
    # multiply data with mask in spectral space
    filtered = xr.apply_ufunc(_fft_filter,
                              *(da,mask),
                              input_core_dims=[['longitude'],['freq']],
                              output_core_dims=[['longitude']],
                              dask='parallelized',
                              output_dtypes=[da.dtype]
                             )
    
    # Since the ignored negative frequencies would contribute the same as the positive ones
    filtered *= 2 
    
    return filtered

In [None]:
def envelope_phase(da):
    '''
        Absolute value and phase angle of the complex signal
    '''
    sig = xr.apply_ufunc(_hilbert,
                         *(da,len(da.longitude)),
                         input_core_dims=[['longitude'],[]],
                         output_core_dims=[['longitude']],
                         dask='parallelized',
                         output_dtypes=[np.dtype('complex128')]
                        )

    env = np.abs(sig)
    phase = np.arctan2(np.imag(sig),np.real(sig))
    
    return xr.Dataset(dict(env=env,phase=phase))


def wavenum_speed(phase,lat=51,dt=6*3600):
    '''
        Use finite differences to estimate wavenumber and phase speed
    '''
    # radians per time step
    freq = xr.apply_ufunc(_finite_difference,
                          phase,
                          input_core_dims=[['step']],
                          output_core_dims=[['step']],
                          dask='parallelized',
                          output_dtypes=[np.double]
                         )
    
    # radians per grid spacing
    wavenum = xr.apply_ufunc(_finite_difference,
                             phase,
                             input_core_dims=[['longitude']],
                             output_core_dims=[['longitude']],
                             dask='parallelized',
                             output_dtypes=[np.double]
                            )
    
    # grid spacing per time step
    speed = freq / wavenum
    # grid spacing per second
    speed = speed / (-dt)
    # meter per second
    speed = speed * (2*np.pi*6371000) * np.cos(np.radians(lat)) / (len(speed['longitude'])-1)
    
    # cycles per circumference
    wavenum = wavenum / (2*np.pi) * len(wavenum['longitude'])
    
    return xr.Dataset(dict(wavenum=wavenum,speed=speed))

In [None]:
da = filtering(exp9['ta'].mean('onset'),kmin=3,kmax=15,alpha=0.5)
ds = envelope_phase(da)
group = envelope_phase(ds['env'])

exp9_wave = xr.Dataset(dict(wave=ds['phase'],envelope=ds['env'],group=group['phase']))

In [None]:
da = filtering(exp4['ta'].mean('onset'),kmin=3,kmax=15,alpha=0.5)
ds = envelope_phase(da)
group = envelope_phase(ds['env'])

exp4_wave = xr.Dataset(dict(wave=ds['phase'],envelope=ds['env'],group=group['phase']))

## Plotting

In [None]:
fig, axes = plt.subplots(nrows=3,figsize=(8,10))


C = exp9['ta'].mean('onset').plot(ax=axes[0],levels=np.linspace(-8,8,17),extend='both',add_colorbar=False)

exp9_wave['wave'].where(exp9_wave['envelope']>0.5).plot.contour(ax=axes[0],levels=[0,],colors='k',linestyles=':',alpha=0.6)
exp9_wave['group'].where(exp9_wave['envelope']>0.5).plot.contour(ax=axes[0],levels=[0,],colors='k',alpha=0.6)

C0 = exp9_wave['envelope'].plot.contour(ax=axes[0],levels=np.arange(0.5,5,0.5),colors='w',alpha=0.8)

plt.clabel(C0,levels=np.arange(1,4.5),fontsize='x-small',colors='k')



exp4['ta'].mean('onset').plot(ax=axes[1],levels=np.linspace(-8,8,17),extend='both',add_colorbar=False)

exp4_wave['wave'].where(exp4_wave['envelope']>0.5).plot.contour(ax=axes[1],levels=[0,],colors='k',linestyles=':',alpha=0.6)
exp4_wave['group'].where(exp4_wave['envelope']>0.5).plot.contour(ax=axes[1],levels=[0,],colors='k',alpha=0.6)

C0 = exp4_wave['envelope'].plot.contour(ax=axes[1],levels=np.arange(0.5,5,0.5),colors='w',alpha=0.8)

plt.clabel(C0,levels=np.arange(1,4.5),fontsize='x-small',colors='k')


exp9_wave['wave'].where(exp9_wave['envelope']>0.5).plot.contour(ax=axes[2],levels=[0,],colors='#1f77b4',linestyles=':')
exp9_wave['group'].where(exp9_wave['envelope']>0.5).plot.contour(ax=axes[2],levels=[0,],colors='#1f77b4',linewidths=1)
exp9['ta'].mean('onset').plot.contour(ax=axes[2],levels=[0,],colors='#1f77b4')

exp4_wave['wave'].where(exp4_wave['envelope']>0.5).plot.contour(ax=axes[2],levels=[0,],colors='#ff7f0e',linestyles=':')
exp4_wave['group'].where(exp4_wave['envelope']>0.5).plot.contour(ax=axes[2],levels=[0,],colors='#ff7f0e',linewidths=1)
exp4['ta'].mean('onset').plot.contour(ax=axes[2],levels=[0,],colors='#ff7f0e')


l1 = axes[2].plot([],[],color='#1f77b4')
l2 = axes[2].plot([],[],color='#ff7f0e')

axes[2].legend([*l1,*l2],['Tropical warming','Arctic warming'],loc='upper left',fontsize=10)


for ax in axes:
    
    ax.plot(ax.get_xlim(),[0,0],linestyle=':',linewidth=0.5,color='k')
    ax.plot([0,0],ax.get_ylim(),linestyle=':',linewidth=0.5,color='k')

    ax.set_title('')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_ylim(-8,8)
    ax.set_xticks([-180,-120,-60,0,60,120,180])
    ax.set_yticks([-8,-4,0,4,8],minor=False)
    
    
axes[0].set_title('Tropical warming',weight='bold',fontsize='smaller')
axes[1].set_title('Arctic warming',weight='bold',fontsize='smaller')

axes[2].set_xlabel('Relative longitude [°E] relative to grid point')
axes[1].set_ylabel('Lag [days] relative to heatwave onset',fontsize=16)

axes[0].set_xticklabels([])
axes[1].set_xticklabels([])


cbar = plt.colorbar(C,ax=axes[:2],aspect=30)
cbar.set_label('T1000 anomaly [K] from 90th percentile')

box = list(axes[2].get_position().bounds)
box[2] = axes[1].get_position().bounds[2]
axes[2].set_position(box)

trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)

axes[0].text(-0.1,1.06,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom')
axes[1].text(-0.1,1.06,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom')
axes[2].text(-0.1,1.06,'c)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom')

