In [None]:
import numpy as np
import xarray as xr
import datetime as dt
import seaborn as sns
import mplotutils as mpu
import cartopy.crs as ccrs
from netCDF4 import Dataset
from dypy.lagranto import Tra
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import cartopy.feature as cfeature
from lagranto.plotting import plot_trajs
from matplotlib.colors import ListedColormap
from wrf import to_np, getvar, latlon_coords, CoordPair, ll_to_xy, Constants

%matplotlib inline

## First row (3 figures with wider extent) for panel plot of trajectory maps

In [None]:
case_study_name = 'case_study_5'
traj_variable_name = 'theta_e'

wrf_filename = '/scratch3/thomasl/work/data/case_study_1/' \
                    'wrfout_d02_2018-05-22_16:30:00'
ncfile = Dataset(wrf_filename)

if case_study_name == 'case_study_1':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs1_eth'
    else:
        case_study_abbr = 'cs1'
    date = '2018-05-22'
    start_time_before = '1630'
    end_time_before = '0800'
    end_cut_before = 6
    end_time_adj_before = '0830'
    start_time_during = '1700'
    end_time_during = '0900'
    end_cut_during = 0
    end_time_adj_during = '0900'
    start_time_after = '1730'
    end_time_after = '0900'
    end_cut_after = 6
    end_time_adj_after = '0930'
    initiation_location = [7.56971, 47.4961]
    low_level = 2.5
    mid_level = 6
    subset_extent = [5.5, 9, 46, 48.5]
    subset_small_extent = [6.75, 8.5, 46.75, 48.25]

elif case_study_name == 'case_study_2':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs2_eth'
    else:
        case_study_abbr = 'cs2'
    date = '2018-05-09'
    initiation_location = [7.63891, 47.0546]
    low_level = 2
    mid_level = 7
    subset_extent = [7, 9, 46.75, 48.75]
    subset_small_extent = [7, 9, 46.9, 48.25]
    start_time_before = '1335'
    end_time_before = '0500'
    end_cut_before = 7
    start_time_during = '1405'
    end_time_during = '0600'
    end_cut_during = 1
    start_time_after = '1435'
    end_time_after = '0600'
    end_cut_after = 7

elif case_study_name == 'case_study_3':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs3_eth'
    else:
        case_study_abbr = 'cs3'
    date = '2018-05-12'
    initiation_location = [8.64449, 47.5522]
    low_level = 3
    mid_level = 6.5
    subset_extent = [7.5, 12, 45.8, 48.5]
    subset_small_extent = [8, 12, 46.5, 48.25]
    start_time_before = '1945'
    end_time_before = '1100'
    end_cut_before = 9
    start_time_during = '2015'
    end_time_during = '1200'
    end_cut_during = 3
    start_time_after = '2045'
    end_time_after = '1200'
    end_cut_after = 9
    
elif case_study_name == 'case_study_4':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs4_eth'
    else:
        case_study_abbr = 'cs4'
    date = '2018-05-10'
    initiation_location = [8.67235, 47.629]
    low_level = 2.5
    mid_level = 6
    subset_extent = [7, 12.5, 47, 49.4]
    subset_small_extent = [7, 9.5, 47.2, 48.75]
    start_time_before = '0030'
    end_time_before = '1600'
    end_cut_before = 6
    start_time_during = '0100'
    end_time_during = '1700'
    end_cut_during = 0
    start_time_after = '0130'
    end_time_after = '1700'
    end_cut_after = 6
    
elif case_study_name == 'case_study_5':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs5_eth'
    else:
        case_study_abbr = 'cs5'
    date = '2018-05-30'
    initiation_location = [6.95974, 46.4482]
    low_level = 2.5
    mid_level = 8
    subset_extent = [6, 8.5, 45, 47]
    subset_small_extent = [6.5, 8.25, 45.5, 47.1]
    start_time_before = '1420'
    end_time_before = '0600'
    end_cut_before = 4
    start_time_during = '1450'
    end_time_during = '0600'
    end_cut_during = 10
    start_time_after = '1520'
    end_time_after = '0700'
    end_cut_after = 4
    
save_dir = '/scratch3/thomasl/work/retrospective_part/' \
            '{}/lagranto/'.format(case_study_name)

if traj_variable_name == 'height':
    traj_variable_name = 'z'
    colorbar_label_trajs = 'Height [$km$]'
    traj_variable_min = 0
    traj_variable_max = 10
    cpalette = ['#543005', '#8c510a', '#bf812d', '#dfc27d', '#f6e8c3', '#c7eae5', 
              '#80cdc1', '#35978f', '#01665e', '#003c30']
    cmap = ListedColormap(sns.color_palette(cpalette).as_hex())

elif traj_variable_name == 'water_vapor':
    traj_variable_name = 'QVAPOR'
    colorbar_label_trajs = 'WVMR [$g$ $kg^{-1}$]'
    if case_study_name == 'case_study_5':
        traj_variable_min = 0
        traj_variable_max = 12
    else:                 
        traj_variable_min = 0
        traj_variable_max = 10

    cmap = ListedColormap(sns.cubehelix_palette(10, start=.5, rot=-.75,))

elif traj_variable_name == 'updraft':
    traj_variable_name = 'W_UP_MAX'
    colorbar_label_trajs = 'Updraft [$m$ $s^-$$^1$]'
    traj_variable_min = 0
    traj_variable_max = 5
    cpalette = ['#fee0b6', '#e0e0e0', '#dadaeb', '#bcbddc', '#9e9ac8', '#807dba', 
                '#6a51a3', '#54278f', '#3f007d', '#2d004b']
    cmap = ListedColormap(sns.color_palette(cpalette).as_hex())

elif traj_variable_name == 'theta_e':
    colorbar_label_trajs = 'Theta-E [$K$]'
    cmap = plt.get_cmap('RdYlBu_r')
    if case_study_name == 'case_study_1':
        traj_variable_min = 320
        traj_variable_max = 325
    elif case_study_name == 'case_study_2':
        traj_variable_min = 313
        traj_variable_max = 323    
    elif case_study_name == 'case_study_3':
        traj_variable_min = 318
        traj_variable_max = 328 
    elif case_study_name == 'case_study_4':
        traj_variable_min = 313
        traj_variable_max = 318
    elif case_study_name == 'case_study_5':
        traj_variable_min = 326
        traj_variable_max = 336

# Get terrain height
terrain_height = getvar(ncfile, 'HGT')/1000 # change to km

# Define cart projection
lats, lons = latlon_coords(terrain_height)
cart_proj = ccrs.LambertConformal(central_longitude=8.722206, 
                                central_latitude=46.73585)
# Create figure
fig, ax = plt.subplots(nrows=1,ncols=3, figsize=(30, 10), dpi=300, 
                                subplot_kw=dict(projection=cart_proj),
                      gridspec_kw={
                       'width_ratios': [1, 1, 1],
                       'height_ratios': [1]})        

### --- Plots before --- ###
traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
        'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time_before, end_time_before)

trajs = Tra()
trajs.load_ascii(traj_data_dir)
trajs['z'] = trajs['z']/1000
if end_cut_before == 0:
    pass
else:
    trajs = [t[:-end_cut_before] for t in trajs]

trajs_all=[]
for t in trajs:
        if (t['lat'][-1] > 0):
                trajs_all.append(t)

# Set map extent 
ax[0].set_extent([subset_extent[0],subset_extent[1],
                   subset_extent[2],subset_extent[3]],
                     ccrs.PlateCarree())

# Plot trajectories
levels_trajs_all = np.linspace(traj_variable_min, traj_variable_max, 11)
rounded_levels_trajs_all = [ round(elem, 1) for elem in levels_trajs_all ]

plt_trajs = plot_trajs(ax[0], trajs_all[::-5], traj_variable_name, 
                           linewidth=2, levels=rounded_levels_trajs_all, 
                           cmap=cmap, zorder=4)

# Plot the terrain height with colorbar
levels = np.linspace(0, 4, 21)
terrain = ax[0].contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
             levels=levels, transform=ccrs.PlateCarree(), 
             cmap=get_cmap('Greys'), alpha=0.75, zorder=1)

# Add borders and coastlines
ax[0].add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
ax[0].add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)

# Add cross for initiation location
for t in trajs:
    ax[0].plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())

# Add gridlines
lon = np.arange(0, 20, 1)
lat = np.arange(40, 60, 1)

ax[0].gridlines(xlocs=lon, ylocs=lat, zorder=2)

# Add tick labels
mpu.yticklabels(lat, ax=ax[0], fontsize=15)
mpu.xticklabels(lon, ax=ax[0], fontsize=15)

# plot during
traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
        'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time_during, end_time_during)

trajs = Tra()
trajs.load_ascii(traj_data_dir)
trajs['z'] = trajs['z']/1000
if end_cut_during == 0:
    pass
else:
    trajs = [t[:-end_cut_during] for t in trajs]

trajs_all=[]
for t in trajs:
        if (t['lat'][-1] > 0):
                trajs_all.append(t)

### Set map extent ###
ax[1].set_extent([subset_extent[0],subset_extent[1],
                   subset_extent[2],subset_extent[3]],
                     ccrs.PlateCarree())

# Plot trajectories
plot_trajs(ax[1], trajs_all[::-5], traj_variable_name, 
                           linewidth=2, levels=rounded_levels_trajs_all, 
                           cmap=cmap, zorder=4)

# Plot the terrain height with colorbar
ax[1].contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
             levels=levels, transform=ccrs.PlateCarree(), 
             cmap=get_cmap('Greys'), alpha=0.75, zorder=1)

# Add borders and coastlines
ax[1].add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
ax[1].add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)

# Add cross for initiation location
for t in trajs:
    ax[1].plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())

# Add gridlines
ax[1].gridlines(xlocs=lon, ylocs=lat, zorder=2)

# Add tick labels
mpu.yticklabels(lat, ax=ax[1], fontsize=15)
mpu.xticklabels(lon, ax=ax[1], fontsize=15)    

# plot after
traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
        'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time_after, end_time_after)

trajs = Tra()
trajs.load_ascii(traj_data_dir)
trajs['z'] = trajs['z']/1000
if end_cut_after == 0:
    pass
else:
    trajs = [t[:-end_cut_after] for t in trajs]

trajs_all=[]
for t in trajs:
        if (t['lat'][-1] > 0):
                trajs_all.append(t)

ax[2].set_extent([subset_extent[0],subset_extent[1],
                   subset_extent[2],subset_extent[3]],
                     ccrs.PlateCarree())

# Plot trajectories
plot_trajs(ax[2], trajs_all[::-5], traj_variable_name, 
                           linewidth=2, levels=rounded_levels_trajs_all, 
                           cmap=cmap, zorder=4)

# Plot the terrain height with colorbar
ax[2].contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
             levels=levels, transform=ccrs.PlateCarree(), 
             cmap=get_cmap('Greys'), alpha=0.75, zorder=1)

# Add borders and coastlines
ax[2].add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
ax[2].add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)

# Add cross for initiation location
for t in trajs:
    ax[2].plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())

ax[2].gridlines(xlocs=lon, ylocs=lat, zorder=2)

# Add tick labels
mpu.yticklabels(lat, ax=ax[2], fontsize=15)
mpu.xticklabels(lon, ax=ax[2], fontsize=15)

# Plot color bars
ax[0].text(0.05, 0.95, '(a)', transform=ax[0].transAxes,
  fontsize=25, fontweight='bold', va='top', zorder=10)
ax[1].text(0.05, 0.95, '(b)', transform=ax[1].transAxes, 
  fontsize=25, fontweight='bold', va='top', zorder=10)
ax[2].text(0.05, 0.95, '(c)', transform=ax[2].transAxes, 
  fontsize=25, fontweight='bold', va='top', zorder=10)

# Plot color bars    
p0 = ax[0].get_position().get_points().flatten()
p1 = ax[2].get_position().get_points().flatten()

if case_study_name == 'case_study_4':
    bottom = 0.2
elif case_study_name == 'case_study_3':
    bottom = 0.1
else:
    bottom = 0.05

cbar_ax = fig.add_axes([(p0[2]+p0[0])/2, bottom, (((p1[2]+p1[0])/2)-((p0[2]+p0[0])/2)), 0.025])

cbar_ter = fig.colorbar(terrain, cax=cbar_ax, orientation='horizontal', aspect=80, use_gridspec=True)
cbar_ter.set_label('Terrain Height [$km$]', fontsize=25)
cbar_ter.set_ticks(levels)
cbar_ter.ax.tick_params(labelsize=20)

for label in cbar_ter.ax.xaxis.get_ticklabels()[1::2]:
    label.set_visible(False)

# Add color bar for trajectory variable
if traj_variable_name == 'theta_e':
    extend = 'both'       
else:
    extend = 'max'
cbar_trajs = mpu.colorbar(plt_trajs, ax[2], orientation='vertical', 
                          aspect=15, shrink=.05, pad=0.1, extend=extend)
cbar_trajs.set_label(colorbar_label_trajs, fontsize=25)
cbar_trajs.set_ticks(rounded_levels_trajs_all)
cbar_trajs.ax.tick_params(labelsize=20)

plt.show()

fig.savefig('{}_traj_{}_subset'.format(case_study_name,traj_variable_name), bbox_inches='tight', dpi=300)

## Second row (3 figures with smaller extent) for panel plot of trajectory maps

In [None]:
case_study_name = 'case_study_5'
traj_variable_name = 'height'

wrf_filename = '/scratch3/thomasl/work/data/case_study_1/' \
                    'wrfout_d02_2018-05-22_16:30:00'
ncfile = Dataset(wrf_filename)

if case_study_name == 'case_study_1':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs1_eth'
    else:
        case_study_abbr = 'cs1'
    date = '2018-05-22'
    start_time_before = '1630'
    end_time_before = '0800'
    end_cut_before = 6
    end_time_adj_before = '0830'
    start_time_during = '1700'
    end_time_during = '0900'
    end_cut_during = 0
    end_time_adj_during = '0900'
    start_time_after = '1730'
    end_time_after = '0900'
    end_cut_after = 6
    end_time_adj_after = '0930'
    initiation_location = [7.56971, 47.4961]
    low_level = 2.5
    mid_level = 6
    subset_extent = [5.5, 9, 46, 48.5]
    subset_small_extent = [6.75, 8.5, 46.75, 48.25]

elif case_study_name == 'case_study_2':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs2_eth'
    else:
        case_study_abbr = 'cs2'
    date = '2018-05-09'
    initiation_location = [7.63891, 47.0546]
    low_level = 2
    mid_level = 7
    subset_extent = [7, 9, 46.75, 48.75]
    subset_small_extent = [7, 9, 46.9, 48.25]
    start_time_before = '1335'
    end_time_before = '0500'
    end_cut_before = 7
    start_time_during = '1405'
    end_time_during = '0600'
    end_cut_during = 1
    start_time_after = '1435'
    end_time_after = '0600'
    end_cut_after = 7

elif case_study_name == 'case_study_3':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs3_eth'
    else:
        case_study_abbr = 'cs3'
    date = '2018-05-12'
    initiation_location = [8.64449, 47.5522]
    low_level = 3
    mid_level = 6.5
    subset_extent = [7.5, 12, 45.8, 48.5]
    subset_small_extent = [8, 12, 46.5, 48.25]
    start_time_before = '1945'
    end_time_before = '1100'
    end_cut_before = 9
    start_time_during = '2015'
    end_time_during = '1200'
    end_cut_during = 3
    start_time_after = '2045'
    end_time_after = '1200'
    end_cut_after = 9

elif case_study_name == 'case_study_4':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs4_eth'
    else:
        case_study_abbr = 'cs4'
    date = '2018-05-10'
    initiation_location = [8.67235, 47.629]
    low_level = 2.5
    mid_level = 6
    subset_extent = [7, 12.5, 47, 49.4]
    subset_small_extent = [7, 9.5, 47.2, 48.75]
    start_time_before = '0030'
    end_time_before = '1600'
    end_cut_before = 6
    start_time_during = '0100'
    end_time_during = '1700'
    end_cut_during = 0
    start_time_after = '0130'
    end_time_after = '1700'
    end_cut_after = 6

elif case_study_name == 'case_study_5':
    if traj_variable_name == 'theta_e':
        case_study_abbr = 'cs5_eth'
    else:
        case_study_abbr = 'cs5'
    date = '2018-05-30'
    initiation_location = [6.95974, 46.4482]
    low_level = 2.5
    mid_level = 8
    subset_extent = [6, 8.5, 45, 47]
    subset_small_extent = [6.5, 8.25, 45.5, 47.1]
    start_time_before = '1420'
    end_time_before = '0600'
    end_cut_before = 4
    start_time_during = '1450'
    end_time_during = '0600'
    end_cut_during = 10
    start_time_after = '1520'
    end_time_after = '0700'
    end_cut_after = 4
    
save_dir = '/scratch3/thomasl/work/retrospective_part/' \
            '{}/lagranto/'.format(case_study_name)

if traj_variable_name == 'height':
    traj_variable_name = 'z'
    colorbar_label_trajs = 'Height [$km$]'
    traj_variable_min = 0
    traj_variable_max = mid_level
    cpalette = ['#543005', '#8c510a', '#bf812d', '#dfc27d', '#f6e8c3', '#c7eae5', 
              '#80cdc1', '#35978f', '#01665e', '#003c30']
    cmap = ListedColormap(sns.color_palette(cpalette).as_hex())

elif traj_variable_name == 'water_vapor':
    traj_variable_name = 'QVAPOR'
    colorbar_label_trajs = 'WVMR [$g$ $kg^{-1}$]'
    if case_study_name == 'case_study_5':
        traj_variable_min = 0
        traj_variable_max = 12
    else:                 
        traj_variable_min = 0
        traj_variable_max = 10

    cmap = ListedColormap(sns.cubehelix_palette(10, start=.5, rot=-.75,))

elif traj_variable_name == 'updraft':
    traj_variable_name = 'W_UP_MAX'
    title_name = 'Trajectories of Updraft'
    colorbar_label_trajs = 'Updraft [$m$ $s^-$$^1$]'
    traj_variable_min = 0
    traj_variable_max = 5
    cpalette = ['#fee0b6', '#e0e0e0', '#dadaeb', '#bcbddc', '#9e9ac8', '#807dba', 
                '#6a51a3', '#54278f', '#3f007d', '#2d004b']
    cmap = ListedColormap(sns.color_palette(cpalette).as_hex())

elif traj_variable_name == 'theta_e':
    colorbar_label_trajs = 'Theta-E [$K$]'
    cmap = plt.get_cmap('RdYlBu_r')
    if case_study_name == 'case_study_1':
        traj_variable_min = 320
        traj_variable_max = 325
    elif case_study_name == 'case_study_2':
        traj_variable_min = 313
        traj_variable_max = 323    
    elif case_study_name == 'case_study_3':
        traj_variable_min = 318
        traj_variable_max = 328 
    elif case_study_name == 'case_study_4':
        traj_variable_min = 313
        traj_variable_max = 318
    elif case_study_name == 'case_study_5':
        traj_variable_min = 326
        traj_variable_max = 336

# Get terrain height
terrain_height = getvar(ncfile, 'HGT')/1000 # change to km

# Define cart projection
lats, lons = latlon_coords(terrain_height)
cart_proj = ccrs.LambertConformal(central_longitude=8.722206, 
                                central_latitude=46.73585)
# Create figure
fig, ax = plt.subplots(nrows=1,ncols=3, figsize=(30, 10), dpi=300, 
                                subplot_kw=dict(projection=cart_proj),
                      gridspec_kw={
                       'width_ratios': [1, 1, 1],
                       'height_ratios': [1]})        

### --- Plots before --- ###
traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
        'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time_before, end_time_before)

trajs = Tra()
trajs.load_ascii(traj_data_dir)
trajs['z'] = trajs['z']/1000
if end_cut_before == 0:
    pass
else:
    trajs = [t[:-end_cut_before] for t in trajs]

trajs_all=[]
for t in trajs:
        if (t['lat'][-1] > 0):
                trajs_all.append(t)

# Set map extent 
ax[0].set_extent([subset_small_extent[0],subset_small_extent[1],
               subset_small_extent[2],subset_small_extent[3]],
                 ccrs.PlateCarree())   

# Plot trajectories
levels_trajs_all = np.linspace(traj_variable_min, traj_variable_max, 11)
rounded_levels_trajs_all = [ round(elem, 1) for elem in levels_trajs_all ]

plt_trajs = plot_trajs(ax[0], trajs_all[::-2], traj_variable_name, 
                           linewidth=2, levels=rounded_levels_trajs_all, 
                           cmap=cmap, zorder=4)

# Plot the terrain height with colorbar
levels = np.linspace(0, 4, 21)
terrain = ax[0].contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
             levels=levels, transform=ccrs.PlateCarree(), 
             cmap=get_cmap('Greys'), alpha=0.75, zorder=1)

# Add borders and coastlines
ax[0].add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
ax[0].add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)

# Add cross for initiation location
for t in trajs:
    ax[0].plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())

# Add gridlines
lon = np.arange(0, 20, 1)
lat = np.arange(40, 60, 1)

ax[0].gridlines(xlocs=lon, ylocs=lat, zorder=2)

# Add tick labels
mpu.yticklabels(lat, ax=ax[0], fontsize=15)
mpu.xticklabels(lon, ax=ax[0], fontsize=15)

# plot during
traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
        'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time_during, end_time_during)

trajs = Tra()
trajs.load_ascii(traj_data_dir)
trajs['z'] = trajs['z']/1000
if end_cut_during == 0:
    pass
else:
    trajs = [t[:-end_cut_during] for t in trajs]

trajs_all=[]
for t in trajs:
        if (t['lat'][-1] > 0):
                trajs_all.append(t)

### Set map extent ###
ax[1].set_extent([subset_small_extent[0],subset_small_extent[1],
               subset_small_extent[2],subset_small_extent[3]],
                 ccrs.PlateCarree())   

# Plot trajectories
plot_trajs(ax[1], trajs_all[::-2], traj_variable_name, 
                           linewidth=2, levels=rounded_levels_trajs_all, 
                           cmap=cmap, zorder=4)

# Plot the terrain height with colorbar
ax[1].contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
             levels=levels, transform=ccrs.PlateCarree(), 
             cmap=get_cmap('Greys'), alpha=0.75, zorder=1)

# Add borders and coastlines
ax[1].add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
ax[1].add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)

# Add cross for initiation location
for t in trajs:
    ax[1].plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())

# Add gridlines
ax[1].gridlines(xlocs=lon, ylocs=lat, zorder=2)

# Add tick labels
mpu.yticklabels(lat, ax=ax[1], fontsize=15)
mpu.xticklabels(lon, ax=ax[1], fontsize=15)    

# plot after
traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
        'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time_after, end_time_after)

trajs = Tra()
trajs.load_ascii(traj_data_dir)
trajs['z'] = trajs['z']/1000
if end_cut_after == 0:
    pass
else:
    trajs = [t[:-end_cut_after] for t in trajs]

trajs_all=[]
for t in trajs:
        if (t['lat'][-1] > 0):
                trajs_all.append(t)

ax[2].set_extent([subset_small_extent[0],subset_small_extent[1],
               subset_small_extent[2],subset_small_extent[3]],
                 ccrs.PlateCarree())   

# Plot trajectories
plot_trajs(ax[2], trajs_all[::-2], traj_variable_name, 
                           linewidth=2, levels=rounded_levels_trajs_all, 
                           cmap=cmap, zorder=4)

# Plot the terrain height with colorbar
ax[2].contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
             levels=levels, transform=ccrs.PlateCarree(), 
             cmap=get_cmap('Greys'), alpha=0.75, zorder=1)

# Add borders and coastlines
ax[2].add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
ax[2].add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)

# Add cross for initiation location
for t in trajs:
    ax[2].plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())

ax[2].gridlines(xlocs=lon, ylocs=lat, zorder=2)

# Add tick labels
mpu.yticklabels(lat, ax=ax[2], fontsize=15)
mpu.xticklabels(lon, ax=ax[2], fontsize=15)

# Plot color bars
ax[0].text(0.05, 0.95, '(d)', transform=ax[0].transAxes,
  fontsize=25, fontweight='bold', va='top', zorder=10)
ax[1].text(0.05, 0.95, '(e)', transform=ax[1].transAxes, 
  fontsize=25, fontweight='bold', va='top', zorder=10)
ax[2].text(0.05, 0.95, '(f)', transform=ax[2].transAxes, 
  fontsize=25, fontweight='bold', va='top', zorder=10)

# Plot color bars    
p0 = ax[0].get_position().get_points().flatten()
p1 = ax[2].get_position().get_points().flatten()

if case_study_name == 'case_study_2':
    bottom = 0.075
elif case_study_name == 'case_study_3':
    bottom = 0.2
elif case_study_name == 'case_study_4':
    bottom = 0.1
else:
    bottom = 0.05
    
cbar_ax = fig.add_axes([(p0[2]+p0[0])/2, bottom, (((p1[2]+p1[0])/2)-((p0[2]+p0[0])/2)), 0.025])

cbar_ter = fig.colorbar(terrain, cax=cbar_ax, orientation='horizontal', aspect=80, use_gridspec=True)
cbar_ter.set_label('Terrain Height [$km$]', fontsize=25)
cbar_ter.set_ticks(levels)
cbar_ter.ax.tick_params(labelsize=20)

for label in cbar_ter.ax.xaxis.get_ticklabels()[1::2]:
    label.set_visible(False)

# Add color bar for trajectory variable
if traj_variable_name == 'theta_e':
    extend = 'both'       
else:
    extend = 'max'
cbar_trajs = mpu.colorbar(plt_trajs, ax[2], orientation='vertical', 
                          aspect=15, shrink=.05, pad=0.1, extend=extend)
cbar_trajs.set_label(colorbar_label_trajs, fontsize=25)
cbar_trajs.set_ticks(rounded_levels_trajs_all)
cbar_trajs.ax.tick_params(labelsize=20)

plt.show()

fig.savefig('{}_traj_{}_subset_small'.format(case_study_name,traj_variable_name), bbox_inches='tight', dpi=300)