In [None]:
import numpy as np
import xarray as xr
from wrf import ll_to_xy, Constants
from netCDF4 import Dataset
from dypy.lagranto import Tra
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

%matplotlib inline

# Temporal Evolution of Trajectories

In [None]:
def temporal_evolution_trajectories(variable_name, case_study_name, start_time, end_time,
                                    time_cut, trajs_bunch_level='all', title=True, 
                                    legend=True, save=False):
    '''This function plots the temporal evolution of trajectories 
    from the chosen variables. Trajectories can be divided in 
    bunches corresponding to different height levels (pbl, 5, 10).
    Supported variables are water_vapor, height and updraft.'''
    
    ### Predefine some variables ###
    number_trajs_plot = 5
    
    if case_study_name == 'case_study_1':
        if variable_name == 'theta_e':
            case_study_abbr = 'cs1_eth'
        else:
            case_study_abbr = 'cs1'
        date = '2018-05-22'
        initiation_location = [7.56971, 47.4961]
        low_level = 2.5
        mid_level = 6 
    elif case_study_name == 'case_study_2':
        if variable_name == 'theta_e':
            case_study_abbr = 'cs2_eth'
        else:
            case_study_abbr = 'cs2'
        date = '2018-05-09'
        initiation_location = [7.63891, 47.0546]
        low_level = 2
        mid_level = 7        
    elif case_study_name == 'case_study_3':
        if variable_name == 'theta_e':
            case_study_abbr = 'cs3_eth'
        else:
            case_study_abbr = 'cs3'
        date = '2018-05-12'
        initiation_location = [8.64449, 47.5522]
        low_level = 3
        mid_level = 6.5         
    elif case_study_name == 'case_study_4':
        if variable_name == 'theta_e':
            case_study_abbr = 'cs4_eth'
        else:
            case_study_abbr = 'cs4'
        date = '2018-05-10'
        initiation_location = [8.67235, 47.629]
        low_level = 2.5
        mid_level = 6
    elif case_study_name == 'case_study_5':
        if variable_name == 'theta_e':
            case_study_abbr = 'cs5_eth'
        else:
            case_study_abbr = 'cs5'
        date = '2018-05-30'
        initiation_location = [6.95974, 46.4482]
        low_level = 2.5
        mid_level = 8
        
    traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
                    'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time, end_time)
    
    save_dir = '/scratch3/thomasl/work/retrospective_part/' \
                '{}/lagranto_evo/'.format(case_study_name)
    
    # Variables for getting PBL height of WRF model data
    filename = '/scratch3/thomasl/work/data/{}/' \
                    'wrfout_d02_{}_{}:{}:00'.format(case_study_name, date, 
                                                    start_time[:2], start_time[2:])
    
    # Define plotting variables
    if variable_name == 'water_vapor':
        variable = 'QVAPOR'
        title_name = 'Water Vapor'
        y_label = 'Water Vapor Mixing Ratio [$g$ $kg^{-1}$]'
        if case_study_name == 'case_study_5':
            ylim = [0, 12]
        else:
            ylim = [0, 10]
            
    elif variable_name == 'height':
        variable = 'z'
        title_name = 'Height'
        y_label = 'Height of Trajectories [$km$]'
        ylim = [0, 10]
            
    elif variable_name == 'updraft':
        variable = 'W_UP_MAX'
        title_name = 'Updraft'
        y_label = 'Max Z-Wind Updraft [$m$ $s^-$$^1$]'
        if case_study_name == 'case_study_3':
            ylim = [0, 10]
        elif case_study_name == 'case_study_1':
            ylim = [0, 10]
        elif case_study_name == 'case_study_2':
            ylim = [0, 6]
        elif case_study_name == 'case_study_4':
            ylim = [0, 5]
        elif case_study_name == 'case_study_5':
            ylim = [0, 12]
            
    elif variable_name == 'theta_e':
        variable = 'theta_e'
        title_name = 'Theta-E'
        y_label = 'Equivalent Potential Temperature [$K$]'
        if case_study_name == 'case_study_1':
            ylim = [317, 327]
        elif case_study_name == 'case_study_2':
            ylim = [310, 325]
        elif case_study_name == 'case_study_3':
            ylim = [316.5, 330]
        elif case_study_name == 'case_study_4':
            ylim = [312, 322]
        elif case_study_name == 'case_study_5':
            ylim = [317.5, 337.5]
            
    ### Plotting procedure ###
    trajs = Tra()
    trajs.load_ascii(traj_data_dir)
    trajs['z'] = trajs['z']/1000
    
    # Get PBL out of WRF model data
    # Make a function that deals with points which have longitudes > STAND_LON
    def rotated_ll_to_xy(wrfFile, latitude, longitude):
        res = ll_to_xy(wrfFile, latitude=latitude, longitude=longitude)
        diffPix = np.round((2*Constants.WRF_EARTH_RADIUS*Constants.PI)/wrfFile.DX,0)
        xs = np.where(np.ravel(longitude) >= wrfFile.STAND_LON*-1, res[0]+diffPix, res[0])
        res[0] = xs[0]
        return(res)
    
    ncfile = Dataset(filename)
    wrf_file = xr.open_dataset(filename)
    data = wrf_file.PBLH
    data2 = wrf_file.HGT
    
    location = rotated_ll_to_xy(ncfile, initiation_location[1], initiation_location[0])
    data_point = data.sel(west_east=location[0], south_north=location[1])
    pbl = data_point.values
    
    data_point2 = data2.sel(west_east=location[0], south_north=location[1])
    hgt = data_point2.values
    print('Terrain height: {} m'.format(hgt))
    
    # Separate trajectories in 3 vertical bunches
    # Trajectories of pbl (according to pbl height of WRF model data)
    trajs_low=[]
    for t in trajs:
            if (t['z'][0] <= low_level and t['lat'][-1] > 0):#pbl):
                    trajs_low.append(t)

    # Trajectories between pbl and 5000 m   
    trajs_mid=[]
    for t in trajs:
            if (t['z'][0] > low_level and t['z'][0] <= mid_level and t['lat'][-1] > 0):
                    trajs_mid.append(t)
    
    # Trajectories between 5000 m and 10000 m                
    trajs_high=[]
    for t in trajs:
            if (t['z'][0] > mid_level and t['lat'][-1] > 0): #and t['z'][0]<=12):
                    trajs_high.append(t)

    trajs = np.array(trajs_low)
    
    
    # Get time delta
    dt = 5 # delta time between data files
    xtime = (np.arange(-len(trajs['time'][0])*dt, 0, dt)+dt)/60

    # Make figure
    fig = plt.figure(figsize=(15,10))
    ax = plt.axes()
    
    # Plot (bunch of) trajectories
    if trajs_bunch_level == 'low':
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime, np.flipud(t[variable][:]), '-', color='grey',  lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='r',lw=2)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='r',lw=2)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='r',
                 lw=2)

    if trajs_bunch_level == 'low_mid':
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime, np.flipud(t[variable][:]), '-', color='grey',  lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='blue',lw=2)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='blue',lw=2)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='blue',
                 lw=2)
        
        trajs=np.array(trajs_mid)
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime,np.flipud(t[variable][:]),'-',color='grey',lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='orange',lw=2)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='orange',lw=2)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='orange',
                 lw=2)
        
    elif trajs_bunch_level == 'mid':
        trajs=np.array(trajs_mid)
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime,np.flipud(t[variable][:]),'-',color='grey',lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='orange',lw=2)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='orange',lw=2)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='orange',
                 lw=2)

    elif trajs_bunch_level == 'high':
        trajs=np.array(trajs_high)
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime,np.flipud(t[variable][:]),'-',color='grey',lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='b',lw=2)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='b',lw=2)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='b',
                 lw=2)

    elif trajs_bunch_level == 'all':
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime, np.flipud(t[variable][:]), '-', color='grey',  lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='r',lw=2.5, zorder=4)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='r',lw=2.5, zorder=4)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='r',
                 lw=2.5, zorder=4)

        trajs=np.array(trajs_mid)
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime,np.flipud(t[variable][:]),'-',color='grey',lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='orange',lw=2.5, zorder=3)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='orange',lw=2.5, zorder=3)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='orange',
                 lw=2.5, zorder=3)

        trajs=np.array(trajs_high)
        for t in trajs[::number_trajs_plot]:
            ax.plot(xtime,np.flipud(t[variable][:]),'-',color='grey',lw=0.5)

        ax.plot(xtime,np.flipud(np.percentile(trajs[:][variable][:],90,axis=0)),
                 '--',color='b',lw=2.5, zorder=2)
        ax.plot(xtime,np.flipud((np.percentile(trajs[:][variable][:],10,axis=0))),
                 '--',color='b',lw=2.5, zorder=2)
        ax.plot(xtime,np.flipud(np.mean(trajs[:][variable][:],0)),'-',color='b',
                 lw=2.5, zorder=2)

    # Add y- and x-axis label
    ax.set_ylabel(y_label, fontsize=25)
    ax.set_ylim(ylim[0], ylim[1])
    ax.tick_params(axis='both', which='major', labelsize=17.5, length=8, width=1.5)
    ax.tick_params(axis='both', which='minor', length=4, width=1)

    ax.set_xlabel('Time [$h$]', fontsize=25)
    ax.set_xlim(xtime[time_cut],0)
    ax.minorticks_on()
    
    # Add horizontal gridlines according to height levels
    plt.grid(axis='y', linestyle='--', color='grey')
    
    # Plot legend
    if trajs_bunch_level == 'low':
        line_mean = mlines.Line2D([], [], color='k', linestyle='-', linewidth=2, label='Mean')
        line_percentiles = mlines.Line2D([], [], color='k', linestyle='--', linewidth=2, 
                                         label='10th and 90th Percentile')
        line_mean_col_pbl = mlines.Line2D([], [], color='r', linestyle='-', linewidth=2, 
                                          label='0 up to {} km Starting Height'.format(low_level))
        ax.legend(handles=[line_mean, line_percentiles, line_mean_col_pbl], 
                   loc='upper left', fontsize=10)
        
    if trajs_bunch_level == '2':
        line_mean = mlines.Line2D([], [], color='k', linestyle='-', linewidth=2, label='Mean')
        line_percentiles = mlines.Line2D([], [], color='k', linestyle='--', linewidth=2, 
                                         label='10th and 90th Percentile')
        line_mean_col_pbl = mlines.Line2D([], [], color='b', linestyle='-', linewidth=2, 
                                          label='0 up to 1 km Starting Height')
        line_mean_col_5 = mlines.Line2D([], [], color='orange', linestyle='-', linewidth=2, 
                                        label='1 up to 2 km Starting Height')
        ax.legend(handles=[line_mean, line_percentiles, line_mean_col_pbl, line_mean_col_5], 
                   loc='upper right', fontsize=10)

    elif trajs_bunch_level == 'mid':
        line_mean = mlines.Line2D([], [], color='k', linestyle='-', linewidth=2, label='Mean')
        line_percentiles = mlines.Line2D([], [], color='k', linestyle='--', linewidth=2, 
                                         label='10th and 90th Percentile')
        line_mean_col_5 = mlines.Line2D([], [], color='orange', linestyle='-', linewidth=2, 
                                        label='{} up to {} km Starting Height'.format(low_level, mid_level))
        ax.legend(handles=[line_mean, line_percentiles, line_mean_col_5], 
                   loc='upper left', fontsize=10)

                            
    elif trajs_bunch_level == 'high':
        line_mean = mlines.Line2D([], [], color='k', linestyle='-', linewidth=2, label='Mean')
        line_percentiles = mlines.Line2D([], [], color='k', linestyle='--', linewidth=2, 
                                         label='10th and 90th Percentile')
        line_mean_col_10 = mlines.Line2D([], [], color='b', linestyle='-', linewidth=2, 
                                         label='{} km up to 10 km Starting Height'.format(mid_level))
        ax.legend(handles=[line_mean, line_percentiles, line_mean_col_10], 
                   loc='upper left', fontsize=10)

    
    else:
        line_mean = mlines.Line2D([], [], color='k', linestyle='-', linewidth=2, label='Mean')
        line_percentiles = mlines.Line2D([], [], color='k', linestyle='--', linewidth=2, 
                                         label='10th and 90th Percentile')
        line_mean_col_pbl = mlines.Line2D([], [], color='r', linestyle='-', linewidth=2, 
                                          label='SFC up to {} km Starting Height'.format(low_level))
        line_mean_col_5 = mlines.Line2D([], [], color='orange', linestyle='-', linewidth=2, 
                                        label='{} up to {} km Starting Height'.format(low_level, mid_level))
        line_mean_col_10 = mlines.Line2D([], [], color='b', linestyle='-', linewidth=2, 
                                         label='{} km up to 10 km Starting Height'.format(mid_level))
        if variable_name == 'water_vapor':
            if case_study_name == 'case_study_3':
                pos = 'upper left'
            elif case_study_name == 'case_study_5':
                pos = 'lower left'
            else:
                pos = 'upper right'
        elif variable == 'theta_e':
            if case_study_name == 'case_study_2' or case_study_name == 'case_study_5':
                pos = 'upper left'
            else:    
                pos = 'upper right'
        else:
            pos = 'upper left'
        
        if legend == True:
            ax.legend(handles=[line_mean, line_percentiles, line_mean_col_pbl, line_mean_col_5, 
                                line_mean_col_10], loc=pos, fontsize=10, framealpha=0.9)

    # Add title
    if title == True:
        ax.set_title('Temporal Evolution of {} along Trajectories'.format(title_name), 
                     fontsize = 15, loc='left')
        ax.set_title('Starting Time: {} - {} UTC'.format(date, start_time), loc='right',
                     fontsize=12.5)
    
    ax.text(0.05, 0.95, '(g)', transform=ax.transAxes,
      fontsize=40, fontweight='bold', va='top', zorder=10)
    
    plt.show()
    
    ### Save figure ###
    if save == True:
        fig.savefig('{}lagranto_evo_{}_{}_{}_{}_eth.png'.format(save_dir, variable_name, 
                                                                trajs_bunch_level, start_time, end_time), 
                            bbox_inches='tight', dpi=300)

In [None]:
temporal_evolution_trajectories(variable_name='theta_e', case_study_name='case_study_5', start_time='1420', 
                                end_time='0600', time_cut=4, trajs_bunch_level='all', title=False, 
                                legend=False, save=True)