# Wavelet analysis

Create Figure 7. This requires standard pressure level output of temperture to normalize potential temperature perturbations. These are derived from horizontally filterd model level ouput. The natural logarithm of model level surface pressure is required for the interpolation of model level output to fine pressure levels. Vertical wavelet power spectra to create Figure 7 are stored in this repository since their computation is expensive. Torrence & Campo (1998) and https://github.com/regeirk/pycwt are excellent references for wavelet analysis.

In [None]:
import numpy as np
import xarray as xr
import numba
import dask.array
from scipy.signal import tukey
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import cmocean
import os

plt.rcParams.update({'font.size': 14})


## Wavelet transform


In [None]:
def zero_pad(da,dim):
    '''
        Zero padding to achieve length of an exponential of 2
    '''
    n = int(2 ** np.ceil(np.log2(len(da[dim])+1)))

    pad = da.pad(pad_width={dim:(0,n-len(da[dim]))},
                 mode='constant',
                 constant_values=0)
    
    pad = pad.chunk({dim:-1})
    
    return pad
    
    
    

def fft(da,dim,coord):
    '''
        Wrapper to use discrete Fourier transform with xarray.DataArray
    '''
    k = xr.DataArray(np.fft.fftfreq(len(coord),d=1/len(coord)),dims=('wavenumber_index'))
    
    # function that accepts dummy argument
    fft = lambda da,k: dask.array.fft.fft(da) / len(k)
    
    transform = xr.apply_ufunc(fft,
                               da,k,
                               input_core_dims=[[dim],['wavenumber_index']],
                               output_core_dims=[['wavenumber_index'],],
                               dask='allowed',
                               output_dtypes=[np.complex_]
                              )
    transform['wavenumber_index'] = k
    
    return transform



def ifft(transform,dim,coord):
    '''
        Wrapper to use inverse discrete Fourier transform with xarray.DataArray
    '''
    # function that accepts dummy argument
    irfft = lambda da,coord: dask.array.fft.ifft(da) * len(coord)
    
    da = xr.apply_ufunc(irfft,
                        transform,coord,
                        input_core_dims=[['wavenumber_index'],[dim]],
                        output_core_dims=[[dim],],
                        dask='allowed',
                        output_dtypes=[np.float_]
                        )
    
    return da

In [None]:
def wavelet_wrapper(N,dz,dj=1/12):
    '''
        Wrapper for continous wavelet transform as described by Torrence & Compo (1998)
    '''
    
    # Morlet window
    flambda = (4 * np.pi) / (6 + np.sqrt(2 + 6 ** 2))
    coi = 1. / np.sqrt(2)
    psi_ft = lambda f: (np.pi ** -0.25) * np.exp(-0.5 * (f - 6) ** 2)
    
    s0 = 2 * dz / flambda
    J = int(np.round(np.log2(N * dz / s0) / dj))
    sj = xr.DataArray(s0 * 2 ** (np.arange(0, J + 1) * dj),dims=('scale',))
    
    # Fourier equivalent frequencies
    wavelength = flambda * sj
    # cone of influence where power of wavelet drops by exp(-2)
    coi = sj / coi / flambda

    wavenum = xr.DataArray(2 * np.pi * np.fft.fftfreq(N,d=dz),
                           dims=('wavenumber_index'),
                           coords=[np.fft.fftfreq(N,d=1/N)])
    
    # scaled wavelet transform
    psi_ft_bar = (np.sqrt(sj * wavenum * N,dtype=complex) *
                  np.conjugate(psi_ft(sj * wavenum)))
    
    return psi_ft_bar, wavelength, coi

## Interpolation to regular log-pressure grid

In [None]:
# scale height
H = 7000
# reference pressure in hPa
ps = 1013.25
kappa = 0.2854

# vertical coordinate with p in hPa
log_pressure = lambda p: -H * np.log(p / ps)

# potential temperature
potential_temperature = lambda t, p: t * np.exp(np.log(ps/p) * kappa)

In [None]:
@numba.guvectorize(
    "(float64[:], float64[:], float64[:], float64[:])",
    "(n), (n), (m) -> (m)",
    nopython=True,
)
def vectorized_interp(data,x,xi,out):
    '''
        Vectorized 1-D interpolation
        
        - much faster than looping or np.vectorize
        - does not work with scipy's cubic spline interpolatin in nopython mode:
          out[:] = interpolate.interp1d(x,data,kind='cubic',fill_value='extrapolate')(xi)
    '''
    out[:] = np.interp(xi,x,data)
    
    

def to_regular_grid(da,z,dz,zmax,dim='isobaricInhPa'):
    '''
    '''
    zi = xr.DataArray(np.arange(0,zmax+dz,dz),dims=('height',))
    
    regular = xr.apply_ufunc(vectorized_interp,
                             *(da,z,zi),
                             input_core_dims=[[dim],[dim],['height']],
                             output_core_dims=[['height']],
                             dask='parallelized',
                             output_dtypes=[da.dtype]
                            )
    regular['height'] = zi
    
    return regular

## Power spectra

In [None]:
def regridding(da,dz,zmax,pressure=None):
    '''
    '''
    dims = list(da.dims)
    if 'pressure' in dims:
        dim = 'pressure'
    elif 'level' in dims:
        dim = 'level'
    elif 'isobaricInhPa' in dims:
        dim = 'isobaricInhPa'
    elif 'hybrid' in dims:
        dim = 'hybrid'
    else:
        raise KeyError('Vertical dimension not recognized')
        
    if pressure is None:
        pressure = da[dim]
    
    theta = potential_temperature(da,pressure)
    
    z = log_pressure(pressure)
    
    regridded = to_regular_grid(theta,z,dz,zmax,dim=dim)
    
    return regridded



def normalization(pert,full):
    '''
    '''
    dims = list(pert.dims)
    if 'longitude' in dims:
        lon_name = 'longitude'
        lat_name = 'latitude'
        cos_phi= np.cos(np.radians(pert['latitude']))
    elif 'lon' in dims:
        lon_name = 'lon'
        lat_name = 'lat'
        cos_phi= np.cos(np.radians(pert['lat']))
    else:
        raise KeyError('Zonal dimension not recognized')
    
    mean = full.mean(lon_name).compute()
    mean = mean.interp(**{lat_name:pert[lat_name]})
    
    
    normalized = pert / mean
    normalized = normalized * np.exp(-pert['height']/(2*H))
    
    norm = (normalized*np.sqrt(cos_phi)).std().compute() / cos_phi.mean()
    normalized = normalized / norm
    
    return normalized, norm


def tapering(da,height,dz):
    '''
    '''
    alpha = 2 * int(height/dz) / len(da['height'])
    window = xr.DataArray(tukey(len(da['height']),alpha=alpha),coords=dict(height=da['height']))
    
    tapered = da * window
    return tapered
    



def transformation(da,dz,dj=1/6):
    '''
    '''
    da_ft = fft(da,'height',da['height'])
    wavelet, wavelength, coi = wavelet_wrapper(len(da['height']),dz,dj=dj)
    
    transform = da_ft * wavelet
    transform = ifft(transform,'height',da['height'])
    
    transform = transform.assign_coords(dict(scale=wavelength)).rename(scale='wavelength')
    coi = coi.assign_coords(dict(scale=wavelength)).rename(scale='wavelength')
    
    return transform, coi
    
    
def mean_power(da,taper,dz):
    '''
    '''
    dims = list(da.dims)
    dims.remove('wavelength')
    dims.remove('height')
    
    Y = np.abs(da) ** 2
    Y = Y.mean(dims)
    
    window = tapering(xr.ones_like(Y),taper,dz)
    Y = Y / (window ** 2).mean()
    
    return Y.where(np.isfinite(Y.height),drop=True)
    

## Parametric bootstrap

In [None]:
@numba.guvectorize(
    "(float64[:],float64[:],float64[:,:])",
    "(n), (m) -> (m,n)",
    forceobj=True
)
def random_sample(a,nb,out):
    '''
        Draw len(nb) random samples from array a
        'ziehen mit zuruecklegen'
        
        - nb is a dummy array to get dimension size
    '''
    lt = len(a)
    variates = stats.uniform.rvs(0,lt,lt*len(nb))
    variates = variates.astype(int).reshape(len(nb),lt)
    out[:,:] = a[variates]
    

@numba.guvectorize(
    "(float64[:],float64[:],float64[:])",
    "(n), (m) -> (m)",
    forceobj=True
)    
def ecdf(a,p,out):
    '''
        Emperical cummulative distribution function of array
        at percentiles p
    '''
    sort = np.sort(a)
    out[:] = sort[np.int64(p*len(a))]
    
    
def t_statistic(x1,x2,dim):
    '''
        T-statistic for the difference of the mean for two samples of equal length
    '''
    diff = x1.mean(dim) - x2.mean(dim)
    err = x1.var(dim) + x2.var(dim)
    err = np.sqrt(err/len(x1[dim]))
    return diff / err


def parametric_bootstrap(sample1,sample2,nb=1000,confid=0.05):
    '''
        Test ensemble mean difference
    '''
    # Produce control samples that fullfill the Null hypothesis
    c1 = sample1 - sample1.mean('number')
    c2 = sample2 - sample2.mean('number')
    
    # Resample control
    bootstrap = xr.DataArray(np.arange(nb),dims=('random'))
    c1 = xr.apply_ufunc(random_sample,
                         *(c1,bootstrap),
                         input_core_dims=[['number'],['random']],
                         output_core_dims=[['random','number']],
                         dask='parallelized',
                         output_dtypes=[[c1.dtype]])
    c2 = xr.apply_ufunc(random_sample,
                         *(c2,bootstrap),
                         input_core_dims=[['number'],['random']],
                         output_core_dims=[['random','number']],
                         dask='parallized',
                         output_dtypes=[[c1.dtype]])
    
    # t statistic for the resampled data
    dist = t_statistic(c1,c2,'number')
    
    # emperical cumulative distribution function
    p = xr.DataArray(np.linspace(0,0.999,1000),dims=('percentile'))
    dist = xr.apply_ufunc(ecdf,
                          *(dist,p),
                          input_core_dims=[['random'],['percentile']],
                          output_core_dims=[['percentile']],
                          dask='parallelized',
                          output_dtypes=[[dist.dtype]])
    dist['percentile'] = p
    
    # check whether Null hypothesis can be rejected
    t = t_statistic(sample1,sample2,'number')
    sig = np.add(t < dist.sel(percentile=confid/2,method='nearest'), 
                 t > dist.sel(percentile=1-confid/2,method='nearest'))
    
    return sig

## Computation - filtered model level data

In [None]:
realization = 8
area = dict(latitude=slice(70,45))
time_slice = ['2018-02-22',
              '2018-03-22']


model_levels = data_dir+'TCo639_nudged/20180208_91L/model_levels_40Nto80N_0.2x0.2/'

pert = xr.open_dataset(model_levels+'T_pert_data/T_40Nto80N_%d_21_639_0.2x0.2_gg.grb'%realization,
                       chunks=dict(step=1,latitude=2),engine='cfgrib')['t']
pert = pert.drop(('step','time')).set_index(step='valid_time').rename(step='time')
pert = pert.sel(time=slice(np.datetime64(time_slice[0]),np.datetime64(time_slice[1])),**area)

lnsp = xr.open_dataset(model_levels+'lnsp/lnsp_%d.grb'%realization,chunks=dict(step=1,latitude=8))['lnsp']
lnsp = lnsp.drop(('step','time')).set_index(step='valid_time').rename(step='time')
lnsp = lnsp.sel(time=slice(np.datetime64(time_slice[0]),np.datetime64(time_slice[1])),**area)


pressure_levels = data_dir+'TCo639_nudged/20180208_91L/pressure_levels_F64/'

full = xr.open_dataset(pressure_levels+'t/t_%d.grb'%realization,
                       chunks=dict(step=1,latitude=2),engine='cfgrib')['t']
full = full.drop(('step','time')).set_index(step='valid_time').rename(step='time')
full = full.sel(time=slice(np.datetime64(time_slice[0]),np.datetime64(time_slice[1])),**area)


AandB = np.loadtxt('./AandB_91L.txt',skiprows=1)
level = xr.DataArray(AandB[:,0],dims=('hybrid'),name='hybrid')
A = xr.DataArray(AandB[:,1],dims=('hybrid'),coords=dict(hybrid=level),name='A')
B = xr.DataArray(AandB[:,2],dims=('hybrid'),coords=dict(hybrid=level),name='B')

pressure = (A + B * np.exp(lnsp)) / 100

pressure = pressure.reindex(hybrid=pressure.hybrid[:0:-1])
pert = pert.reindex(hybrid=pert.hybrid[:0:-1])

pressure

In [None]:
dz = 100
zmax = 64000
taper = zmax/2

full_regrid = regridding(full,dz,zmax)
pert_regrid = regridding(pert,dz,zmax,pressure)
normalized, norm = normalization(pert_regrid,full_regrid)
tapered = tapering(normalized,taper,dz)
transformed, coi = transformation(tapered,dz)

power = mean_power(transformed,taper,dz)
power = power.compute()

power

In [None]:
title = 'spectogram_'+str(realization)+'_'+time_slice[0]+'_'+time_slice[1]
xr.Dataset(dict(power=power,norm=norm,coi=coi)).to_netcdf(work_dir+title+'.nc')

## Figure 7

Ensemble-mean non-dimensional potential energy wavelet spectrum horizontally averaged between 45 and 70° N for the period 22 February to 22 March 2018 in the (a) TCo639L91 and (b) TCo639L198 nudged simulations and (c) the ensemble-mean difference TCo639L198 - TCo639L91. The hatched area indicates the theoretical cone of influence and stippling in the lower panel indicates where ensemble-mean energies are not significantly different estimated by a parametric bootstrap.

In [None]:
def compare_wavelet_spectra(L91,L198,H=7000):
    '''
    '''
    lambdas = L91['wavelength'].values / 1000
    zmax = L91['height'].max().values  / 1000
    coi = L91['coi'].mean('number')    / 1000
    coi['wavelength'] = coi['wavelength'] / 1000
    
    L91 = (L91['power'] * L91['norm']).sel(wavelength=slice(0,2*H))
    L91['wavelength'] = L91['wavelength'] / 1000
    L91['height'] = L91['height'] / 1000
    L198 = (L198['power'] * L198['norm']).sel(wavelength=slice(0,2*H))
    L198['wavelength'] = L198['wavelength'] / 1000
    L198['height'] = L198['height'] / 1000
    
    difference = L198.mean('number') - L91.mean('number')
    
    sig = parametric_bootstrap(L198,L91,nb=4000)
    
    # plotting
    
    fig, axes = plt.subplots(nrows=3,sharex='all',figsize=(6,9))
    
    # first subplot
    
    C1 = L91.mean('number').plot.pcolormesh(ax=axes[0],
                           y='height',
                           levels=np.linspace(0,0.04,21),
                           cmap=cmocean.cm.matter,
                           extend='max',
                           add_colorbar=False)

    axes[0].fill_between(lambdas,
                np.zeros(len(lambdas)),
                coi.sel(wavelength=lambdas).values,
                color='k', alpha=0.3, hatch='x')

    axes[0].fill_between(lambdas,
                zmax - coi.sel(wavelength=lambdas).values,
                zmax * np.ones(len(lambdas)),
                color='k', alpha=0.3, hatch='x')
    
    axes[0].set_ylabel('height [km]')
    axes[0].set_xlabel(None)
    
    ax = axes[0].twinx()
    ylim = axes[0].get_ylim()
    ax.set_ylim(1013.25*np.exp(-np.array(ylim)*1000/H))
    ax.set_yscale('log')
    ax.set_ylabel('pressure [hPa]')
    
    ax = axes[0].twinx()
    ax.set_yticks([])
    ax.set_ylabel('energy L91',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    
    # second subplot
    
    L198.mean('number').plot.pcolormesh(ax=axes[1],
                           y='height',
                           levels=np.linspace(0,0.04,21),
                           cmap=cmocean.cm.matter,
                           extend='max',
                           add_colorbar=False)

    axes[1].fill_between(lambdas,
                np.zeros(len(lambdas)),
                coi.sel(wavelength=lambdas).values,
                color='k', alpha=0.3, hatch='x')

    axes[1].fill_between(lambdas,
                zmax - coi.sel(wavelength=lambdas).values,
                zmax * np.ones(len(lambdas)),
                color='k', alpha=0.3, hatch='x')
    
    axes[1].set_ylabel('height [km]')
    axes[1].set_xlabel(None)
    
    ax = axes[1].twinx()
    ylim = axes[1].get_ylim()
    ax.set_ylim(1013.25*np.exp(-np.array(ylim)*1000/H))
    ax.set_yscale('log')
    ax.set_ylabel('pressure [hPa]')
    
    ax = axes[1].twinx()
    ax.set_yticks([])
    ax.set_ylabel('energy L198',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    
    # third subplot
    
    C2 = difference.plot.pcolormesh(ax=axes[2],
                               y='height',
                               levels=np.linspace(-0.02,0.02,25),
                               cmap=cmocean.cm.balance,
                               extend='both',
                               add_colorbar=False)
    
    sig.astype(np.double).plot.contourf(ax=axes[2],y='height',levels=[0,0.5,1],hatches=['.',''],
                                        alpha=0,add_colorbar=False)
    

    axes[2].fill_between(lambdas,
                    np.zeros(len(lambdas)),
                    coi.sel(wavelength=lambdas).values,
                    color='k', alpha=0.3, hatch='x')

    axes[2].fill_between(lambdas,
                    zmax - coi.sel(wavelength=lambdas).values,
                    zmax * np.ones(len(lambdas)),
                    color='k', alpha=0.3, hatch='x')
    
    axes[2].set_ylabel('height [km]')
    axes[2].set_xlabel('wavelength [km]')
    
    ax = axes[2].twinx()
    ylim = axes[2].get_ylim()
    ax.set_ylim(1013.25*np.exp(-np.array(ylim)*1000/H))
    ax.set_yscale('log')
    ax.set_ylabel('pressure [hPa]')
    
    ax = axes[2].twinx()
    ax.set_yticks([])
    ax.set_ylabel('difference L198-L91',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    # shared
    trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)
    
    axes[0].text(-0.06,1.1,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[1].text(-0.06,1.1,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[2].text(-0.06,1.1,'c)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    
    fig.subplots_adjust(0,0,1,0.95,0,0)
    
    cbar = plt.colorbar(C1,ax=axes[0:2],orientation='vertical',fraction=0.1,aspect=20,shrink=0.95,pad=0.15)
    
    cbar = plt.colorbar(C2,ax=axes[2],orientation='vertical',fraction=0.1,aspect=10,shrink=0.95,pad=0.15)
    
    

    
time_slice = ['2018-02-22',
              '2018-03-22']
suffix = '_'+time_slice[0]+'_'+time_slice[1]+'.nc'

experiment = 'TCo639_nudged_198L'
files_single = !ls {'./vertical_wavelet_power_spectra/'+experiment+'/spectogram_?'+suffix}
files_double = !ls {'./vertical_wavelet_power_spectra/'+experiment+'/spectogram_??'+suffix}
L198 = xr.open_mfdataset(files_single+files_double,combine='nested',concat_dim='number').load()

experiment = 'TCo639_nudged_91L'
files_single = !ls {'./vertical_wavelet_power_spectra/'+experiment+'/spectogram_?'+suffix}
files_double = !ls {'./vertical_wavelet_power_spectra/'+experiment+'/spectogram_??'+suffix}
L91 = xr.open_mfdataset(files_single+files_double,combine='nested',concat_dim='number').load()

compare_wavelet_spectra(L91,L198)