# Hayashi spectra

Climatological-mean (upper panel) and composite-mean anomalous (middel and lower) Hayashi spectra, i.e. meridionally averaged (35°N-65°N) power spectral density of meridional wind (250 hPa) in coordinates of zonal phase speed and zonal wavenumber. The thick solid line indicates the centroid of the climatological mean spectrum and the hatching denotes areas where the composite mean anomaly is not statistically significant. Contours lines of composite-mean anomalous power spectral density are dotted where non-positive with a contour spacing of 0.6 m s$^{-1}$.

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cmocean
import os

from dask.distributed import Client
from scipy import stats
from matplotlib.colors import Normalize
from matplotlib.colorbar import ColorbarBase

from wave_util import construct_rolling_dataset, remove_climatology, compute_spectra, centroid, integral
from composite_util import composite_dates, ecdf, bootstrap_correlation, parametric_bootstrap

work = os.environ.get('WORK')+'/'
plt.rcParams.update({'font.size': 14})

In [None]:
client = Client()

client

In [None]:
client.close()

## Computation

In [None]:
# construct list nested list of filenames

directory = work+'DATA/ERA5/eth/plev/'

files = [directory + f for f in os.listdir(directory) if f.startswith('era5_an_vwind_reg2_6h') and 
                                            (f.endswith('06.nc') or
                                             f.endswith('07.nc') or
                                             f.endswith('08.nc'))]

files.sort() 

n_valid_years = int(len(files)/3)

files = [files[i:i+3] for i in range(0,n_valid_years*3,3)]

In [None]:
# compute meridional mean spectra

selection = dict(lat=slice(65,35),plev=25000)

rolling = construct_rolling_dataset(files,selection=selection,n_per_window=30*4)

anomalies = remove_climatology(rolling)

In [None]:
spectra = compute_spectra(anomalies['var132'].rename(dict(lat='latitude',lon='longitude')),wavenumber=slice(1,10),dc=1/3)

spectra = spectra.mean('latitude')
spectra = spectra.compute()

In [None]:
def estimate_composite_mean(series,field,percentile=0.9,nrandom=10000):
    
    dates = composite_dates(series.rename(rolling='time'),percentile=percentile)
    composite = field.sel(rolling=dates)
    sig = parametric_bootstrap(composite,field,nrandom=nrandom,pvalue=0.05)
    
    return xr.Dataset(dict(mean=composite.mean('time'),sig=sig))


slow = []
fast = []

for w in range(4,9):
    
    field = spectra.sel(wavenumber=w)
    
    # create time series for composite construction
    bound = centroid(field.mean('rolling'))
    slow.append(integral(field,bound.values,-30))
    fast.append(integral(field,30,bound.values))

    
slow = sum(slow)
fast = sum(fast)
    
composite_slow = []
composite_fast = []

for w in spectra['wavenumber']:
    
    field = spectra.sel(wavenumber=w)
    
    # estimate composite mean
    anomalies = field - field.mean('rolling')
    composite_slow.append(estimate_composite_mean(slow,anomalies))
    composite_fast.append(estimate_composite_mean(fast,anomalies))
    
    
composite_slow = xr.concat(composite_slow,dim='wavenumber')    
composite_fast = xr.concat(composite_fast,dim='wavenumber')

In [None]:
# store spectra to be used in other notebooks
xr.Dataset(dict(spect=spectra.sel(wavenumber=slice(5,8)))).to_netcdf(work+'wolfgang/spectra_30days_65-35N_wave5-8.nc')

## Figure 1

In [None]:
fig, axes = plt.subplots(nrows=3,figsize=(6,9))

# plot climatological mean

C1 = spectra.mean('rolling').plot.pcolormesh(ax=axes[0],levels=np.linspace(0,3,21),extend='max',cmap=cmocean.cm.matter,add_colorbar=False)

axes[0].plot(centroid(spectra.mean('rolling')).values,np.arange(1,11),'-k',label='centroid',linewidth=3)

axes[0].legend(loc='upper left')

# configure axes

axes[0].set_xlim(-20,20)
axes[0].set_xticks([-20,-10,0,10,20])
axes[0].set_ylim(1,10)
axes[0].set_yticks([2,4,6,8,10])
axes[0].set_yticks([1,3,5,7,9],minor=True)
axes[0].grid(axis='x')

axes[0].set_ylabel('Zonal wavenumber')
axes[0].set_xlabel('')
axes[0].set_title('JJA climatological mean',weight='bold')

plt.colorbar(C1,ax=axes[0],label=r'Power spectral density [m s$^{-1}$]',ticks=[0,0.6,1.2,1.8,2.4,3])


# plot 'slow' composites
C = composite_slow['mean'].plot.contour(ax=axes[1],levels=np.arange(-2,2.5,0.5),cmap=cmocean.cm.rain,
                                        linestyles=np.where(np.arange(-2,2.5,0.5)>0.1,'solid','dotted'))
axes[1].clabel(C)

composite_slow['sig'].astype(np.double).plot.contourf(ax=axes[1],levels=[0,0.5,1],hatches=['//',''],alpha=0,add_colorbar=False)

axes[1].plot(centroid(spectra.mean('rolling')).values,np.arange(1,11),'-k',label='centroid',linewidth=3)


# configure axes

axes[1].set_xlim(-20,20)
axes[1].set_xticks([-20,-10,0,10,20])
axes[1].set_ylim(1,10)
axes[1].set_yticks([2,4,6,8,10])
axes[1].set_yticks([1,3,5,7,9],minor=True)

axes[1].set_ylabel('Zonal wavenumber')
axes[1].set_xlabel('')

axes[1].set_title("'Amplfied Slow' anomaly",weight='bold')


# plot 'slow' composites

C = composite_fast['mean'].plot.contour(ax=axes[2],levels=np.arange(-2,2.5,0.5),cmap=cmocean.cm.rain,
                                        linestyles=np.where(np.arange(-2,2.5,0.5)>0.1,'solid','dotted'))
axes[2].clabel(C)

composite_fast['sig'].astype(np.double).plot.contourf(ax=axes[2],levels=[0,0.5,1],hatches=['//',''],alpha=0,add_colorbar=False)

axes[2].plot(centroid(spectra.mean('rolling')).values,np.arange(1,11),'-k',label='centroid',linewidth=3)



# configure axes

axes[2].set_xlim(-20,20)
axes[2].set_xticks([-20,-10,0,10,20])
axes[2].set_ylim(1,10)
axes[2].set_yticks([2,4,6,8,10])
axes[2].set_yticks([1,3,5,7,9],minor=True)

axes[2].set_ylabel('Zonal wavenumber')
axes[2].set_xlabel(r'Phase speed [m s$^{-1}$]')

axes[2].set_title("'Amplified Fast' anomaly",weight='bold')


# configure figure

fig.subplots_adjust(0,0,1,1,0,0.3)

box = list(axes[1].get_position().bounds)
box[2] = axes[0].get_position().bounds[2]
axes[1].set_position(box)

box = list(axes[2].get_position().bounds)
box[2] = axes[0].get_position().bounds[2]
axes[2].set_position(box)


## Correlation between timeseries of wave energy


In [None]:
# create time series of wave energy
slow = []
fast = []

for w in spectra['wavenumber']:
    
    bound = centroid(spectra.sel(wavenumber=w).mean('rolling'))
    slow.append(integral(spectra.sel(wavenumber=w),bound.values,-30))
    fast.append(integral(spectra.sel(wavenumber=w),30,bound.values))

slow = xr.concat(slow,dim='wavenumber')
fast = xr.concat(fast,dim='wavenumber')

In [None]:
# fill matrix of correlations

correlation = np.ones((20,20))
significance = np.zeros((20,20),dtype='bool')

# upper left quadrant (slow-slow)
for i in range(10):
    for j in range(i+1,10):
        
        ds = bootstrap_correlation(slow.isel(wavenumber=i),slow.isel(wavenumber=j))
        
        correlation[i,j] = ds['correlation'].values
        correlation[j,i] = ds['correlation'].values
        
        if (ds['p_value'] > 0.975) + (ds['p_value'] < 0.025):
            significance[i,j] = True
            significance[j,i] = True
        
# lower right quadrant (fast-fast)
for i in range(10):
    for j in range(i+1,10):
        
        ds = bootstrap_correlation(fast.isel(wavenumber=i),fast.isel(wavenumber=j))
        
        correlation[i+10,j+10] = ds['correlation'].values
        correlation[j+10,i+10] = ds['correlation'].values
        
        if (ds['p_value'] > 0.975) + (ds['p_value'] < 0.025):
            significance[i+10,j+10] = True
            significance[j+10,i+10] = True
        
        
# neighbor quadrant (slow-fast)

for i in range(10):
    for j in range(10):
        
        ds = bootstrap_correlation(slow.isel(wavenumber=i),fast.isel(wavenumber=j))
        
        correlation[i,j+10] = ds['correlation'].values
        correlation[j+10,i] = ds['correlation'].values
        
        if (ds['p_value'] > 0.975) + (ds['p_value'] < 0.025):
            significance[i,j+10] = True
            significance[j+10,i] = True


## Figure S1

In [None]:
# plotting

fig = plt.figure()
ax = plt.axes(position=[0.2,0,0.7,0.7])

ax.plot([0.5,0.5],[1,0],color='k',linewidth=3)
ax.plot([1,0],[0.5,0.5],color='k',linewidth=3)

ax.set_xlim(0,1)
ax.set_ylim(0,1)

# cell color

norm = Normalize(vmin=-1,vmax=1)


colors = [cmocean.cm.balance(norm(x)) if sig else (0,0,0,0) for x,sig in zip(correlation.reshape(correlation.size),significance.reshape(significance.size))]
colors = np.array([c if not(x==1) else (0.5,0.5,0.5,0.5) for c, x in zip(colors,correlation.reshape(correlation.size))])
colors = colors.reshape((20,20,4))

# print string
strings = np.array(['%.2f'%c if not(c==1) else '' for c in correlation.reshape(correlation.size)])
strings = strings.reshape(correlation.shape)

table = ax.table(strings,colors,bbox=[0,0,1,1],)

# format string

for i in range(10):
    table.get_celld()[(10+i,0+i)].set_text_props(weight='bold')
    table.get_celld()[(0+i,10+i)].set_text_props(weight='bold')

# colorbar
cax = plt.axes(position=[0.95,0,0.05,0.7])
cbar = ColorbarBase(cax,cmap=cmocean.cm.balance,norm=norm,
                    values=np.arange(-0.4,0.4,0.01),extend='both',
                    ticks=[-0.3,0,0.3],
                    label='Correlation')


# labels

ax.set_ylim(1,0)
ax.set_xticks(np.arange(1/40,1,1/10),minor=False)
ax.set_yticks(np.arange(1/40,1,1/10),minor=False)
ax.set_xticklabels([1,3,5,7,9,1,3,5,7,9],minor=False)
ax.set_yticklabels([1,3,5,7,9,1,3,5,7,9],minor=False)
ax.set_xticks(np.linspace(0,1,21),minor=True)
ax.set_yticks(np.linspace(0,1,21),minor=True)
ax.set_xticklabels([],minor=True)
ax.set_yticklabels([],minor=True)
ax.tick_params(axis='both',which='major',length=5,bottom=False,
               top=True,labelbottom=False,labeltop=True)
ax.tick_params(axis='both',which='minor',length=0,bottom=False,
               top=True,labelbottom=False,labeltop=True)

# white grid

ax.grid(which='minor',color='w',linestyle='-',linewidth=3)
ax.spines.top.set(color='w',linewidth=3)
ax.spines.bottom.set(color='w',linewidth=3)
ax.spines.left.set(color='w',linewidth=3)
ax.spines.right.set(color='w',linewidth=3)



# sup labels

ax.text(-0.15,0.25,'Slow',weight='bold',rotation=90,verticalalignment='center')
ax.text(-0.15,0.75,'Fast',weight='bold',rotation=90,verticalalignment='center')

ax.text(0.25,-0.15,'Slow',weight='bold',horizontalalignment='center')
ax.text(0.75,-0.15,'Fast',weight='bold',horizontalalignment='center')
