# Figure 2

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import os
import cmocean
import numba

from dagk.distributed import Client
from dask_jobqueue import SLURMCluster
from scipy import signal, stats

from icon_util import rfft, spectrum

work = os.environ.get('WORK')+'/'
plt.rcParams.update({'font.size': 14})

## Fourier transform in longitude

In [None]:
# open multi-file dataset

directory = work+'/DATA/icon_simulations/atm_heldsuarez_default/3d/va/25000pa/'

files = [directory + f for f in os.listdir(directory) if f.endswith('.nc')]
files.sort()

array = xr.open_mfdataset(files,chunks=dict(),combine='nested',concat_dim='time')['va']

array = array.sel(time=slice(1300,501300))

## compute anomalies from climatology

clim = array.mean('time').compute()

array= array - clim


# compute Fourier coefficients

coeff = rfft(array)

ref_coeff = coeff.compute()

In [None]:
# open multi-file dataset

directory = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp9/3d/va/25000pa/'

files = [directory + f for f in os.listdir(directory) if f.endswith('.nc')]
files.sort()

array = xr.open_mfdataset(files,chunks=dict(),combine='nested',concat_dim='time')['va']

array = array.sel(time=slice(1300,501300))

## compute anomalies from climatology

clim = array.mean('time').compute()

array= array - clim


# compute Fourier coefficients

coeff = rfft(array)

exp9_coeff = coeff.compute()

In [None]:
# open multi-file dataset

directory = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp4/3d/va/25000pa/'

files = [directory + f for f in os.listdir(directory) if f.endswith('.nc')]
files.sort()

array = xr.open_mfdataset(files,chunks=dict(),combine='nested',concat_dim='time')['va']

array = array.sel(time=slice(1300,501300))

## compute anomalies from climatology

clim = array.mean('time').compute()

array= array - clim


# compute Fourier coefficients

coeff = rfft(array)

exp4_coeff = coeff.compute()

In [None]:
# open multi-file dataset

directory = work+'/DATA/icon_simulations/atm_heldsuarez_butler_exp8/3d/va/25000pa/'

files = [directory + f for f in os.listdir(directory) if f.endswith('.nc')]
files.sort()

array = xr.open_mfdataset(files,chunks=dict(),combine='nested',concat_dim='time')['va']

array = array.sel(time=slice(1300,501300))

## compute anomalies from climatology

clim = array.mean('time').compute()

array= array - clim


# compute Fourier coefficients

coeff = rfft(array)

exp8_coeff = coeff.compute()

## Fourier transform in time

In [None]:
def construct_rolling(array,nstep=20*4):
    
    length = len(array['time'])
    start = range(0,length-int(nstep/2),int(nstep/2))
    end = range(nstep,length+int(nstep/2),int(nstep/2))
    
    
    rolling = [array.isel(time=slice(start,end)) for start, end in zip(start,end)]
    rolling = [da.drop('time').assign_coords(rolling=da.time.mean()) for da in rolling]
    
    return xr.concat(rolling,dim='rolling')

def wavenum_freq_spect(coeff,timestep=6*3600,window=signal.tukey,window_args=(0.5,),pad=0):
    
    freq = xr.DataArray(np.fft.fftfreq(len(coeff.time)+pad,d=timestep),dims=('frequency'))
    taper = xr.DataArray(window(len(coeff.time),*window_args),coords=dict(time=coeff.time),dims=('time'))
    
    coeff = coeff * taper
    
    spect = xr.apply_ufunc(spectrum,
                           *(coeff,freq),
                          input_core_dims=[['time'],['frequency']],
                           output_core_dims=[['frequency']],
                           dask='parallelized',
                           output_dtypes=[np.float64])
    spect['frequency'] = freq
    
    spect = spect / len(coeff.time) * timestep
    
    spect = spect * 2
    
    spect = spect / (taper**2).mean()
    
    return spect


@numba.guvectorize(
    "(float64[:],float64[:],float64[:],float64[:])",
    "(n), (n), (m) -> (m)",
    nopython=True
)
def interp_freq(spect,freq,fc,out):

    out[:] = np.interp(fc,freq,spect)


def freq2phase_speed_interp(spect,dc=1):
    '''
        Calculate wavenumber-phase speed spectra from wavenumber-frequency spectra
        following Randel & Held (1991)
        
        -positive phase speed is eastward
    '''
    # Define an array of phase speed
    N = len(spect.frequency)
    c = xr.DataArray(np.arange(-dc*int(N/2),dc*int(N/2),dc),dims=('phase_speed')) 
    c = c.assign_coords(phase_speed=c)
    
    # Define the array of frequencies that correspond that phase speeds
    a = 6371000
    factor = 1 / (2*np.pi*a)
    factor = factor / np.cos(spect.latitude/180*np.pi)
    fc = factor * c * spect.wavenumber # this has dimenstions latitude, phase speed, wavenumber
    
    # Interpolate linearly to these frequencies
    spect = spect.sortby('frequency')
    new_spect = xr.apply_ufunc(interp_freq,
                               *(spect,spect.frequency,fc),
                               input_core_dims=[['frequency'],['frequency'],['phase_speed']],
                               output_core_dims=[['phase_speed']],
                               dask='parallelized',
                               output_dtypes=[spect.dtype])
    
    # scale power spectral density into units of phase speed
    new_spect = new_spect * spect.wavenumber * factor
    # positive phase speed is eastward
    new_spect['phase_speed'] = -1* new_spect['phase_speed']
    
    return new_spect


def centroid(da,dim='phase_speed'):
    return (da*da[dim]).sum(('wavenumber','phase_speed')) / da.sum(('wavenumber','phase_speed'))

In [None]:
# Hayashi spectra for rolling windows
rolling = construct_rolling(ref_coeff.isel(time=slice(None,71960)),nstep=14*4).chunk(dict(time=-1,rolling=100))

reduced = rolling.isel(rolling=range(0,len(rolling['rolling']),2))

spect = wavenum_freq_spect(reduced,timestep=6*3600,window=signal.tukey,window_args=(0.5,),pad=2*len(rolling['time']))

spect = freq2phase_speed_interp(spect.rename(latitude_bin='latitude'),dc=1/3)

spect = spect.sel(phase_speed=slice(20,-20),latitude=slice(20,80),wavenumber=slice(1,10))

# meridional mean
ref_spect = spect.sel(latitude=slice(35,65)).mean(('latitude')).compute().squeeze()

# centroid as a function of latitude
ds = xr.Dataset(dict(phase_speed=centroid(spect,'phase_speed'),wavenumber=centroid(spect,'wavenumber')))

ref_centroid = ds.compute()

In [None]:
# Hayashi spectra for rolling windows
rolling = construct_rolling(exp9_coeff.isel(time=slice(None,71960)),nstep=14*4).chunk(dict(time=-1,rolling=100))

reduced = rolling.isel(rolling=range(0,len(rolling['rolling']),2))

spect = wavenum_freq_spect(reduced,timestep=6*3600,window=signal.tukey,window_args=(0.5,),pad=2*len(rolling['time']))

spect = freq2phase_speed_interp(spect.rename(latitude_bin='latitude'),dc=1/3)

spect = spect.sel(phase_speed=slice(20,-20),latitude=slice(20,80),wavenumber=slice(1,10))

# meridional mean
exp9_spect = spect.sel(latitude=slice(35,65)).mean(('latitude')).compute().squeeze()

# centroid as a function of latitude
ds = xr.Dataset(dict(phase_speed=centroid(spect,'phase_speed'),wavenumber=centroid(spect,'wavenumber')))

exp9_centroid = ds.compute()

In [None]:
# Hayashi spectra for rolling windows
rolling = construct_rolling(exp4_coeff.isel(time=slice(None,71960)),nstep=14*4).chunk(dict(time=-1,rolling=100))

reduced = rolling.isel(rolling=range(0,len(rolling['rolling']),2))

spect = wavenum_freq_spect(reduced,timestep=6*3600,window=signal.tukey,window_args=(0.5,),pad=2*len(rolling['time']))

spect = freq2phase_speed_interp(spect.rename(latitude_bin='latitude'),dc=1/3)

spect = spect.sel(phase_speed=slice(20,-20),latitude=slice(20,80),wavenumber=slice(1,10))

# meridional mean
exp4_spect = spect.sel(latitude=slice(35,65)).mean(('latitude')).compute().squeeze()

# centroid as a function of latitude
ds = xr.Dataset(dict(phase_speed=centroid(spect,'phase_speed'),wavenumber=centroid(spect,'wavenumber')))

exp4_centroid = ds.compute()

In [None]:
# Hayashi spectra for rolling windows
rolling = construct_rolling(exp8_coeff.isel(time=slice(None,71960)),nstep=14*4).chunk(dict(time=-1,rolling=100))

reduced = rolling.isel(rolling=range(0,len(rolling['rolling']),2))

spect = wavenum_freq_spect(reduced,timestep=6*3600,window=signal.tukey,window_args=(0.5,),pad=2*len(rolling['time']))

spect = freq2phase_speed_interp(spect.rename(latitude_bin='latitude'),dc=1/3)

spect = spect.sel(phase_speed=slice(20,-20),latitude=slice(20,80),wavenumber=slice(1,10))

# meridional mean
exp8_spect = spect.sel(latitude=slice(35,65)).mean(('latitude')).compute().squeeze()

# centroid as a function of latitude
ds = xr.Dataset(dict(phase_speed=centroid(spect,'phase_speed'),wavenumber=centroid(spect,'wavenumber')))

exp8_centroid = ds.compute()

## Inference

In [None]:
@numba.guvectorize(
    "(float64[:],float64[:],float64[:,:])",
    "(n), (m) -> (m,n)",
    forceobj=True
)
def random_sample(a,nb,out):
    '''
        Draw len(nb) random samples from array a
        'ziehen mit zuruecklegen'
        
        - nb is a dummy array to get dimension size
    '''
    lt = len(a)
    variates = stats.uniform.rvs(0,lt,lt*len(nb))
    variates = variates.astype(int).reshape(len(nb),lt)
    out[:,:] = a[variates]

    
@numba.guvectorize(
    "(float64[:],float64[:],float64[:])",
    "(n), (m) -> (m)",
    forceobj=True
)    
def icdf(a,p,out):
    '''
        Inverse empirical cummulative distribution function of array at percentiles p
    '''
    sort = np.sort(a)
    out[:] = sort[np.int64(p*len(a))]


def confid(da,alpha=0.05):
    '''
        Estimate confidence intervals using the inverse empirical cummulative distribution function 
        for distrubution of bootstrap samples.
    '''
    n_bootstrap = xr.DataArray(np.arange(10000),dims=('random'))
    sample = xr.apply_ufunc(random_sample,
                         *(da,n_bootstrap),
                         input_core_dims=[['rolling'],['random']],
                         output_core_dims=[['random','rolling']],
                         dask='parallelized',
                         output_dtypes=[[np.float64]])
    
    dist = sample.mean('rolling')
    
    p = xr.DataArray([alpha/2, 1-alpha/2],dims=('percentile'))
    values = xr.apply_ufunc(icdf,
                          *(dist,p),
                          input_core_dims=[['random'],['percentile']],
                          output_core_dims=[['percentile']],
                          dask='parallelized',
                          output_dtypes=[[dist.dtype]])
    values['percentile'] = p
    
    return values

In [None]:
def t_statistic(x1,x2,dim):
    '''
        T-statistic for the difference of the mean for two samples of equal length
    '''
    diff = x1.mean(dim) - x2.mean(dim)
    err = x1.var(dim) + x2.var(dim)
    err = np.sqrt(err/len(x1[dim]))
    return diff / err


def parametric_bootstrap(sample1,sample2,nb=1000,confid=0.05):
    '''
        Test ensemble mean difference
    '''
    # Produce control samples that fullfill the Null hypothesis
    c1 = sample1 - sample1.mean('rolling')
    c2 = sample2 - sample2.mean('rolling')
    
    # Resample control
    bootstrap = xr.DataArray(np.arange(nb),dims=('random'))
    c1 = xr.apply_ufunc(random_sample,
                         *(c1,bootstrap),
                         input_core_dims=[['rolling'],['random']],
                         output_core_dims=[['random','rolling']],
                         dask='parallelized',
                         output_dtypes=[[c1.dtype]])
    c2 = xr.apply_ufunc(random_sample,
                         *(c2,bootstrap),
                         input_core_dims=[['rolling'],['random']],
                         output_core_dims=[['random','rolling']],
                         dask='parallized',
                         output_dtypes=[[c1.dtype]])
    
    # t statistic for the resampled data
    dist = t_statistic(c1,c2,'rolling')
    
    # emperical cumulative distribution function
    p = xr.DataArray(np.linspace(0,0.999,1000),dims=('percentile'))
    dist = xr.apply_ufunc(icdf,
                          *(dist,p),
                          input_core_dims=[['random'],['percentile']],
                          output_core_dims=[['percentile']],
                          dask='parallelized',
                          output_dtypes=[[dist.dtype]])
    dist['percentile'] = p
    
    # check whether Null hypothesis can be rejected
    t = t_statistic(sample1,sample2,'rolling')
    sig = np.add(t < dist.sel(percentile=confid/2,method='nearest'), 
                 t > dist.sel(percentile=1-confid/2,method='nearest'))
    
    return sig

In [None]:
ref_CI = xr.Dataset(dict(phase_speed=confid(ref_centroid['phase_speed']),wavenumber=confid(ref_centroid['wavenumber'])))
ref_centroid = ref_centroid.mean('rolling')

exp9_CI = xr.Dataset(dict(phase_speed=confid(exp9_centroid['phase_speed']),wavenumber=confid(exp9_centroid['wavenumber'])))
exp9_centroid = exp9_centroid.mean('rolling')

exp4_CI = xr.Dataset(dict(phase_speed=confid(exp4_centroid['phase_speed']),wavenumber=confid(exp4_centroid['wavenumber'])))
exp4_centroid = exp4_centroid.mean('rolling')

exp8_CI = xr.Dataset(dict(phase_speed=confid(exp8_centroid['phase_speed']),wavenumber=confid(exp8_centroid['wavenumber'])))
exp8_centroid = exp8_centroid.mean('rolling')

In [None]:
sig_exp4 = parametric_bootstrap(exp4_spect,ref_spect,nb=2000,confid=0.001)

In [None]:
sig_exp9 = parametric_bootstrap(exp9_spect,ref_spect,nb=2000,confid=0.001)

## Plotting

In [None]:
fig = plt.figure(figsize=(6,8))

axes = [fig.add_subplot(3,2,1),fig.add_subplot(3,2,3),fig.add_subplot(3,2,5),
       fig.add_subplot(2,4,4),fig.add_subplot(2,4,8)]

# plot climatological mean

C1 = ref_spect.mean('rolling').plot.pcolormesh(ax=axes[0],levels=np.arange(0,7,0.5),extend='max',cmap=cmocean.cm.matter,add_colorbar=False)


l = axes[0].plot(centroid(ref_spect.mean('rolling'),'phase_speed').values,
                 centroid(ref_spect.mean('rolling'),'wavenumber').values,
                 marker='o',markeredgecolor='k',markersize=10,markeredgewidth=1.5)
l[0].set_markerfacecolor((0,0,0,0))

# configure axes

axes[0].set_xlim(-20,20)
axes[0].set_xticks([-20,-10,0,10,20])
axes[0].set_xticks([-15,-5,5,15],minor=True)
axes[0].set_ylim(1,10)
axes[0].set_yticks([2,4,6,8,10])
axes[0].set_yticks([1,3,5,7,9],minor=True)
axes[0].grid(axis='both')

axes[0].set_ylabel('Zonal wavenumber')
axes[0].set_xlabel('')
axes[0].set_title('Reference',weight='bold',fontsize='smaller')

cbar = plt.colorbar(C1,ax=axes[0],ticks=range(7))
cbar.set_label(r'Power spectral density [m s$^{-1}$]',fontsize=12)


# plot 'slow' composites

levels = np.concatenate([np.arange(-6.5,0,0.5),np.arange(0.5,7,0.5)])

C = (exp9_spect.mean('rolling')-ref_spect.mean('rolling')).plot.contour(ax=axes[1],levels=levels,cmap=cmocean.cm.rain,
                                        linestyles=np.where(levels>0.1,'solid','dotted'))
axes[1].clabel(C)

sig_exp9.astype(np.double).plot.contourf(ax=axes[1],levels=[0,0.5,1],hatches=['////',''],alpha=0,add_colorbar=False)

l = axes[1].plot(centroid(ref_spect.mean('rolling'),'phase_speed').values,
                 centroid(ref_spect.mean('rolling'),'wavenumber').values,
                 marker='o',markeredgecolor='k',markersize=10,markeredgewidth=1.5)
l[0].set_markerfacecolor((0,0,0,0))

axes[1].plot(centroid(exp9_spect.mean('rolling'),'phase_speed').values,
                 centroid(exp9_spect.mean('rolling'),'wavenumber').values,
                 marker='x',markeredgecolor='k',markersize=10,markeredgewidth=1.5)

# configure axes

axes[1].set_xlim(-20,20)
axes[1].set_xticks([-20,-10,0,10,20])
axes[1].set_xticks([-15,-5,5,15],minor=True)
axes[1].set_ylim(1,10)
axes[1].set_yticks([2,4,6,8,10])
axes[1].set_yticks([1,3,5,7,9],minor=True)
axes[1].grid(axis='both')

axes[1].set_ylabel('Zonal wavenumber')
axes[1].set_xlabel('')

axes[1].set_title('Tropical warming',weight='bold',fontsize='smaller')


# plot 'slow' composites

C = (exp4_spect.mean('rolling')-ref_spect.mean('rolling')).plot.contour(ax=axes[2],levels=levels,cmap=cmocean.cm.rain,
                                        linestyles=np.where(levels>0.1,'solid','dotted'))
axes[2].clabel(C)

sig_exp4.astype(np.double).plot.contourf(ax=axes[2],levels=[0,0.5,1],hatches=['////',''],alpha=0,add_colorbar=False)

l = axes[2].plot(centroid(ref_spect.mean('rolling'),'phase_speed').values,
                 centroid(ref_spect.mean('rolling'),'wavenumber').values,
                 marker='o',markeredgecolor='k',markersize=10,markeredgewidth=1.5)
l[0].set_markerfacecolor((0,0,0,0))

axes[2].plot(centroid(exp4_spect.mean('rolling'),'phase_speed').values,
                 centroid(exp4_spect.mean('rolling'),'wavenumber').values,
                 marker='x',markeredgecolor='k',markersize=10,markeredgewidth=1.5)


# configure axes

axes[2].set_xlim(-20,20)
axes[2].set_xticks([-20,-10,0,10,20])
axes[2].set_xticks([-15,-5,5,15],minor=True)
axes[2].set_ylim(1,10)
axes[2].set_yticks([2,4,6,8,10])
axes[2].set_yticks([1,3,5,7,9],minor=True)
axes[2].grid(axis='both')

axes[2].set_ylabel('Zonal wavenumber')
axes[2].set_xlabel(r'Phase speed [m s$^{-1}$]')

axes[2].set_title('Arctic warming',weight='bold',fontsize='smaller')


# configure figure

fig.subplots_adjust(0,0,1,0.9,-0.3,0.3)

box = list(axes[1].get_position().bounds)
box[2] = axes[0].get_position().bounds[2]
axes[1].set_position(box)

box = list(axes[2].get_position().bounds)
box[2] = axes[0].get_position().bounds[2]
axes[2].set_position(box)



# centroid

l1 = ref_centroid['phase_speed'].plot(ax=axes[3],linestyle='-')
l2 = exp9_centroid['phase_speed'].plot(ax=axes[3],linestyle='--')
l3 = exp4_centroid['phase_speed'].plot(ax=axes[3],linestyle=':')
l4 = exp8_centroid['phase_speed'].plot(ax=axes[3],linestyle='-.')

axes[3].fill_between(ref_CI['latitude'].values,
                     ref_CI['phase_speed'].sel(percentile=0.025).values.squeeze(),
                     ref_CI['phase_speed'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l1[0].get_c())
axes[3].fill_between(exp9_CI['latitude'].values,
                     exp9_CI['phase_speed'].sel(percentile=0.025).values.squeeze(),
                     exp9_CI['phase_speed'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l2[0].get_c())
axes[3].fill_between(exp4_CI['latitude'].values,
                     exp4_CI['phase_speed'].sel(percentile=0.025).values.squeeze(),
                     exp4_CI['phase_speed'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l3[0].get_c())
axes[3].fill_between(exp8_CI['latitude'].values,
                     exp8_CI['phase_speed'].sel(percentile=0.025).values.squeeze(),
                     exp8_CI['phase_speed'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l4[0].get_c())

ref_centroid['wavenumber'].plot(ax=axes[4],label='Reference',linestyle='-')
exp9_centroid['wavenumber'].plot(ax=axes[4],label='Tropical warming',linestyle='--')
exp4_centroid['wavenumber'].plot(ax=axes[4],label='Arctic warming',linestyle=':')
exp8_centroid['wavenumber'].plot(ax=axes[4],label='Combined forcing',linestyle='-.')

axes[4].fill_between(ref_CI['latitude'].values,
                     ref_CI['wavenumber'].sel(percentile=0.025).values.squeeze(),
                     ref_CI['wavenumber'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l1[0].get_c())
axes[4].fill_between(exp9_CI['latitude'].values,
                     exp9_CI['wavenumber'].sel(percentile=0.025).values.squeeze(),
                     exp9_CI['wavenumber'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l2[0].get_c())
axes[4].fill_between(exp4_CI['latitude'].values,
                     exp4_CI['wavenumber'].sel(percentile=0.025).values.squeeze(),
                     exp4_CI['wavenumber'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l3[0].get_c())
axes[4].fill_between(exp8_CI['latitude'].values,
                     exp8_CI['wavenumber'].sel(percentile=0.025).values.squeeze(),
                     exp8_CI['wavenumber'].sel(percentile=0.975).values.squeeze(),
                     alpha=0.3,color=l4[0].get_c())


for ax in axes[3:]:
    ax.set_title('')
    ax.set_xlabel('Latitude [°N]')
    ax.grid()
    ax.set_xticks([35,50,65])
    ax.set_xlim([22.5,77.5])
    
    ylim = ax.get_ylim()
    ax.plot([41.5,41.5],ylim,color=l1[0].get_c(),linewidth=1,linestyle=':')
    ax.plot([45.75,45.75],ylim,color=l2[0].get_c(),linewidth=1,linestyle=':')
    ax.plot([37.5,37.5],ylim,color=l3[0].get_c(),linewidth=1,linestyle=':')
    ax.set_ylim(ylim)
    
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    ax.tick_params(axis='y',left=True,right=True)
    
axes[3].set_ylabel(r'Phase speed centroid [m s$^{-1}$]')
axes[4].set_ylabel(r'Zonal wavenumber centroid')

axes[4].legend(fontsize=10)


trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)

axes[0].text(-0.1,1.06,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom')
axes[1].text(-0.1,1.06,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom')
axes[2].text(-0.1,1.06,'c)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom')
axes[3].text(0.1,1.06,'d)',transform=axes[3].transAxes+trans,fontsize='large',va='bottom')
axes[4].text(0.1,1.06,'e)',transform=axes[4].transAxes+trans,fontsize='large',va='bottom')