# Gravity wave drag

Create Figures 4, 5, 6, A3, and B1. This requires resolved gravity wave momentum fluxes, parameterized gravity wave drag and zonal wind for the simulations:

- TCo319_free_running/20180208_137L (hm2h)
- TCo319_free_running/20180208_198L (hjy3)
- TCo319_free_running/20180208_91L (hjwz)
- TCo639_free_running/20060117_198L (hopf)
- TCo639_free_running/20060117_91L (hopg)
- TCo639_free_running/20100205_198L (hope)
- TCo639_free_running/20100205_91L (hopd)
- TCo639_free_running/20180208_198L (hokx)
- TCo639_free_running/20180208_91L (hokw)
- TCo639_nudged/20180208_198L (hq1m)
- TCo639_nudged/20180208_91L (hpxj)

In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import matplotlib.colors as mcolors
import matplotlib.cm as mcm
import cmocean
import os

plt.rcParams.update({'font.size': 14})


In [None]:
def load_Grib(directory,chunks={}):
    '''
        Open one grib file for each ensemble member and concatenate into one xr.Dataset
    '''
    files = [directory+f for f in os.listdir(directory) if f.endswith('.grb')]
    files.sort()
    ds = xr.open_mfdataset(files,engine='cfgrib',chunks=chunks,combine='nested',concat_dim='number')
    
    return ds


def resample2daily(ds):
    '''
        Pandas' default behaviour for resampling is to put the time stamp at the beginning of the bin.
    '''
    # Rename dimensions
    try:
        ds = ds.drop(('step','time')).set_index(step='valid_time').rename(step='time')
    except:
        print('Exception: Time dimension is not renamed')
    if 'isobaricInhPa' in ds.dims:
        ds = ds.rename(isobaricInhPa='level')
        
    attrs = {}
    for var in ds.data_vars:
        try:
            attrs.update({var:dict(standard_name=ds[var].attrs['standard_name'],units=ds[var].attrs['units'])})
        except KeyError:
            try:
                attrs.update({var:dict(standard_name=ds[var].attrs['long_name'],units=ds[var].attrs['units'])})
            except:
                print('Exception: No name attribute') 
    
    # Resample to daily mean
    ds = ds.resample(time='1D').mean()

    for var in ds.data_vars:
        try:
            ds[var].attrs = attrs[var]
        except:
            # No name attribute
            print('')
        
    return ds


def average_from_accumulated(ds,freq='1D'):
    '''
        Compute average from accumulated tendencies
        
        - time series is one element short of the resampled one
    '''
    # Rename dimensions
    ds = ds.drop(('step','time')).set_index(step='valid_time').rename(step='time')
    if 'isobaricInhPa' in ds.dims:
        ds = ds.rename(isobaricInhPa='level')
        
    # Bin time stamps
    bounds = list(ds.resample(time=freq).groups.keys())
    
    # Average tendency
    ds = ds.sel(time=bounds).drop('time')
    ds = ds.isel(time=slice(1,None,None))-ds.isel(time=slice(None,-1,None))
    ds = ds.assign_coords(dict(time=bounds[:-1]))
    ds = ds / pd.Timedelta(freq).total_seconds()
        
    return ds

In [None]:
def area_weighted_mean(da):
    weights = np.cos(da['latitude'] * np.pi/180)
    da = da * weights
    da = da.mean(dim='latitude')
    da = da / weights.mean(dim='latitude')
    return da


# compute meridional mean over selected latitude range
aggregation = lambda da: area_weighted_mean(da.sel(latitude=slice(70,45)))


def wave_drag(flux):
    '''
        Compute wave drag from resolved momentum fluxes by vertical differentiation
    '''
    drag = - flux['u'].differentiate('level') / 100
    return drag

In [None]:
def two_sample_t_test(sample1,sample2,dim='number',broadcast_dims=('time','level')):
    '''
        Comparing two samples means, independent samples
        
        Be \sigma_1^2 \ne \sigma_2^2, but n_1=n_2=n
        H0: \mu_1 = \mu_2
        
        Returns the probability that x1 - x2 larger than the observed value could have occurded by chance
    '''
    n = len(sample1[dim])
    x1 = sample1.mean(dim)
    s1 = sample1.var(dim)
    x2 = sample2.mean(dim)
    s2 = sample2.var(dim)
    
    t = (x1-x2) / np.sqrt((s1+s2)/n)
    dof = n - 1 + (2 * n - 2) / (s1 / s2 + s2 / s1)

    p = xr.apply_ufunc(stats.t.sf,t,dof,
                       input_core_dims=[broadcast_dims,broadcast_dims],
                       output_core_dims=[broadcast_dims],
                       output_dtypes=t.dtype)
    return p

## Processing data

In [None]:
# resampling and zonal average of resolved GW momentum flux

directories = [data_dir+'TCo639_GWflux_8feb2018/nudg/hpxj/',
               data_dir+'TCo639_GWflux_8feb2018/nudg/hq1m/',
               data_dir+'TCo639_GWflux_17jan2006/free/hopg/',
               data_dir+'TCo639_GWflux_17jan2006/free/hopf/',
               data_dir+'TCo639_GWflux_5feb2010/free/hopd/',
               data_dir+'TCo639_GWflux_5feb2010/free/hope/',
               data_dir+'TCo639_GWflux_8feb2018/free/hokw/',
               data_dir+'TCo639_GWflux_8feb2018/free/hokx/',
               data_dir+'TCo319_GWflux_8feb2018/free/hm2h/',
               data_dir+'TCo319_GWflux_8feb2018/free/hjy3/',
               data_dir+'TCo319_GWflux_8feb2018/free/hjwz/'
              ]

experiments = ['TCo639_nudged/20180208_91L',
               'TCo639_nudged/20180208_198L',
               'TCo639_free_running/20060117_91L',
               'TCo639_free_running/20060117_198L',
               'TCo639_free_running/20100205_91L',
               'TCo639_free_running/20100205_198L',
               'TCo639_free_running/20180208_91L',
               'TCo639_free_running/20180208_198L',
               'TCo319_free_running/20180208_137L',
               'TCo319_free_running/20180208_198L',
               'TCo319_free_running/20180208_91L'
              ]

for directory, experiment in zip(directories,experiments):
    print(directory)
    files = [directory+f for f in os.listdir(directory) if f.endswith('.grb')]
    files.sort()

    ds = []
    for f in files:
        try: 
            ds = ds + [xr.open_dataset(f,chunks={},filter_by_keys={'typeOfLevel': 'isobaricInhPa'}),]
        except TypeError:
            print('TypeError')
    ds = xr.concat(ds,dim='number')
    
    time = ds.time + ds.step
    ds = ds.drop(('step','time','valid_time')).assign_coords(time=time).set_index(step='time').rename(step='time')

    ds = resample2daily(ds)
    ds = ds.mean('longitude').compute()

    ds.to_netcdf(work_dir+experiment+'_gravity_wave_flux.nc')

In [None]:
# zonal average of parameterized GW drag

experiments = ['TCo639_nudged/20180208_91L',
               'TCo639_nudged/20180208_198L',
               'TCo639_free_running/20060117_91L',
               'TCo639_free_running/20060117_198L',
               'TCo639_free_running/20100205_91L',
               'TCo639_free_running/20100205_198L',
               'TCo639_free_running/20180208_91L',
               'TCo639_free_running/20180208_198L',
               ]

for exp in experiments:
    ds = load_Grib(data_dir+exp+'/pressure_levels_F64/u_drag/',chunks=dict(step=16))
    print(exp)
    
    ds = ds.mean('longitude').compute()
    
    xr.Dataset(dict(u=ds['p102.128'])).to_netcdf(work_dir+exp+'_orographic_drag.nc')
    xr.Dataset(dict(u=ds['p105.128'])).to_netcdf(work_dir+exp+'_non-orographic_drag.nc')

In [None]:
# resampling and zonal average of zonal wind

experiments = ['TCo639_free_running/20060117_198L',
               'TCo639_free_running/20100205_198L',
               'TCo639_free_running/20180208_198L',
               ]

for exp in experiments:
    ds = load_Grib(data_dir+exp+'/pressure_levels_F64/u/',chunks=dict(step=16))
    print(exp)
    
    ds = resample2daily(ds)
    ds = ds.mean('longitude').compute()

    ds.to_netcdf(work_dir+exp+'_zonal_mean_U.nc')

## Figure 4

TCo639L198 ensemble mean of (a) full gravity wave drag (resolved + orographic (parameterized) + non-orographic (parameterized)) and (b) resolved gravity wave momentum flux horizontally averaged between 45 and 70° N with contours of zonal-mean zonal wind [ms$^{-1}$] at 60° N.

In [None]:
def plot_one_exp(exp):
    
    # compute full drag 
    
    res = xr.open_dataset(work_dir+exp+'_gravity_wave_flux.nc')
    res = wave_drag(res)
    res = aggregation(res)
    res = res.mean('number')
    
    oro = xr.open_dataset(work_dir+exp+'_orographic_drag.nc')
    oro = average_from_accumulated(oro)
    oro = aggregation(oro)['u']
    oro = oro.mean('number')
    
    non = xr.open_dataset(work_dir+exp+'_non-orographic_drag.nc')
    non = average_from_accumulated(non)
    non = aggregation(non)['u']
    non = non.mean('number')
    
    drag = (res+oro+non)*86400
    
    u = xr.open_dataarray(work_dir+exp+'_zonal_mean_U.nc')
    u = u.sel(latitude=60,method='nearest').drop('latitude')
    u = u.mean('number')
    
    # compute resolved momentumf flux
    
    flux = xr.open_dataset(work_dir+exp+'_gravity_wave_flux.nc')['u']
    flux = aggregation(flux)
    flux = flux.mean('number')
    flux = flux / (-9.81) * 1000

    # Plotting
    
    
    fig, axes = plt.subplots(nrows=2,sharex='all',figsize=(6,6))
    
        
    C1 = drag.plot.pcolormesh(ax=axes[0],x='time',cmap=cmocean.cm.balance,extend='both',
                    levels=np.linspace(-2,2,33),add_colorbar=False)
    
    C = u.plot.contour(ax=axes[0],x='time',levels=np.arange(-60,70,10),colors='k',add_colorbar=False)
    
    plt.clabel(C)
    
    ax = axes[0].twinx()
    ax.set_yticks([])
    ax.set_ylabel('wave drag',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    axes[0].set_yscale('log')
    axes[0].set_ylim(axes[0].get_ylim()[::-1])
    axes[0].set_ylabel('pressure [hPa]')
    axes[0].set_xlabel('time')
    
    
    C2 = flux.plot.pcolormesh(ax=axes[1],x='time',cmap=cmocean.cm.balance,extend='both',
                    levels=np.linspace(-2,2,33),add_colorbar=False)
    
    C = u.plot.contour(ax=axes[1],x='time',levels=np.arange(-60,70,10),colors='k',add_colorbar=False)
    
    plt.clabel(C)
    
    ax = axes[1].twinx()
    ax.set_yticks([])
    ax.set_ylabel('momentum flux',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    axes[1].set_yscale('log')
    axes[1].set_ylim(axes[1].get_ylim()[::-1])
    axes[1].set_ylabel('pressure [hPa]')
    axes[1].set_xlabel('time')
           
        
    trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)
    
    axes[0].text(-0.06,1.0,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[1].text(-0.06,1.0,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    
    fig.subplots_adjust(0,0,1,1,0,0)
    
    cbar = plt.colorbar(C1,ax=axes[0],orientation='vertical',fraction=0.1,aspect=10,shrink=0.95)
    cbar.set_label(r'[m s$^{-1}$ day$^{-1}$]')
    
    cbar = plt.colorbar(C2,ax=axes[1],orientation='vertical',fraction=0.1,aspect=10,shrink=0.95)
    cbar.set_label(r'[mPa]')
        
    

plot_one_exp('TCo639_free_running/20180208_198L')

## Figure 5

Difference in ensemble-mean gravity wave drag between the TCo639L198 and TCo639L91 free-running hindcasts split into the (a) resolved, (b) parameterized non-orographic, (c) parameterized orographic components, and (d) the difference in ensemblemean resolved gravity wave momentum flux between the two sets of hindcasts. Wave drag and momentum flux are horizontally averaged between 45 and 70° N. Hatching indicates areas where the improvement with higher vertical resolution is not significantly different from zero at a 95% confidence level estimated by a two-sample t test.

In [None]:
def plot_difference(exp0,exp1):
    '''
    '''
    
    # load split drag
    res0 = xr.open_dataset(work_dir+exp0+'_gravity_wave_flux.nc')
    res0 = wave_drag(res0)
    res0 = aggregation(res0) * 86400
    
    res1 = xr.open_dataset(work_dir+exp1+'_gravity_wave_flux.nc')
    res1 = wave_drag(res1)
    res1 = aggregation(res1) * 86400
    
    oro0 = xr.open_dataset(work_dir+exp0+'_orographic_drag.nc')
    oro0 = average_from_accumulated(oro0)
    oro0 = aggregation(oro0)['u'] * 86400
        
    oro1 = xr.open_dataset(work_dir+exp1+'_orographic_drag.nc')
    oro1 = average_from_accumulated(oro1)
    oro1 = aggregation(oro1)['u'] * 86400
    
    non0 = xr.open_dataset(work_dir+exp0+'_non-orographic_drag.nc')
    non0 = average_from_accumulated(non0)
    non0 = aggregation(non0)['u'] * 86400
        
    non1 = xr.open_dataset(work_dir+exp1+'_non-orographic_drag.nc')
    non1 = average_from_accumulated(non1)
    non1 = aggregation(non1)['u'] * 86400
    
    # comute momentum flux
    
    flux0 = xr.open_dataset(work_dir+exp0+'_gravity_wave_flux.nc')['u']
    flux0 = aggregation(flux0)
    flux0 = flux0 / (-9.81) * 1000
    
    flux1 = xr.open_dataset(work_dir+exp1+'_gravity_wave_flux.nc')['u']
    flux1 = aggregation(flux1)
    flux1 = flux1 / (-9.81) * 1000
    
    # compute difference and significance
    res_diff = res1.mean('number') - res0.mean('number')
    p = two_sample_t_test(res1, res0)
    res_sig = np.add(p < 0.025, p > 0.975)
    
    oro_diff = oro1.mean('number') - oro0.mean('number')
    p = two_sample_t_test(oro1, oro0)
    oro_sig = np.add(p < 0.025, p > 0.975)
    
    non_diff = non1.mean('number') - non0.mean('number')
    p = two_sample_t_test(non1, non0)
    non_sig = np.add(p < 0.025, p > 0.975)
    
    flux_diff = flux1.mean('number') - flux0.mean('number')
    p = two_sample_t_test(flux1, flux0)
    flux_sig = np.add(p < 0.025, p > 0.975)
    
    
    # plotting
    
    fig, axes = plt.subplots(nrows=4,sharex='all',figsize=(6,9))
    
    C1 = res_diff.plot.pcolormesh(ax=axes[0],x='time',cmap=cmocean.cm.balance,extend='both',
                             levels=np.linspace(-0.3,0.3,33),add_colorbar=False)
    
    res_sig.astype(np.double).plot.contourf(ax=axes[0],x='time',levels=[0,0.5,1],hatches=['\\',''],
                                            alpha=0,add_colorbar=False)
    
    ax = axes[0].twinx()
    ax.set_yticks([])
    ax.set_ylabel('resolved',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    
    non_diff.plot.pcolormesh(ax=axes[1],x='time',cmap=cmocean.cm.balance,extend='both',
                             levels=np.linspace(-0.3,0.3,33),add_colorbar=False)
    
    non_sig.astype(np.double).plot.contourf(ax=axes[1],x='time',levels=[0,0.5,1],hatches=['\\',''],
                                            alpha=0,add_colorbar=False)
    
    ax = axes[1].twinx()
    ax.set_yticks([])
    ax.set_ylabel('non-orographic',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')

    
    oro_diff.plot.pcolormesh(ax=axes[2],x='time',cmap=cmocean.cm.balance,extend='both',
                             levels=np.linspace(-0.3,0.3,33),add_colorbar=False)
    
    oro_sig.astype(np.double).plot.contourf(ax=axes[2],x='time',levels=[0,0.5,1],hatches=['\\',''],
                                            alpha=0,add_colorbar=False)
    
    ax = axes[2].twinx()
    ax.set_yticks([])
    ax.set_ylabel('orographic',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    
    C2 = flux_diff.plot.pcolormesh(ax=axes[3],x='time',cmap=cmocean.cm.balance,extend='both',
                             levels=np.linspace(-0.4,0.4,33),add_colorbar=False)

    flux_sig.astype(np.double).plot.contourf(ax=axes[3],x='time',levels=[0,0.5,1],hatches=['\\',''],
                                            alpha=0,add_colorbar=False)
    
    ax = axes[3].twinx()
    ax.set_yticks([])
    ax.set_ylabel('momentum flux',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')

    
    for ax in axes:
        ax.set_yscale('log')
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_ylabel('pressure [hPa]')
        ax.set_xlabel(None)
    
    axes[3].set_xlabel('time')
    
    trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)
    
    axes[0].text(-0.06,1.0,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[1].text(-0.06,1.0,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[2].text(-0.06,1.0,'c)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[3].text(-0.06,1.0,'d)',transform=axes[3].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    
    
    fig.subplots_adjust(0,0,1,1,0,0)
    
    cbar = plt.colorbar(C1,ax=axes[0:3],orientation='vertical',fraction=0.1,aspect=30,shrink=0.95)
    cbar.set_label(r'[m s$^{-1}$ day$^{-1}$]')
    
    cbar = plt.colorbar(C2,ax=axes[3],orientation='vertical',fraction=0.1,aspect=10,shrink=0.95)
    cbar.set_label(r'[mPa]')
    
    
plot_difference('TCo639_free_running/20180208_91L',
                      'TCo639_free_running/20180208_198L')

## Figure 6
Same as Fig. 5 but for the nudged simulations.

In [None]:
plot_difference('TCo639_nudged/20180208_91L',
                      'TCo639_nudged/20180208_198L')

## Figure A3

Ensemble-mean full gravity wave drag and zonal-mean zonal wind as in Fig. 4a for the (a) 2006 and (b) 2010 SSW events. Differences in ensemble-mean gravity wave drag between the TCo639L198 and TCo639L91 free-running hindcasts as in Fig. 5 for the (c, e, g) 2006 and (d, f, h) 2010 SSW events.

In [None]:
def plot_eight_drag(date1,date2):
    '''
    '''
    # compute full drag
    
    full = []
    
    for exp in [date1,date2]:
        res = xr.open_dataset(work_dir+exp+'_198L_gravity_wave_flux.nc')
        res = wave_drag(res)
        res = aggregation(res)
        res = res.mean('number')
    
        oro = xr.open_dataset(work_dir+exp+'_198L_orographic_drag.nc')
        oro = average_from_accumulated(oro)
        oro = aggregation(oro)['u']
        oro = oro.mean('number')
    
        non = xr.open_dataset(work_dir+exp+'_198L_non-orographic_drag.nc')
        non = average_from_accumulated(non)
        non = aggregation(non)['u']
        non = non.mean('number')
        
        full.append((res+oro+non)*86400)
        
    # load zonal wind
    
    uwind = []
    
    for exp in [date1,date2]:
        u = xr.open_dataarray(work_dir+exp+'_198L_zonal_mean_U.nc')
        u = u.sel(latitude=60,method='nearest').drop('latitude')
        u = u.mean('number')
        uwind.append(u)
        
        
    # compute resolved difference
    
    res_diff = []
    res_sig = []
    
    for exp in [date1,date2]:
        res0 = xr.open_dataset(work_dir+exp+'_91L_gravity_wave_flux.nc')
        res0 = wave_drag(res0)
        res0 = aggregation(res0) * 86400
        
        res1 = xr.open_dataset(work_dir+exp+'_198L_gravity_wave_flux.nc')
        res1 = wave_drag(res1)
        res1 = aggregation(res1) * 86400
        
        res_diff.append(res1.mean('number') - res0.mean('number'))
        p = two_sample_t_test(res1, res0)
        res_sig.append(np.add(p < 0.025, p > 0.975))
        
        
    # compute orographic difference
    
    oro_diff = []
    oro_sig = []
    
    for exp in [date1,date2]:
        oro0 = xr.open_dataset(work_dir+exp+'_91L_orographic_drag.nc')
        oro0 = average_from_accumulated(oro0)
        oro0 = aggregation(oro0)['u'] * 86400
        
        oro1 = xr.open_dataset(work_dir+exp+'_198L_orographic_drag.nc')
        oro1 = average_from_accumulated(oro1)
        oro1 = aggregation(oro1)['u'] * 86400
        
        oro_diff.append(oro1.mean('number') - oro0.mean('number'))
        p = two_sample_t_test(oro1, oro0)
        oro_sig.append(np.add(p < 0.025, p > 0.975))
        
        
    # compute non-orographic difference
    
    non_diff = []
    non_sig = []
    
    for exp in [date1,date2]:
        non0 = xr.open_dataset(work_dir+exp+'_91L_non-orographic_drag.nc')
        non0 = average_from_accumulated(non0)
        non0 = aggregation(non0)['u'] * 86400
        
        non1 = xr.open_dataset(work_dir+exp+'_198L_non-orographic_drag.nc')
        non1 = average_from_accumulated(non1)
        non1 = aggregation(non1)['u'] * 86400
        
        non_diff.append(non1.mean('number') - non0.mean('number'))
        p = two_sample_t_test(non1, non0)
        non_sig.append(np.add(p < 0.025, p > 0.975))
        
        
    # plotting
    
    fig, axes = plt.subplots(nrows=4,ncols=2,sharex='col',sharey=False,figsize=(9,12))
    axes = axes.flatten()
    
    #fig.subplots_adjust(right=0.8)
    
    for ax, drag, u in zip(axes[:2],full,uwind):
        drag.plot.pcolormesh(ax=ax,x='time',cmap=cmocean.cm.balance,extend='both',
                                 levels=np.linspace(-2,2,33),add_colorbar=False)
        
        C = u.plot.contour(ax=ax,x='time',levels=np.arange(-60,70,10),colors='k',add_colorbar=False)
        plt.clabel(C)
        
        
    for ax, diff, sig in zip(axes[2:],res_diff+non_diff+oro_diff,res_sig+non_sig+oro_sig):
        diff.plot.pcolormesh(ax=ax,x='time',cmap=cmocean.cm.balance,extend='both',
                                 levels=np.linspace(-0.3,0.3,33),add_colorbar=False)
        sig.astype(np.double).plot.contourf(ax=ax,x='time',levels=[0,0.5,1],hatches=['\\',''],
                                            alpha=0,add_colorbar=False)
        
    for ax in axes:
        ax.set_yscale('log')
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_ylabel(None)
        ax.set_xlabel(None)
    
    axes[0].set_ylabel('pressure [hPa]')
    axes[2].set_ylabel('pressure [hPa]')
    axes[4].set_ylabel('pressure [hPa]')
    axes[6].set_ylabel('pressure [hPa]')
    
    axes[1].set_yticklabels([])
    axes[3].set_yticklabels([])
    axes[5].set_yticklabels([])
    axes[7].set_yticklabels([])

    axes[6].set_xlabel('time')
    axes[7].set_xlabel('time')
    
    ax = axes[0].twinx()
    ax.set_yticks([])
    ax.set_ylabel('wave drag',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    ax = axes[2].twinx()
    ax.set_yticks([])
    ax.set_ylabel('resolved',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    ax = axes[4].twinx()
    ax.set_yticks([])
    ax.set_ylabel('non-orographic',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    ax = axes[6].twinx()
    ax.set_yticks([])
    ax.set_ylabel('orographic',labelpad=80,size=18,weight='bold')
    ax.yaxis.set_label_position('left')
    
    fig.subplots_adjust(0,0,1,0.95,0.15,0)
    
    cmap = mcolors.LinearSegmentedColormap('',cmocean.tools.get_dict(cmocean.cm.balance,N=34),N=34)
    
    fig.colorbar(mcm.ScalarMappable(norm=mcolors.Normalize(-2,2), cmap=cmap),
                 ax=[axes[0],axes[1]],
                 extend='both',ticks=[-2,-1.5,-1,-0.5,0,0.5,1,1.5,2],
                 label=r'[m s$^{-1}$ day$^{-1}$]',fraction=0.1,aspect=10,shrink=0.95)
    
    fig.colorbar(mcm.ScalarMappable(norm=mcolors.Normalize(-0.3,0.3), cmap=cmap),
                 ax=[axes[2],axes[3],axes[4],axes[5],axes[6],axes[7]],
                 extend='both',ticks=[-0.3,-0.225,-0.15,-0.075,0,0.075,0.15,0.225,0.3],
                 label=r'[m s$^{-1}$ day$^{-1}$]',fraction=0.1,aspect=30,shrink=0.95)
    
    
    
    trans = mtransforms.ScaledTranslation(-45/72, -20/72, fig.dpi_scale_trans)
    
    axes[0].text(-0.05,1.0,'a)',transform=axes[0].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[1].text(0.1,1.0,'b)',transform=axes[1].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[2].text(-0.05,1.0,'c)',transform=axes[2].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[3].text(0.1,1.0,'d)',transform=axes[3].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[4].text(-0.05,1.0,'e)',transform=axes[4].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[5].text(0.1,1.0,'f)',transform=axes[5].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[6].text(-0.05,1.0,'g)',transform=axes[6].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    axes[7].text(0.1,1.0,'h)',transform=axes[7].transAxes+trans,fontsize='large',va='bottom',fontfamily='serif')
    
plot_eight_drag('TCo639_free_running/20060117','TCo639_free_running/20100205',)

## Figure B1

Ensemble-mean resolved gravity wave drag horizontally averaged between 45 and 70° N for different model configurations initialized on 8 February 2018.

In [None]:
def plot_five_drag(experiments,labels):
    '''
    '''
    
    drag = []
    for exp in experiments:
        res = xr.open_dataset(work_dir+exp+'_gravity_wave_flux.nc')
        res = wave_drag(res)
        res = aggregation(res)
        res = res.mean('number')*86400
        drag.append(res)
    

    # Plotting
    
    fig, axes = plt.subplots(nrows=3,ncols=2,sharex=False,sharey=False,figsize=(9,12))
    axes = axes.flatten()
    
    for res, l, ax in zip(drag,labels,axes[:-1]):
        
        C = res.plot.pcolormesh(ax=ax,x='time',cmap=cmocean.cm.balance,extend='both',
                                levels=np.linspace(-0.3,0.3,33),add_colorbar=False)
        
        ax.set_yscale('log')
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_ylabel(None)
        ax.set_xlabel(None)
        
        ax.set_title(l,weight='bold')
        
        
    axes[0].set_ylabel('pressure [hPa]')
    axes[2].set_ylabel('pressure [hPa]')
    axes[4].set_ylabel('pressure [hPa]')
    axes[4].set_xlabel('time')
    
    axes[0].set_xticklabels([])
    axes[1].set_xticklabels([])
    axes[2].set_xticklabels([])
    axes[3].set_xticklabels([])
    axes[1].set_yticklabels([])
    axes[3].set_yticklabels([])
           
    axes[5].set_aspect(6)
    axes[5].set_anchor((0.4,0.5))
    plt.colorbar(C,cax=axes[5],label=r'wave drag [m s$^{-1}$ day$^{-1}$]')
    
    plt.tight_layout()
    
    
plot_five_drag(['TCo639_free_running/20180208_91L',
               'TCo639_free_running/20180208_198L',
               'TCo319_free_running/20180208_91L',
               'TCo319_free_running/20180208_137L',
               'TCo319_free_running/20180208_198L'],
               ['TCo639L91',
                'TCo639L198',
                'TCo319L91',
                'TCo319L137',
                'TCo319L198',],)