In [None]:
import numpy as np
import xarray as xr
import datetime as dt
import seaborn as sns
import mplotutils as mpu
import cartopy.crs as ccrs
from netCDF4 import Dataset
from dypy.lagranto import Tra
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import cartopy.feature as cfeature
from lagranto.plotting import plot_trajs
from matplotlib.colors import ListedColormap
from wrf import to_np, getvar, latlon_coords, CoordPair, ll_to_xy, Constants

%matplotlib inline

# Map of Trajectories from Lagranto

In [None]:
def lagranto_plotting(traj_variable_name, case_study_name, start_time, end_time,
                     end_cut, end_time_adj, trajs_bunch_level='all', subset=False, 
                      subset_small=False, save=False):
    '''This function plots the chosen variables for the trajectories 
    calculated with Lagranto. Supported variables for plotting procedure 
    are water_vapor, updraft, height and theta_e.'''
    
    ### Predefine some variables ###
    if subset_small == True:
        number_trajs_plot = 2
    else:
        number_trajs_plot = 5
        
    if case_study_name == 'case_study_1':
        case_study_abbr = 'cs1'
        date = '2018-05-22'
        initiation_location = [7.56971, 47.4961]
        low_level = 2.5
        mid_level = 6
        subset_extent = [5.5, 9, 46, 48.5]
        subset_small_extent = [6.75, 8.5, 46.75, 48.25]
    elif case_study_name == 'case_study_2':
        case_study_abbr = 'cs2_eth'
        date = '2018-05-09'
        initiation_location = [7.63891, 47.0546]
        low_level = 2
        mid_level = 7
        subset_extent = [7, 9, 46.75, 48.75]
        subset_small_extent = [7, 9, 46.9, 48.25]
    elif case_study_name == 'case_study_3':
        case_study_abbr = 'cs3_eth'
        date = '2018-05-12'
        initiation_location = [8.64449, 47.5522]
        low_level = 3
        mid_level = 6.5
        subset_extent = [7.5, 12, 45.8, 48.5]
        subset_small_extent = [8, 12, 46.5, 48.25]
    elif case_study_name == 'case_study_4':
        case_study_abbr = 'cs4_eth'
        date = '2018-05-10'
        initiation_location = [8.67235, 47.629]
        low_level = 2.5
        mid_level = 6
        subset_extent = [7, 12.5, 47, 49.4]
        subset_small_extent = [7, 9.5, 47.2, 48.75]
    elif case_study_name == 'case_study_5':
        case_study_abbr = 'cs5_eth'
        date = '2018-05-30'
        initiation_location = [6.95974, 46.4482]
        low_level = 2.5
        mid_level = 8
        subset_extent = [6, 8.5, 45, 47]
        subset_small_extent = [6.5, 8.25, 45.5, 47.1]
        
    traj_data_dir = '/scratch3/thomasl/work/retrospective_part/lagranto/' \
                    'traj_{}_{}_{}.ll'.format(case_study_abbr, start_time, end_time)
    
    save_dir = '/scratch3/thomasl/work/retrospective_part/' \
                '{}/lagranto/'.format(case_study_name)

    
    start_locations = 'area'    # distinction between single and multiple 
                                # starting points of trajectories


    # Variables for getting PBL height of WRF model data
    wrf_filename = '/scratch3/thomasl/work/data/{}/' \
                    'wrfout_d02_{}_{}:{}:00'.format(case_study_name, date, 
                                                    start_time[:2], start_time[2:])
    
    # Variables:
    if traj_variable_name == 'water_vapor':
        traj_variable_name = 'QVAPOR'
        title_name = 'Trajectories of Water Vapor'
        colorbar_label_trajs = 'Water Vapor Mixing Ratio [$g$ $kg^{-1}$]'
        save_name = 'trajectory_water_vapor_{}_{}'.format(start_time, end_time)
        if subset_small == True:
            #if case_study_name == 'case_study_1':
            traj_variable_min = 0
            traj_variable_max = 10
        elif case_study_name == 'case_study_5':
            traj_variable_min = 0
            traj_variable_max = 12
        else:                 
            traj_variable_min = 0
            traj_variable_max = 10

        cmap = ListedColormap(sns.cubehelix_palette(10, start=.5, rot=-.75,))

    elif traj_variable_name == 'updraft':
        traj_variable_name = 'W_UP_MAX'
        title_name = 'Trajectories of Updraft'
        colorbar_label_trajs = 'Max Z-Wind Updraft [$m$ $s^-$$^1$]'
        save_name = 'trajectory_updraft_{}_{}'.format(start_time, end_time)
        traj_variable_min = 0
        traj_variable_max = 5
        cpalette = ['#fee0b6', '#e0e0e0', '#dadaeb', '#bcbddc', '#9e9ac8', '#807dba', 
                    '#6a51a3', '#54278f', '#3f007d', '#2d004b']
        cmap = ListedColormap(sns.color_palette(cpalette).as_hex())
        
    elif traj_variable_name == 'height':
        traj_variable_name = 'z'
        title_name = 'Height of Trajectories'
        colorbar_label_trajs = 'Height of Trajectories [$km$]'
        save_name = 'trajectory_height_{}_{}'.format(start_time, end_time)
        traj_variable_min = 0
        if subset_small == True:
            traj_variable_max = mid_level
        else:
            traj_variable_max = 10
            
        cpalette = ['#543005', '#8c510a', '#bf812d', '#dfc27d', '#f6e8c3', '#c7eae5', 
                  '#80cdc1', '#35978f', '#01665e', '#003c30']
        cmap = ListedColormap(sns.color_palette(cpalette).as_hex())
    
    elif traj_variable_name == 'theta_e':
        title_name = 'Trajectories of Theta-E'
        colorbar_label_trajs = 'Equivalent Potential Temperature [$K$]'
        save_name = 'trajectory_theta_{}_{}'.format(start_time, end_time)
        cmap = plt.get_cmap('RdYlBu_r')
        if case_study_name == 'case_study_1':
            traj_variable_min = 320
            traj_variable_max = 325
        elif case_study_name == 'case_study_2':
            traj_variable_min = 313
            traj_variable_max = 323    
        elif case_study_name == 'case_study_3':
            traj_variable_min = 318
            traj_variable_max = 328 
        elif case_study_name == 'case_study_4':
            traj_variable_min = 313
            traj_variable_max = 318
        elif case_study_name == 'case_study_5':
            traj_variable_min = 326
            traj_variable_max = 336
            
    ### Plotting procedure ###
    trajs = Tra()
    trajs.load_ascii(traj_data_dir)
    trajs['z'] = trajs['z']/1000
    
    if end_cut == 0:
        pass
    else:
        trajs = [t[:end_cut] for t in trajs]
    
    # Get PBL out of WRF model data
    # Make a function that deals with points which have longitudes > STAND_LON
    def rotated_ll_to_xy(wrfFile, latitude, longitude):
        res = ll_to_xy(wrfFile, latitude=latitude, longitude=longitude)
        diffPix = np.round((2*Constants.WRF_EARTH_RADIUS*Constants.PI)/wrfFile.DX,0)
        xs = np.where(np.ravel(longitude) >= wrfFile.STAND_LON*-1, res[0]+diffPix, res[0])
        res[0] = xs[0]
        return(res)
    
    ncfile = Dataset(wrf_filename)
    wrf_file = xr.open_dataset(wrf_filename)
    data = wrf_file.PBLH

    location = rotated_ll_to_xy(ncfile, initiation_location[1], 
                                initiation_location[0])
    data_point = data.sel(west_east=location[0], south_north=location[1])
    pbl = data_point.values

    # Separate trajectories in 3 vertical bunches
    # Trajectories of pbl (according to pbl height of WRF model data)
    trajs_low=[]
    for t in trajs:
            if (t['z'][0] <= low_level and t['lat'][-1] > 0):
                    trajs_low.append(t)

    # Trajectories between pbl and 5000 m   
    trajs_mid=[]
    for t in trajs:
            if (t['z'][0] <= mid_level and t['lat'][-1] > 0):
                    trajs_mid.append(t)
    
    # Trajectories between 5000 m and 10000 m                
    trajs_high=[]
    for t in trajs:
            if (t['z'][0] > mid_level and t['lat'][-1] > 0): #and t['z'][0]<=12):
                    trajs_high.append(t)
                    
    trajs_all=[]
    for t in trajs:
            if (t['lat'][-1] > 0): #and t['z'][0]<=12):
                    trajs_all.append(t)
                    
    #trajs = np.array(trajs_low)
                    
    trajs2=[]
    for t in trajs:
            if (t['z'][0]<=2000):
                    trajs2.append(t)
    trajs212=[]
    for t in trajs:
            if (t['z'][0]>2000 and t['z'][0]<=12000):
                    trajs212.append(t)
                        
    if trajs_bunch_level == 'low':
        trajs_bunch = trajs_low
        
    elif trajs_bunch_level == 'mid':
        trajs_bunch = trajs_mid
        
    elif trajs_bunch_level == 'high':
        trajs_bunch = trajs_high
    
    elif trajs_bunch_level == 'all':
        trajs_bunch = trajs_all
    
    # Get terrain height
    terrain_height = getvar(ncfile, 'HGT')/1000 # change to km

    # Define cart projection
    lats, lons = latlon_coords(terrain_height)
    cart_proj = ccrs.LambertConformal(central_longitude=8.722206, 
                                    central_latitude=46.73585)

    # Create figure
    fig = plt.figure(figsize=(15,10))
    ax = plt.axes(projection=cart_proj)

    ### Set map extent ###
    domain_extent = [3.701088, 13.814863, 43.85472, 49.49499]

    if subset == True:
        ax.set_extent([subset_extent[0],subset_extent[1],
                       subset_extent[2],subset_extent[3]],
                         ccrs.PlateCarree())
    
    elif subset_small == True:
        ax.set_extent([subset_small_extent[0],subset_small_extent[1],
                       subset_small_extent[2],subset_small_extent[3]],
                         ccrs.PlateCarree())        

    else: 
        ax.set_extent([domain_extent[0]+0.7,domain_extent[1]-0.7,
                       domain_extent[2]+0.1,domain_extent[3]-0.1],
                         ccrs.PlateCarree())
    
    # Plot trajectories
    levels_trajs = np.linspace(traj_variable_min, traj_variable_max, 11)
    rounded_levels_trajs = [ round(elem, 1) for elem in levels_trajs ]
    
    plt_trajs = plot_trajs(ax, trajs_bunch[::-number_trajs_plot], traj_variable_name, 
                               linewidth=2, levels=rounded_levels_trajs, 
                               cmap=cmap, zorder=4)
    
    #if trajs_bunch_level == 'all':
     #   plot_trajs(ax, trajs2[0::10], traj_variable_name,
      #                             linewidth=3, levels=rounded_levels_trajs, 
       #                            cmap=cmap)

    # Plot the terrain height with colorbar
    levels = np.linspace(0, 4, 21)
    terrain = plt.contourf(to_np(lons), to_np(lats), to_np(terrain_height), 
                 levels=levels, transform=ccrs.PlateCarree(), 
                 cmap=get_cmap('Greys'), alpha=0.75, zorder=1)
    
    cbar = mpu.colorbar(terrain, ax, orientation='horizontal', 
                        aspect=40, shrink=.05, pad=0.075)
    cbar.set_label('Terrain Height [$km$]', fontsize=15)
    cbar.set_ticks(levels)
    
    # Make only every second color bar tick label visible
    for label in cbar.ax.xaxis.get_ticklabels()[1::2]:
        label.set_visible(False)
    
    # Add color bar for trajectory variable
    if traj_variable_name == 'theta_e':
        extend = 'both'       
    else:
        extend = 'max'
    cbar_trajs = mpu.colorbar(plt_trajs, ax, orientation='vertical', 
                              aspect=40, shrink=.05, pad=0.05, extend=extend)
    cbar_trajs.set_label(colorbar_label_trajs, fontsize=15)
    cbar_trajs.set_ticks(rounded_levels_trajs)
        
    # Add borders and coastlines
    ax.add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.8)
    ax.add_feature(cfeature.COASTLINE.with_scale('10m'), linewidth=0.8)
    
    # Add cross for initiation location
    for t in trajs:
        ax.plot(t['lon'][0],t['lat'][0],'kx', zorder=5, transform=ccrs.PlateCarree())
    
    # Add gridlines
    lon = np.arange(0, 20, 1)
    lat = np.arange(40, 60, 1)

    gl = ax.gridlines(xlocs=lon, ylocs=lat, zorder=2)

    # Add tick labels
    mpu.yticklabels(lat, ax=ax, fontsize=12.5)
    mpu.xticklabels(lon, ax=ax, fontsize=12.5)

    # Set title
    ax.set_title(title_name, loc='left', fontsize=15)
    ax.set_title('Time Range: {} - {} UTC'.format(start_time, end_time_adj), loc='right',
                 fontsize=12.5)
    
    plt.show()

    ### Save figure ###
    if save == True: 
        if subset == True:
            if trajs_bunch_level == 'all':
                fig.savefig('{}{}_subset_{}_eth.png'.format(save_dir, save_name, start_locations),
                       bbox_inches='tight', dpi=300)
            else:
                fig.savefig('{}{}_subset_{}_{}_eth.png'.format(save_dir, save_name, 
                                                           trajs_bunch_level, start_locations),
                       bbox_inches='tight', dpi=300)                
        
        elif subset_small == True:
            if trajs_bunch_level == 'all':
                fig.savefig('{}{}_subset_small_{}_eth.png'.format(save_dir, save_name, start_locations),
                       bbox_inches='tight', dpi=300)
            else:
                fig.savefig('{}{}_subset_small_{}_{}_eth.png'.format(save_dir, save_name, 
                                                           trajs_bunch_level, start_locations),
                       bbox_inches='tight', dpi=300)                      
        else: 
            if trajs_bunch_level == 'all':
                fig.savefig('{}{}_{}_eth.png'.format(save_dir, save_name, start_locations),
                       bbox_inches='tight', dpi=300)
            else:
                fig.savefig('{}{}_{}_{}_eth.png'.format(save_dir, save_name, 
                                                    trajs_bunch_level, start_locations),
                       bbox_inches='tight', dpi=300) 

In [None]:
lagranto_plotting(traj_variable_name='height', case_study_name='case_study_1', start_time='1630', 
                  end_time='0800', end_cut=-6, end_time_adj='0830', subset=True, save=True)