In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import seaborn as sns
import mplotutils as mpu
import cartopy.crs as ccrs
from netCDF4 import Dataset
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.cm import get_cmap
import matplotlib.patheffects as path_effects
from cartopy.feature import NaturalEarthFeature
from matplotlib.colors import from_levels_and_colors, ListedColormap
from wrf import to_np, getvar, CoordPair, vertcross, latlon_coords, get_cartopy, interpline, ll_to_xy, Constants

from wrf import Constants
from wrf.extension import _tk, _eth
from wrf.util import extract_vars

%matplotlib inline

# Vertical Cross-section

In [None]:
def cross_section(variable_name, case_study_name, save=False):
    
    '''This function plots a vertical cross section of the chosen 
    variable. Supported variables for plotting procedure are 
    vertical_velocity, rh, omega, absolute_vorticity, theta_e and
    reflectivity.'''
    
    ### Predefine some variables ###
    if case_study_name == 'case_study_1':
        date = '2018-05-22'
        time = '17:00'
        initiation_lat = 47.4961
        initiation_lon = 7.56971
    elif case_study_name == 'case_study_2':
        date = '2018-05-09'
        time = '14:05'
        initiation_lat = 47.0546
        initiation_lon = 7.63891
    if case_study_name == 'case_study_3':
        date = '2018-05-12'
        time = '20:15'
        initiation_lat = 47.5522
        initiation_lon = 8.64449
    elif case_study_name == 'case_study_4':
        date = '2018-05-10'
        time = '01:00'
        initiation_lat = 47.629
        initiation_lon = 8.67235
    elif case_study_name == 'case_study_5':
        date = '2018-05-30'
        time = '14:50'
        initiation_lat = 46.4482
        initiation_lon = 6.95974
        
    # Define data filename
    data_dir = '/mnt/climstor/wrfout2/WRF4.0_test51/'
    
    filename = '{}wrfout_d02_{}_{}:00'.format(data_dir, date, time)
    
    # Define save directory
    save_dir = '/scratch3/thomasl/work/retrospective_part/' \
                '{}/cross_sections/'.format(case_study_name)

    ### Start plotting procedure ###
    # Open NetCDF file
    ncfile = Dataset(filename)
    
    # Create pivot point for cross section
    cross_angle = 0
    
    # Make a function that deals with points which have longitudes > STAND_LON
    def rotated_ll_to_xy(wrfFile, latitude, longitude):
        res = ll_to_xy(wrfFile, latitude=latitude, longitude=longitude)
        diffPix = np.round((2*Constants.WRF_EARTH_RADIUS*Constants.PI)/wrfFile.DX,0)
        xs = np.where(np.ravel(longitude) >= wrfFile.STAND_LON*-1, res[0]+diffPix, res[0])
        res[0] = xs[0]
        return(res)
    
    xy = rotated_ll_to_xy(ncfile, initiation_lat, initiation_lon)
    pivot_point = CoordPair(x=xy[0], y=xy[1])

    # Extract the model height, terrain height and variables
    ht = getvar(ncfile, 'z')/1000 # change to km
    ter = getvar(ncfile, 'ter')/1000
    
    if variable_name == 'vertical_velocity': 
        variable = getvar(ncfile, 'wa', units='kt')
        title_name = 'Vertical Velocity'
        colorbar_label = 'Vertical Velocity [$kn$]'
        variable_min = -2
        variable_max = 2
    
    elif variable_name == 'rh':
        variable = getvar(ncfile, 'rh')
        title_name = 'Relative Humidity'
        colorbar_label = 'Relative Humidity [$pct$]'
        variable_min = 0
        variable_max = 100
    
    elif variable_name == 'omega':
        variable = getvar(ncfile, 'omega')
        title_name = 'Vertical Motion (Omega)'
        colorbar_label = 'Omega [$Pa$ $s^-$$^1$]'
        if case_study_name == 'case_study_4':
            variable_min = -5
            variable_max = 5           
        else:
            variable_min = -15
            variable_max = 15
            
    elif variable_name == 'absolute_vorticity':
        variable = getvar(ncfile, 'avo')
        title_name = 'Absolute Vorticity'
        colorbar_label = 'Absolute Vorticity [$10^{-5}$' \
                            '$s^{-1}$]'
        variable_min = -50
        variable_max = 100
        
    elif variable_name == 'theta_e':
        variable = getvar(ncfile, 'theta_e')
        title_name = 'Theta-E'
        colorbar_label = 'Theta-E [$K$]'
        if case_study_name == 'case_study_4':
            variable_min = 310
            variable_max = 320           
        elif  case_study_name == 'case_study_3':
            variable_min = 315
            variable_max = 325 
        elif  case_study_name == 'case_study_1':
            variable_min = 313
            variable_max = 323
        elif  case_study_name == 'case_study_2':
            variable_min = 310
            variable_max = 320 
        elif  case_study_name == 'case_study_5':
            variable_min = 322
            variable_max = 332
    
    elif variable_name == 'reflectivity':
        variable = getvar(ncfile,'REFL_10CM')#, timeidx=-1
        title_name = 'Reflectivity'
        colorbar_label = 'Reflectivity [$dBZ$]'
        variable_min = 5
        variable_max = 75     
    
    # Linear Z for interpolation 
    Z = 10**(variable/10)

    # Compute the vertical cross-section interpolation
    z_cross = vertcross(Z, ht, wrfin=ncfile,
                        pivot_point=pivot_point,
                        angle=cross_angle,
                        latlon=True, meta=True) 

    # Convert back after interpolation
    variable_cross = 10.0 * np.log10(z_cross)

    # Make a copy of the z cross data
    variable_cross_filled = np.ma.copy(to_np(variable_cross))

    # For each cross section column, find the first index with 
    # non-missing values and copy these to the missing elements below
    for i in range(variable_cross_filled.shape[-1]):
        column_vals = variable_cross_filled[:,i]
        first_idx = int(np.transpose((column_vals > -200).nonzero())[0])
        variable_cross_filled[0:first_idx, i] = variable_cross_filled[first_idx, i]

    ter_line = interpline(ter, wrfin=ncfile, pivot_point=pivot_point,
                          angle=cross_angle)
    
    # Get latitude and longitude points
    lats, lons = latlon_coords(variable)

    # Create the figure
    fig = plt.figure(figsize=(15,10))
    ax = plt.axes()

    ys = to_np(variable_cross.coords['vertical'])
    xs = np.arange(0, variable_cross.shape[-1], 1)
    
    # Make contour plot
    if variable_name == 'reflectivity':
        levels = np.arange(variable_min, variable_max, 5)
        
        # Create the dbz color table found on NWS pages.
        dbz_rgb = np.array([[4,233,231],
                            [1,159,244], [3,0,244],
                            [2,253,2], [1,197,1],
                            [0,142,0], [253,248,2],
                            [229,188,0], [253,149,0],
                            [253,0,0], [212,0,0],
                            [188,0,0],[248,0,253],
                            [152,84,198]], np.float32) / 255.0

        dbz_cmap, dbz_norm = from_levels_and_colors(levels, dbz_rgb,
                                                   extend='max')
    
    else:
        levels = np.linspace(variable_min, variable_max, 11)
    
    if variable_name == 'omega' or variable_name == 'vertical_velocity':
        def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
            new_cmap = colors.LinearSegmentedColormap.from_list(
                'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, 
                                                    b=maxval),
                cmap(np.linspace(minval, maxval, n)))
            return new_cmap

        old_cmap = plt.get_cmap('RdYlBu')
        cmap = truncate_colormap(old_cmap, 0.1, 0.9)

        norm = colors.DivergingNorm(vmin=variable_min, vcenter=-1, vmax=variable_max)
    
    elif variable_name == 'theta_e':
            cmap = plt.get_cmap('RdYlBu_r')
    
    else:
        cmap = ListedColormap(sns.cubehelix_palette(20, start=.5, rot=-.75))
    
    if variable_name == 'omega' or variable_name == 'vertical_velocity':
        cmap = plt.get_cmap('PRGn')
        variable_contours = ax.contourf(xs, ys, 
                                        to_np(variable_cross_filled), 
                                        levels=levels, cmap=cmap, 
                                        extend='both')#, norm=norm)
    
    elif variable_name == 'rh':
        variable_contours = ax.contourf(xs, ys, 
                                        to_np(variable_cross_filled), 
                                        levels=levels, cmap=cmap, 
                                        extend='neither')
    elif variable_name == 'reflectivity':
        variable_contours = ax.contourf(xs, ys, 
                                        to_np(variable_cross_filled), 
                                        levels=levels, cmap=dbz_cmap, 
                                        norm=dbz_norm, extend='both')    
    
    else:
        variable_contours = ax.contourf(xs, ys, 
                                        to_np(variable_cross_filled), 
                                        levels=levels, cmap=cmap, 
                                        extend='both')
    # Plot wind barbs     
    if variable_name == 'vertical_velocity':
        u = getvar(ncfile, 'ua', units='kt')
        
        U = 10**(u/10)
        
        u_cross = vertcross(U, ht, wrfin=ncfile, 
                                   pivot_point=pivot_point,
                                      angle=cross_angle, latlon=True, 
                                   meta=True)  
        
        u_cross = 10.0 * np.log10(u_cross)
        
        u_cross_filled = np.ma.copy(to_np(u_cross))

        for i in range(u_cross_filled.shape[-1]):
            column_vals = u_cross_filled[:,i]
            first_idx = int(np.transpose((column_vals > -200).nonzero())[0])
            u_cross_filled[0:first_idx, i] = u_cross_filled[first_idx, i]
        
        ax.barbs(xs[::3], ys[::3], to_np(u_cross_filled[::3, ::3]), 
                 to_np(variable_cross_filled[::3, ::3]), length=7, zorder=1)
        
    if variable_name == 'omega':
        u = getvar(ncfile, 'ua', units='kt')
        
        U = 10**(u/10)
        
        u_cross = vertcross(U, ht, wrfin=ncfile, 
                                   pivot_point=pivot_point,
                                   angle=cross_angle, latlon=True, 
                                   meta=True)  
        
        u_cross = 10.0 * np.log10(u_cross)
        
        u_cross_filled = np.ma.copy(to_np(u_cross))

        for i in range(u_cross_filled.shape[-1]):
            column_vals = u_cross_filled[:,i]
            first_idx = int(np.transpose((column_vals > -200).nonzero())[0])
            u_cross_filled[0:first_idx, i] = u_cross_filled[first_idx, i]

        w = getvar(ncfile, 'wa', units='kt')
        
        W = 10**(w/10)
        
        w_cross = vertcross(W, ht, wrfin=ncfile, 
                                   pivot_point=pivot_point,
                                   angle=cross_angle, latlon=True, 
                                   meta=True)  
        
        w_cross = 10.0 * np.log10(w_cross)
        
        w_cross_filled = np.ma.copy(to_np(w_cross))

        for i in range(w_cross_filled.shape[-1]):
            column_vals = w_cross_filled[:,i]
            first_idx = int(np.transpose((column_vals > -200).nonzero())[0])
            w_cross_filled[0:first_idx, i] = w_cross_filled[first_idx, i]
        
        ax.barbs(xs[::3], ys[::3], to_np(u_cross_filled[::3, ::3]), 
                 to_np(w_cross_filled[::3, ::3]), length=7, zorder=1,
                color='k')        

    # Add color bar
    cbar = mpu.colorbar(variable_contours, ax, orientation='vertical', 
                        aspect=20, shrink=.05, pad=0.05)
    cbar.set_label(colorbar_label, fontsize=20)
    cbar.set_ticks(levels)
    cbar.ax.tick_params(labelsize=15)


    # Set x-ticks to use latitude and longitude labels
    coord_pairs = to_np(variable_cross.coords['xy_loc'])
    x_ticks = np.arange(coord_pairs.shape[0])
    x_labels = [pair.latlon_str(fmt='{:.2f}, {:.2f}')
                for pair in to_np(coord_pairs)]
    
    # Set desired number of x ticks below
    num_ticks = 30
    thin = int((len(x_ticks) / num_ticks) + .5)
    ax.set_xticks(x_ticks[::thin])
    ax.set_xticklabels(x_labels[::thin], rotation=-15, fontsize=15, ha='left')
    #ax.set_xlim(x_ticks[0], x_ticks[-1])
    x_min = rotated_ll_to_xy(ncfile, initiation_lat-0.5, initiation_lon)
    x_max = rotated_ll_to_xy(ncfile, initiation_lat+0.5, initiation_lon)
    ax.set_xlim(int(x_min[1]), int(x_max[1]))
    
    # Set y-ticks and limit the height
    ax.set_yticks(np.linspace(0,12,13))
    ax.set_ylim(0, 12)

    # Set x-axis and y-axis labels
    ax.set_xlabel('Latitude, Longitude', fontsize=20)
    ax.set_ylabel('Height [$km$]', fontsize=20)
    
    # Fill in mountian area
    ht_fill = ax.fill_between(xs, 0, to_np(ter_line),
                                facecolor='saddlebrown', zorder=2)
    
    ax.text(0.01, 0.001, 'A', ha='left', va='bottom', transform=ax.transAxes, 
                     fontsize=30, weight='bold', zorder=5,
                    path_effects=[path_effects.Stroke(linewidth=3, foreground='white'),
                       path_effects.Normal()])
    ax.text(0.99, 0.001, 'B', ha='right', va='bottom', transform=ax.transAxes, 
                     fontsize=30, weight='bold', zorder=5, 
                    path_effects=[path_effects.Stroke(linewidth=3, foreground='white'),
                       path_effects.Normal()])
    
    ax.text(0.5, -0.0265, 'I', ha='right', va='bottom', transform=ax.transAxes, 
                     fontsize=15, zorder=5, weight='bold', c='r')
    
    # Make nicetime
    xr_file = xr.open_dataset(filename)
    nicetime = pd.to_datetime(xr_file.QVAPOR.isel(Time=0).XTIME.values)
    
    # Add title
    #ax.set_title('Vertical Cross-Section of {}'.format(title_name), fontsize=20, 
    #             loc='left')
    #ax.set_title('Valid Time: {} {}'.format(nicetime.strftime('%Y-%m-%d %H:%M'), 'UTC'), 
    #             fontsize= 15, loc='right')
        
    # Add grid for y axis
    ax.grid(axis='y', linestyle='--', color='grey')
    
    ax.tick_params(axis='both', which='major', labelsize=17.5, length=8, width=1.5)
    ax.tick_params(axis='both', which='minor', length=4, width=1)
    
    plt.show()
    
    ### Save figure ###
    if save == True:
        fig.savefig('{}cross_section_{}_{}_2.png'.format(save_dir, variable_name, time), 
        bbox_inches='tight', dpi=300)

In [None]:
cross_section('theta_e', 'case_study_1', save=True)