In [None]:
import os
import numpy as np
import xarray as xr
import pandas as pd
import seaborn as sns
import mplotutils as mpu
from netCDF4 import Dataset
import matplotlib.pyplot as plt
import matplotlib.ticker as tkr
import matplotlib.colors as colors
from wrf import getvar, ll_to_xy, Constants
from matplotlib.colors import ListedColormap

import sys

# Temporal Evolution of Variable

In [None]:
# Create a list of all WRF file directories
data_directory = '/mnt/climstor/wrfout2/WRF4.0_test51/'

date = '2018-05-22'
#date2 = '2018-05-10'

wrflist = list()
for (dirpath, dirnames, filenames) in os.walk(data_directory):
    wrflist += [ os.path.join(dirpath, file) for file in filenames
               if file.startswith(('wrfout_d02_{}'.format(date))) ]#, 'wrfout_d02_{}'.format(date2))) ]

In [None]:
def temporal_evolution_variable(variable_name, periods, case_study_name, 
                   save=False):
    
    '''This function makes a plot of the chosen variable at the selected location 
    along the time axis (backwards or forwards in time). Supported variuables for 
    plotting procedure are wind_speed, rh, omega, absolue_vorticity, 
    theta_e. Format (inside brackets) of variable_name ('insert_name'), 
    date (YYYY-MM-DD), initiation_time (HH:MM), duration (HH), latitude (YY.Y) and 
    longitude (XXX.X).'''

    if case_study_name == 'case_study_1':
        date = '2018-05-22'
        time = '17:30' # +30min
        time_title = '17:00'
        latitude = 47.4961
        longitude = 7.56971
        
    elif case_study_name == 'case_study_2':
        date = '2018-05-09'
        time = '14:35' # +30min
        time_title = '14:05'
        latitude = 47.0546
        longitude = 7.63891
        
    elif case_study_name == 'case_study_3':
        date = '2018-05-12'
        time = '20:45' # +30min
        time_title = '20:15'
        latitude = 47.5522
        longitude = 8.64449
        
    if case_study_name == 'case_study_4':
        date = '2018-05-10'
        time = '01:30'
        time_title = '01:00'
        latitude = 47.629
        longitude = 8.67235
        
    elif case_study_name == 'case_study_5':
        date = '2018-05-30'
        time = '15:20' # +30min
        time_title = '14:50'
        latitude = 46.4482
        longitude = 6.95974
        
    # Predefine some variables
    if variable_name == 'wind_speed': 
        title_name = 'Wind Speed'
        colorbar_label = 'Wind Speed [$m$ $s^{-1}$]'
        variable_min = 0
        variable_max = 30
        
    elif variable_name == 'rh':
        title_name = 'Relative Humidity'
        colorbar_label = 'Relative Humidity [$pct$]'
        variable_min = 0
        variable_max = 100
        
    elif variable_name == 'omega':
        title_name = 'Vertical Motion (Omega)'
        colorbar_label = 'Omega [$Pa$ $s^-$$^1$]'
        if case_study_name == 'case_study_4':
            variable_min = -5
            variable_max = 5           
        else:
            variable_min = -15
            variable_max = 15
            
    elif variable_name == 'absolute_vorticity':
        variable_name = 'avo'
        title_name = 'Absolute Vorticity'
        colorbar_label = 'Absolute Vorticity [$10^{-5}$' \
                            '$s^{-1}$]'
        variable_min = -50
        variable_max = 100
        
    elif variable_name == 'theta_e':
        title_name = 'Equivalent Potential Temperature (Theta-E)'
        colorbar_label = 'Theta-E [$K$]'
        if case_study_name == 'case_study_4':
            variable_min = 310
            variable_max = 320         
        elif  case_study_name == 'case_study_3':
            variable_min = 315
            variable_max = 325 
        elif  case_study_name == 'case_study_1':
            variable_min = 313
            variable_max = 323    
        elif  case_study_name == 'case_study_2':
            variable_min = 310
            variable_max = 320
        elif  case_study_name == 'case_study_5':
            variable_min = 322
            variable_max = 332
            
    ### Plotting Iteration ###
    print('------ Start iteration process ------')
    # Define frequency for time range
    frequency = '5min'

    # Make a function that deals with points which have longitudes > STAND_LON
    def rotated_ll_to_xy(wrfFile, latitude, longitude):
        res = ll_to_xy(wrfFile, latitude=latitude, longitude=longitude)
        diffPix = np.round((2*Constants.WRF_EARTH_RADIUS*Constants.PI)/wrfFile[0].DX,0)
        xs = np.where(np.ravel(longitude) >= wrfFile[0].STAND_LON*-1, res[0]+diffPix, res[0])
        res[0] = xs[0]
        return(res)

    # Make date range
    initiation_time = date +' '+ time +':00'

    date_range = pd.date_range(end=initiation_time, periods=periods+1, freq=frequency)
    date_range_str = date_range.strftime('%Y-%m-%d_%H:%M:%S')

    variables_point, variables_point2, variables_point3, variables_point4, \
    variables_point5, variables_point6, variables_point7, variables_point8, \
    variables_point9 = ([] for i in range(9))
    
    for timeSteps in date_range_str[::-1]:
        print('Time step: {}'.format(timeSteps))
            
        ncfile = [Dataset(x) for x in wrflist
            if x.endswith(timeSteps)]
                    
        if variable_name == 'wind_speed':
            variable = getvar(ncfile, 'wspd_wdir')[0,:]

        else:
            variable = getvar(ncfile, variable_name)
                    
        ds_list = [ xr.open_dataset(x) for x in wrflist
            if x.endswith(timeSteps) ]
        ds = ds_list[0]

        ds[variable_name] = (['bottom_top', 'south_north', 'west_east'], 
                                 variable)

        location = rotated_ll_to_xy(ncfile, latitude, longitude)
        ds_point = ds.sel(west_east=location[0], south_north=location[1])
        variable_point_i = ds_point[variable_name]
        variables_point.append(variable_point_i)
        
        location2 = rotated_ll_to_xy(ncfile, latitude-0.05, longitude-0.05)
        ds_point2 = ds.sel(west_east=location2[0], south_north=location2[1])
        variable_point_i2 = ds_point2[variable_name]
        variables_point2.append(variable_point_i2)
        
        location3 = rotated_ll_to_xy(ncfile, latitude, longitude-0.05)
        ds_point3 = ds.sel(west_east=location3[0], south_north=location3[1])
        variable_point_i3 = ds_point3[variable_name]
        variables_point3.append(variable_point_i3)
        
        location4 = rotated_ll_to_xy(ncfile, latitude+0.05, longitude-0.05)
        ds_point4 = ds.sel(west_east=location4[0], south_north=location4[1])
        variable_point_i4 = ds_point4[variable_name]
        variables_point4.append(variable_point_i4)
        
        location5 = rotated_ll_to_xy(ncfile, latitude-0.05, longitude)
        ds_point5 = ds.sel(west_east=location5[0], south_north=location5[1])
        variable_point_i5 = ds_point5[variable_name]
        variables_point5.append(variable_point_i5)
        
        location6 = rotated_ll_to_xy(ncfile, latitude+0.05, longitude)
        ds_point6 = ds.sel(west_east=location6[0], south_north=location6[1])
        variable_point_i6 = ds_point6[variable_name]
        variables_point6.append(variable_point_i6)
        
        location7 = rotated_ll_to_xy(ncfile, latitude-0.05, longitude+0.05)
        ds_point7 = ds.sel(west_east=location7[0], south_north=location7[1])
        variable_point_i7 = ds_point7[variable_name]
        variables_point7.append(variable_point_i7)
        
        location8 = rotated_ll_to_xy(ncfile, latitude, longitude+0.05)
        ds_point8 = ds.sel(west_east=location8[0], south_north=location8[1])
        variable_point_i8 = ds_point8[variable_name]
        variables_point8.append(variable_point_i8)
        
        location9 = rotated_ll_to_xy(ncfile, latitude+0.05, longitude+0.05)
        ds_point9 = ds.sel(west_east=location9[0], south_north=location9[1])
        variable_point_i9 = ds_point9[variable_name]
        variables_point9.append(variable_point_i9)
        
        height = getvar(ncfile, 'z')
        height_point = height.sel(west_east=location[0], south_north=location[1])

    variables_point_con = xr.concat(variables_point, dim='time')
    variables_point_con2 = xr.concat(variables_point, dim='time')
    variables_point_con3 = xr.concat(variables_point, dim='time')
    variables_point_con4 = xr.concat(variables_point, dim='time')
    variables_point_con5 = xr.concat(variables_point, dim='time')
    variables_point_con6 = xr.concat(variables_point, dim='time')
    variables_point_con7 = xr.concat(variables_point, dim='time')
    variables_point_con8 = xr.concat(variables_point, dim='time')
    variables_point_con9 = xr.concat(variables_point, dim='time')
    
    variables_point_sum = (variables_point_con+variables_point_con2+variables_point_con3+ \
                           variables_point_con4+variables_point_con5+variables_point_con6+ \
                           variables_point_con7+variables_point_con8+variables_point_con9)/9
    
    variables_point_sum_ds = xr.Dataset(data_vars={variable_name: variables_point_sum,
                                              'height': height_point})
    
    print('------ Start plotting process ------')
    fig = plt.figure(figsize=(16,9))
    ax = plt.axes()
       
    x, y = np.mgrid[:len(variables_point_sum_ds['time'].values), 
                    :len(variables_point_sum_ds['height'].values)]
        
    if variable_name == 'omega': 
        def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
            new_cmap = colors.LinearSegmentedColormap.from_list(
                'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, 
                                                    b=maxval),
                cmap(np.linspace(minval, maxval, n)))
            return new_cmap

        old_cmap = plt.get_cmap('RdYlBu')
        cmap = truncate_colormap(old_cmap, 0.05, 0.9)

        norm = colors.DivergingNorm(vmin=variable_min, vcenter=-1, vmax=variable_max)
    
    elif variable_name == 'theta_e':        
        cmap = plt.get_cmap('RdYlBu_r')
    
    else:
        cmap = ListedColormap(sns.cubehelix_palette(10, start=.5, rot=-.75))
    
    levels = np.linspace(variable_min, variable_max, 11)
    
    if variable_name == 'rh':
        extend = 'neither'
    
    elif variable_name == 'wind_speed':
        extend = 'max'
    
    else: 
        extend = 'both'
        
    if variable_name == 'omega':
        cmap = plt.get_cmap('PRGn')
        variable_plot = plt.contourf(x, y, variables_point_sum_ds[variable_name],
                                        levels=levels, cmap=cmap, 
                                        extend=extend)#, norm=norm)
    
    else:
        variable_plot = plt.contourf(x, y, variables_point_sum_ds[variable_name], 
                                    levels=levels, cmap=cmap, extend=extend)

    cbar = plt.colorbar(shrink=.95, aspect=20)
    cbar.set_label(colorbar_label, fontsize=20)
    cbar.set_ticks(levels)
    cbar.ax.tick_params(labelsize=15)

    ax.set_ylabel('Height [$km$]', fontsize=20)
    
    def numfmt(x, pos):
        s = '{}'.format(round(x / 12, 1))
        return s

    ax.set_xticks(np.arange(0,102.111,6))
    ax.set_xticklabels(np.arange(-0.5,8.111,0.5))
    
    labels = [item.get_text() for item in ax.get_xticklabels()]
    labels[0] = '+0.5'    
    labels[1] = '0'
    ax.set_xticklabels(labels)
    
    yfmt = tkr.FuncFormatter(numfmt)
    #new_tick_labels = plt.gca().xaxis.set_major_formatter(yfmt)
    #plt.xticks(np.arange(0,8.111,0.5), np.arange(0,8.111,0.5))
    #ax.set_xlim(0,8)
    
    if variable_name == 'rh':
        vline_color = 'w'
    else:
        vline_color = 'k'
        
    plt.axvline(x=0, color=vline_color, linestyle='--')        
    plt.axvline(x=6, color=vline_color, linestyle='--')
    plt.axvline(x=12, color=vline_color, linestyle='--')

    ax.tick_params(axis='x', which='minor')
    ax.minorticks_on()
    #xtick_labels = date_range_str
    #plt.xticks(np.arange(0, 10, 1), np.arange(0, (periods*5)/60,1))
    
    ax.set_xlabel('Time Before Initiation [$h$]', fontsize=20)
    
    plt.gca().invert_xaxis()
    
    # Set y-ticks and limit the height
    ytick_labels = [ round(label/1000, 1) for label in variables_point_sum_ds['height'].values ]
    plt.yticks(np.arange(0, 50, 2), ytick_labels[::2])
    ax.set_ylim(0.5, 25)
    
    ax.grid(axis='y', linestyle='--', color='grey')
    #for y in ax.yaxis.get_minorticklocs():
    #      y.set_visible(False)
    

    #ax.set_title('Temporal Evolution of {}'.format(title_name), loc='left', fontsize=15)
    #ax.set_title('Initiation Time: {} UTC'.format(time_title), loc='right', fontsize=12.5)

    ax.tick_params(axis='both', which='major', labelsize=17.5, length=8, width=1.5)
    ax.tick_params(axis='both', which='minor', length=4, width=1)
    
    plt.show()
    
    if save == True:
        fig.savefig('/scratch3/thomasl/work/retrospective_part/' \
                '{}/lidar/lidar_{}_{}.png'.format(case_study_name, variable_name, periods),
                    bbox_inches='tight', dpi=300)

In [None]:
temporal_evolution_variable(variable_name='rh', periods=102, case_study_name='case_study_1', save=True)