## Hurricane Florence PyDDA retrieval

In [1]:
import warnings
warnings.filterwarnings("ignore")
import glob
import pyart
import pydda
import datetime
import cartopy.crs as ccrs
import os
import dask.bag as db
import gc
from boto.s3.connection import S3Connection

from scipy.interpolate import NearestNDInterpolator, LinearNDInterpolator
%pylab inline


## You are using the Python ARM Radar Toolkit (Py-ART), an open source
## library for working with weather radar data. Py-ART is partly
## supported by the U.S. Department of Energy as part of the Atmospheric
## Radiation Measurement (ARM) Climate Research Facility, an Office of
## Science user facility.
##
## If you use this software to prepare a publication, please cite:
##
##     JJ Helmus and SM Collis, JORS 2016, doi: 10.5334/jors.119

Populating the interactive namespace from numpy and matplotlib


In [None]:
florence_path = '/lcrc/group/earthscience/radar/florance/'
ltx_list = sorted(glob.glob(florence_path + '/**/KLTX*V06.ar2v', recursive=True))
mhx_list = sorted(glob.glob(florence_path + '/**/KMHX*V06.ar2v', recursive=True))

other_rads_path = '/lcrc/group/earthscience/rjackson/florence/'
cae_list = sorted(glob.glob(other_rads_path + '/**/KLTX*V06', recursive=True))
clx_list = sorted(glob.glob(other_rads_path + '/**/KCLX*V06', recursive=True))
fcx_list = sorted(glob.glob(other_rads_path + '/**/KFCX*V06', recursive=True))
gsp_list = sorted(glob.glob(other_rads_path + '/**/KGSP*V06', recursive=True))
rax_list = sorted(glob.glob(other_rads_path + '/**/KRAX*V06', recursive=True))

print(len(gsp_list))

In [None]:
print(mhx_list[0])

def parse_dt(file_path):
    return (datetime.datetime.strptime(file_path[-24:-9], '%Y%m%d_%H%M%S'))

def parse_dt_no_ext(file_path):
    return (datetime.datetime.strptime(file_path[-19:-4], '%Y%m%d_%H%M%S'))
mhx_times = np.array([parse_dt(x) for x in ltx_list])
ltx_times = np.array([parse_dt(x) for x in mhx_list])
cae_times = np.array([parse_dt_no_ext(x) for x in cae_list])
clx_times = np.array([parse_dt_no_ext(x) for x in clx_list])
fcx_times = np.array([parse_dt_no_ext(x) for x in fcx_list])
gsp_times = np.array([parse_dt_no_ext(x) for x in gsp_list])
rax_times = np.array([parse_dt_no_ext(x) for x in rax_list])

In [None]:
the_time = datetime.datetime(2018,9,14,6,50)
the_ind_mhx = np.argmin(np.abs(mhx_times-the_time))
the_ind_ltx = np.argmin(np.abs(ltx_times-the_time))
mhx_radar = pyart.io.read(ltx_list[the_ind_mhx])
ltx_radar = pyart.io.read(mhx_list[the_ind_ltx])
gf_mhx = pyart.filters.GateFilter(mhx_radar)
gf_mhx.exclude_below('cross_correlation_ratio', 0.5)
gf_mhx.exclude_below('reflectivity', -20)
gf_ltx = pyart.filters.GateFilter(ltx_radar)
gf_ltx.exclude_below('cross_correlation_ratio', 0.5)
gf_ltx.exclude_below('reflectivity', -20)

In [None]:
display_mhgx = pyart.graph.RadarMapDisplay(mhx_radar)
display_mhgx.plot_ppi_map('reflectivity', resolution='l', gatefilter=gf_mhx)

In [None]:
display_ltx = pyart.graph.RadarMapDisplay(ltx_radar)
display_ltx.plot_ppi_map('reflectivity', resolution='l', gatefilter=gf_ltx)

In [None]:
plt.figure(figsize=(9,9))
display_mhgx.plot_ppi_map('velocity', sweep=1, resolution='l')

In [None]:
plt.figure(figsize=(9,9))
display_ltx = pyart.graph.RadarMapDisplay(ltx_radar)
display_ltx.plot_ppi_map('velocity', sweep=1, resolution='l')

In [None]:
plt.figure(figsize=(9,9))
dealiased_vel_mhx = pyart.correct.dealias_region_based(mhx_radar)
mhx_radar.add_field('corrected_velocity', dealiased_vel_mhx, replace_existing=True)
display_mhgx.plot_ppi_map('corrected_velocity', sweep=1, resolution='l')

In [None]:
plt.figure(figsize=(9,9))
dealiased_vel_ltx = pyart.correct.dealias_region_based(ltx_radar)
ltx_radar.add_field('corrected_velocity', dealiased_vel_ltx, replace_existing=True)
display_ltx.plot_ppi_map('corrected_velocity', sweep=1, resolution='l', vmin=-100, vmax=100)

In [None]:
mhx_radar.fields.keys()

In [None]:
grid_mhx = pyart.map.grid_from_radars(mhx_radar,(31,351,401),
                   ((0.,15000.),(-100000.,200000.),(-150000.,300000.)),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0.,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

grid_ltx = pyart.map.grid_from_radars(ltx_radar,(31,351,401),
                   ((0.,15000.),(-100000.,200000.),(-150000.,300000.)),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

pyart.io.write_grid('grid_mhx.nc', grid_mhx)
pyart.io.write_grid('grid_ltx.nc', grid_ltx)

In [None]:
grid_mhx = pyart.io.read_grid('grid_mhx.nc')
grid_ltx = pyart.io.read_grid('grid_ltx.nc')

In [None]:
plt.figure(figsize=(10,4))
grid_disp_mhx = pyart.graph.GridMapDisplay(grid_mhx)
grid_disp_mhx.plot_longitude_slice('corrected_velocity', lat=-78)

In [None]:
grid_disp_mhx.plot_grid('corrected_velocity', level=5)

In [None]:
u_init, v_init, w_init = pydda.initialization.make_constant_wind_field(grid_mhx, (0.0, 0.0, 0.0))
out_grids = pydda.retrieval.get_dd_wind_field([grid_mhx, grid_ltx], u_init, v_init, w_init, Co=1.0, Cm=100.0,
                                             mask_outside_opt=True, vel_name='corrected_velocity')

In [None]:
pyart.io.write_grid('grid0.nc', out_grids[0])
pyart.io.write_grid('grid1.nc', out_grids[1])

In [None]:
out_grids[0].projection_proj

In [None]:
out_grids = [pyart.io.read_grid('grid0.nc'),
             pyart.io.read_grid('grid1.nc')]

In [None]:
fig = plt.figure(figsize=(15,10)) 
ax = plt.axes(projection=ccrs.PlateCarree())
out_grids[1].fields['u']['data'] = np.ma.masked_where(np.logical_or(out_grids[0].fields['corrected_velocity']['data'].mask,
                                                                    out_grids[1].fields['corrected_velocity']['data'].mask),
                                                      out_grids[1].fields['u']['data'])
out_grids[1].fields['v']['data'] = np.ma.masked_where(np.logical_or(out_grids[0].fields['corrected_velocity']['data'].mask,
                                                                    out_grids[1].fields['corrected_velocity']['data'].mask),
                                                      out_grids[1].fields['v']['data'])
ax = pydda.vis.plot_horiz_xsection_barbs_map(out_grids, ax=ax, bg_grid_no=-1, level=3, barb_spacing_x_km=20.0,
                                             barb_spacing_y_km=20.0)

plt.title(out_grids[0].time['units'][13:] + ' winds at 1.5 km')

In [None]:
grid_array = np.ma.stack([x.fields['reflectivity']['data'] for x in out_grids])
plt.imshow(grid_array.max(axis=0)[3])

### Load historical HRRR data

In [None]:
import cfgrib

In [None]:
hrrr_data_path = '/lcrc/group/earthscience/rjackson/florence_hrrr/20180914/hrrr.t06z.wrfprsf00.grib2'
the_grib = cfgrib.Dataset.from_path(hrrr_data_path, filter_by_keys={'typeOfLevel': 'isobaricInhPa'})

In [None]:
the_grib.variables.keys()

In [None]:
grb_u = the_grib.variables['u']
grb_v = the_grib.variables['v']
gh = the_grib.variables['gh']

lat = the_grib.variables['latitude'].data[:,:]
lon = the_grib.variables['longitude'].data[:,:]
lon[lon > 180] = lon[lon>180]-360
print(lon.shape)

In [None]:
EARTH_MEAN_RADIUS = 6.3781e6
gh = gh.data[:,:,:]
height = (EARTH_MEAN_RADIUS*gh)/(EARTH_MEAN_RADIUS-gh)

In [None]:
the_grib.variables.keys()

In [None]:
from scipy.interpolate import griddata
ax = plt.axes(projection=ccrs.PlateCarree())
u = grb_u.data[1,:,:]
# need to shift data grid longitudes from (0..360) to (-180..180)
v = grb_v.data[0,:,:]
# need to shift data grid longitudes from (0..360) to (-180..180)

r = the_grib.variables['gh']
u = u[:,:]
p = ax.pcolormesh(lon, lat, u, transform=ccrs.PlateCarree())
ax.coastlines(resolution='10m')
ax.set_xlim([-80, -75])
ax.set_ylim([33, 36])
plt.colorbar(p)

## Grid the HRRR data onto the analysis grid

In [None]:
# We do not need the entire box, just the radar domain
radar_grid_lat = out_grids[0].point_latitude['data']
radar_grid_lon = out_grids[0].point_longitude['data']
radar_grid_alt = out_grids[0].point_z['data']
lat_min = radar_grid_lat.min()
lat_max = radar_grid_lat.max()
lon_min = radar_grid_lon.min()
lon_max = radar_grid_lon.max()
lon_r = np.tile(lon, (height.shape[0],1,1))
lat_r = np.tile(lat, (height.shape[0],1,1))
lon_flattened = lon_r.flatten()
lat_flattened = lat_r.flatten()
height_flattened = gh.flatten()
the_box = np.where(np.logical_and.reduce((lon_flattened >= lon_min, lat_flattened >= lat_min,
                                          lon_flattened <= lon_max, lat_flattened <= lat_max)))[0]

lon_flattened = lon_flattened[the_box]
lat_flattened = lat_flattened[the_box]
height_flattened = height_flattened[the_box]

u_flattened = grb_u.data[:,:,:].flatten()
u_flattened = u_flattened[the_box]
u_interp = NearestNDInterpolator((height_flattened, lat_flattened, lon_flattened), u_flattened, rescale=True)
u_new = u_interp(radar_grid_alt, radar_grid_lat, radar_grid_lon)

In [None]:
print(u_new.shape)
u_dict = {'data': u_new, 'long_name': "U from HRRR", 'units': "m/s"}
out_grids[0].add_field("U_hrrr", u_dict, replace_existing=True)

In [None]:
disp = pyart.graph.GridMapDisplay(out_grids[0])
disp.plot_grid('U_hrrr', level=2)

In [None]:
plt.pcolormesh(u_new[2])

In [None]:
np.where(np.logical_and.reduce((lon_r.flatten() >= lon_min,
                                          lon_r.flatten() <= lon_max)))[0]


In [None]:
print(lon_max)

In [None]:
out_grids

## Now do the retrieval with HRRR data, dude!

In [None]:
grid_mhx = pyart.io.read_grid('grid_mhx.nc')
grid_ltx = pyart.io.read_grid('grid_ltx.nc')
grid_mhx = pydda.initialization.add_hrrr_constraint_to_grid(grid_mhx,
    '/lcrc/group/earthscience/rjackson/florence_hrrr/20180914/hrrr.t06z.wrfprsf00.grib2')
disp = pyart.graph.GridMapDisplay(grid_mhx)
disp.plot_grid('U_hrrr', level=2)

In [None]:
u_init, v_init, w_init = pydda.initialization.make_constant_wind_field(grid_mhx, (0.0, 0.0, 0.0))
out_grids = pydda.retrieval.get_dd_wind_field([grid_mhx, grid_ltx], u_init, v_init, w_init, Co=10.0, Cm=50.0,
                                              Cmod=0.0, mask_outside_opt=True, vel_name='corrected_velocity',
                                               
                                              )

In [None]:
fig = plt.figure(figsize=(15,10)) 
ax = plt.axes(projection=ccrs.PlateCarree())
#out_grids[1].fields['u']['data'] = np.ma.masked_where(np.logical_or(out_grids[0].fields['corrected_velocity']['data'].mask,
#                                                                    out_grids[1].fields['corrected_velocity']['data'].mask),
#                                                      out_grids[1].fields['u']['data'])
#out_grids[1].fields['v']['data'] = np.ma.masked_where(np.logical_or(out_grids[0].fields['corrected_velocity']['data'].mask,
#                                                                    out_grids[1].fields['corrected_velocity']['data'].mask),
#                                                      out_grids[1].fields['v']['data'])
ax = pydda.vis.plot_horiz_xsection_barbs_map(out_grids, ax=ax, bg_grid_no=-1, level=1, barb_spacing_x_km=20.0,
                                             barb_spacing_y_km=20.0)

plt.title(out_grids[0].time['units'][13:] + ' winds at 0.5 km')

In [None]:
max_w = out_grids[1].fields['u']['data'][2]
plt.contourf(max_w)
plt.colorbar()

## Hey there, it's time for Dask-jobqueue!

In [None]:
import dask_jobqueue
import dask.bag as db

In [None]:
out_img_path = '/lcrc/group/earthscience/rjackson/florence_winds/png/'
out_grid_path = '/lcrc/group/earthscience/rjackson/florence_winds/grids/'

def make_retrieved_grid(the_time, ltx_list, mhx_list, do_hrrr=True):
    out_grid_dir = (out_grid_path + '/' + "%04d" % the_time.year +
                   "%02d" % the_time.month +"%02d" % the_time.day + '/')
    out_img_dir = (out_img_path + '/' + "%04d" % the_time.year +
                   "%02d" % the_time.month +"%02d" % the_time.day + '/')
    if(not os.path.isdir((out_img_dir))):
        os.makedirs(out_img_dir)
    if(not os.path.isdir((out_grid_dir))):
        os.makedirs(out_grid_dir)
    if(do_hrrr == True):
        out_grid_mhx_file_path = (out_grid_dir + '05kmwinds_gridmhx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_grid_ltx_file_path = (out_grid_dir + '05kmwinds_gridltx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_img_file_path = (out_img_dir + '05kmwinds' + "%04d" % the_time.year + 
                         "%02d" % the_time.month + "%02d" % the_time.day + '.' +
                         "%02d" % the_time.hour + "%02d" % the_time.minute + '.png')
    else:
        out_grid_mhx_file_path = (out_grid_dir + '05kmwinds_gridmhx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc')
        out_grid_ltx_file_path = (out_grid_dir + '05kmwinds_gridltx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc') 
        out_img_file_path = (out_img_dir + '05kmwinds' + "%04d" % the_time.year + 
                         "%02d" % the_time.month + "%02d" % the_time.day + '.' +
                         "%02d" % the_time.hour + "%02d" % the_time.minute + '.nohrr.png')
        
    if(os.path.isfile(out_grid_mhx_file_path) and os.path.isfile(out_grid_ltx_file_path)):
        return
    
    print("## Loading data...")
    the_ind_mhx = np.argmin(np.abs(mhx_times-the_time))
    the_ind_ltx = np.argmin(np.abs(ltx_times-the_time))
    if(np.abs(mhx_times[the_ind_mhx]-ltx_times[the_ind_ltx]) > datetime.timedelta(minutes=5)):
        print("No simultaneous coverage!")
        return
    try:
        mhx_radar = pyart.io.read(ltx_list[the_ind_mhx])
        ltx_radar = pyart.io.read(mhx_list[the_ind_ltx])
    except:
        print(str(the_time) + " Failed!")
        return
        
    gf_mhx = pyart.filters.GateFilter(mhx_radar)
    gf_mhx.exclude_below('cross_correlation_ratio', 0.5)
    gf_mhx.exclude_below('reflectivity', -20)
    gf_ltx = pyart.filters.GateFilter(ltx_radar)
    gf_ltx.exclude_below('cross_correlation_ratio', 0.5)
    gf_ltx.exclude_below('reflectivity', -20)

    print("## Dealiasing...")
    # Dealias
    try:
        dealiased_vel_mhx = pyart.correct.dealias_region_based(mhx_radar)
        mhx_radar.add_field('corrected_velocity', dealiased_vel_mhx, replace_existing=True) 
        dealiased_vel_ltx = pyart.correct.dealias_region_based(ltx_radar)
        ltx_radar.add_field('corrected_velocity', dealiased_vel_ltx, replace_existing=True)  
    except KeyError:
        print("No velocity information available!")
        return
    
    print("## Gridding...")
    # Grid
    grid_mhx = pyart.map.grid_from_radars(mhx_radar,(31,351,401),
                   ((0.,15000.),(-100000.,200000.),(-150000.,300000.)),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0.,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

    grid_ltx = pyart.map.grid_from_radars(ltx_radar,(31,351,401),
                   ((0.,15000.),(-100000.,200000.),(-150000.,300000.)),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

    # Get HRRR data from nearest hour
    if(do_hrrr == True):
        print("## Processing HRRR data...")
        hrrr_date = datetime.datetime(the_time.year, the_time.month, the_time.day, the_time.hour)
        if(the_time.minute > 30):
            hrrr_date += datetime.timedelta(hours=1)
    
        hrrr_path = ('/lcrc/group/earthscience/rjackson/florence_hrrr/' + 
                     "%04d" % hrrr_date.year +
                     "%02d" % hrrr_date.month +
                     "%02d" % hrrr_date.day +
                     '/hrrr.t' + "%02d" % hrrr_date.hour  + 'z.wrfprsf00.grib2')
        grid_mhx = pydda.initialization.add_hrrr_constraint_to_grid(grid_mhx,
            hrrr_path)
        Cmod = 5e-6
        model_fields=["hrrr"]
    else:
        Cmod = 0.0
        model_fields=None
        
    print("## Running PyDDA...")
    u_init, v_init, w_init = pydda.initialization.make_constant_wind_field(grid_mhx, (0.0, 0.0, 0.0))
    out_grids = pydda.retrieval.get_dd_wind_field([grid_mhx, grid_ltx], u_init, v_init, w_init, Co=10.0, Cm=50.0,
                                              Cmod=Cmod, mask_outside_opt=True, vel_name='corrected_velocity',
                                              model_fields=model_fields
                                              )
    print('## Making plot..')
    fig = plt.figure(figsize=(15,10)) 
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax = pydda.vis.plot_horiz_xsection_barbs_map(out_grids, ax=ax, bg_grid_no=-1, level=1, barb_spacing_x_km=20.0,
                                             barb_spacing_y_km=20.0)

    plt.title(out_grids[0].time['units'][13:] + ' winds at 0.5 km')
    print("## Saving plot...")
       
    plt.savefig(out_img_file_path)
 
    pyart.io.write_grid(out_grid_mhx_file_path, out_grids[0])
    pyart.io.write_grid(out_grid_ltx_file_path, out_grids[1])
    del out_grids, grid_mhx, grid_ltx, mhx_radar, ltx_radar, u_init, v_init, w_init
    gc.collect()
    
    
def make_retrieved_grid_only_hrrr(the_time, ltx_list, mhx_list, do_hrrr=True):
    out_grid_dir = (out_grid_path + '/' + "%04d" % the_time.year +
                   "%02d" % the_time.month +"%02d" % the_time.day + '/')
    out_img_dir = (out_img_path + '/' + "%04d" % the_time.year +
                   "%02d" % the_time.month +"%02d" % the_time.day + '/')
    if(not os.path.isdir((out_img_dir))):
        os.makedirs(out_img_dir)
    if(not os.path.isdir((out_grid_dir))):
        os.makedirs(out_grid_dir)
    if(do_hrrr == True):
        out_grid_mhx_file_path = (out_grid_dir + '05kmwinds_gridmhx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.onlyhrrr.nc') 
        out_grid_ltx_file_path = (out_grid_dir + '05kmwinds_gridltx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.onlyhrrr.nc') 
        out_img_file_path = (out_img_dir + '05kmwinds' + "%04d" % the_time.year + 
                         "%02d" % the_time.month + "%02d" % the_time.day + '.' +
                         "%02d" % the_time.hour + "%02d" % the_time.minute + '.onlyhrrr.png')
    else:
        out_grid_mhx_file_path = (out_grid_dir + '05kmwinds_gridmhx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'hrrronly.nc')
        out_grid_ltx_file_path = (out_grid_dir + '05kmwinds_gridltx' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'hrrronly.nc') 
        out_img_file_path = (out_img_dir + '05kmwinds' + "%04d" % the_time.year + 
                         "%02d" % the_time.month + "%02d" % the_time.day + '.' +
                         "%02d" % the_time.hour + "%02d" % the_time.minute + '.nohrr.png')
        
    if(os.path.isfile(out_grid_mhx_file_path) and os.path.isfile(out_grid_ltx_file_path)):
        return
    
    print("## Loading data...")
    the_ind_mhx = np.argmin(np.abs(mhx_times-the_time))
    the_ind_ltx = np.argmin(np.abs(ltx_times-the_time))
    if(np.abs(mhx_times[the_ind_mhx]-ltx_times[the_ind_ltx]) > datetime.timedelta(minutes=5)):
        print("No simultaneous coverage!")
        return
    try:
        mhx_radar = pyart.io.read(ltx_list[the_ind_mhx])
        ltx_radar = pyart.io.read(mhx_list[the_ind_ltx])
    except:
        print(str(the_time) + " Failed!")
        return
        
    gf_mhx = pyart.filters.GateFilter(mhx_radar)
    gf_mhx.exclude_below('cross_correlation_ratio', 0.5)
    gf_mhx.exclude_below('reflectivity', -20)
    gf_ltx = pyart.filters.GateFilter(ltx_radar)
    gf_ltx.exclude_below('cross_correlation_ratio', 0.5)
    gf_ltx.exclude_below('reflectivity', -20)

    print("## Dealiasing...")
    # Dealias
    try:
        dealiased_vel_mhx = pyart.correct.dealias_region_based(mhx_radar)
        mhx_radar.add_field('corrected_velocity', dealiased_vel_mhx, replace_existing=True) 
        dealiased_vel_ltx = pyart.correct.dealias_region_based(ltx_radar)
        ltx_radar.add_field('corrected_velocity', dealiased_vel_ltx, replace_existing=True)  
    except KeyError:
        print("No velocity information available!")
        return
    
    print("## Gridding...")
    # Grid
    grid_mhx = pyart.map.grid_from_radars(mhx_radar,(31,351,401),
                   ((0.,15000.),(-100000.,200000.),(-150000.,300000.)),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0.,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

    grid_ltx = pyart.map.grid_from_radars(ltx_radar,(31,351,401),
                   ((0.,15000.),(-100000.,200000.),(-150000.,300000.)),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

    # Get HRRR data from nearest hour
    if(do_hrrr == True):
        print("## Processing HRRR data...")
        hrrr_date = datetime.datetime(the_time.year, the_time.month, the_time.day, the_time.hour)
        if(the_time.minute > 30):
            hrrr_date += datetime.timedelta(hours=1)
    
        hrrr_path = ('/lcrc/group/earthscience/rjackson/florence_hrrr/' + 
                     "%04d" % hrrr_date.year +
                     "%02d" % hrrr_date.month +
                     "%02d" % hrrr_date.day +
                     '/hrrr.t' + "%02d" % hrrr_date.hour  + 'z.wrfprsf00.grib2')
        grid_mhx = pydda.initialization.add_hrrr_constraint_to_grid(grid_mhx,
            hrrr_path)
        Cmod = 5e-6
        model_fields=["hrrr"]
    else:
        Cmod = 0.0
        model_fields=None
        
    print("## Running PyDDA...")
    u_init, v_init, w_init = pydda.initialization.make_constant_wind_field(grid_mhx, (0.0, 0.0, 0.0))
    out_grids = pydda.retrieval.get_dd_wind_field([grid_mhx, grid_ltx], u_init, v_init, w_init, Co=0.0, Cm=0.0,
                                              Cmod=1e-3, mask_outside_opt=True, vel_name='corrected_velocity',
                                              model_fields=model_fields
                                              )
    print('## Making plot..')
    fig = plt.figure(figsize=(15,10)) 
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax = pydda.vis.plot_horiz_xsection_barbs_map(out_grids, ax=ax, bg_grid_no=-1, level=1, barb_spacing_x_km=20.0,
                                             barb_spacing_y_km=20.0)

    plt.title(out_grids[0].time['units'][13:] + ' winds at 0.5 km')
    print("## Saving plot...")
       
    plt.savefig(out_img_file_path)
 
    pyart.io.write_grid(out_grid_mhx_file_path, out_grids[0])
    pyart.io.write_grid(out_grid_ltx_file_path, out_grids[1])
    del out_grids, grid_mhx, grid_ltx, mhx_radar, ltx_radar, u_init, v_init, w_init
    gc.collect()

In [None]:
make_retrieved_grid(ltx_times[241], ltx_list, mhx_list)

In [None]:
print(len(mhx_times))

In [None]:
from dask_jobqueue import SLURMCluster
cluster = SLURMCluster(cores=2, project='rainfall', walltime='2:00:00', 
                       processes=2, memory='128GB')

cluster.scale(8)         # Ask for ten workers

from dask.distributed import Client
client = Client(cluster)  # Connect this local process to remote workers

# wait for jobs to arrive, depending on the queue, this may take some time

import dask.array as da
from distributed import wait

In [None]:
client

In [None]:
cluster.stop_all_jobs()

In [None]:
time_inds = np.where(np.logical_and(mhx_times >= datetime.datetime(2018, 9, 14, 8, 0, 1), 
                                    mhx_times <= datetime.datetime(2018, 9, 14, 10, 0, 1)))[0]
make_grid = lambda x: make_retrieved_grid(x, ltx_list, mhx_list, True)
make_grid_no_hrrr = lambda x: make_retrieved_grid(x, ltx_list, mhx_list, False)
futures = client.map(make_grid, ltx_times[229:426])
wait(futures)

In [None]:
make_grid = lambda x: make_retrieved_grid(x, ltx_list, mhx_list, True)
make_grid_no_hrrr = lambda x: make_retrieved_grid(x, ltx_list, mhx_list, False)
make_grid_hrrr_only = lambda x: make_retrieved_grid_only_hrrr(x, ltx_list, mhx_list, True)
make_grid_hrrr_only(ltx_times[230])

In [None]:
from scipy.signal import correlate2d
grid_all_ltx = pyart.io.read_grid('/lcrc/group/earthscience/rjackson/florence_winds/grids/20180914/05kmwinds_gridltx20180914.0624.nc')
grid_all_mtx = pyart.io.read_grid('/lcrc/group/earthscience/rjackson/florence_winds/grids/20180914/05kmwinds_gridmhx20180914.0624.nc')
grid_only_hrrr  = pyart.io.read_grid('/lcrc/group/earthscience/rjackson/florence_winds/grids/20180914/05kmwinds_gridltx20180914.0624.onlyhrrr.nc')

correlation_u = correlate2d(grid_all.fields["u"]["data"][1], grid_only_hrrr.fields["u"]["data"][1])
correlation_v = correlate2d(grid_all.fields["v"]["data"][1], grid_only_hrrr.fields["v"]["data"][1])

plt.pcolormesh(correlation_u)
plt.colorbar()

In [None]:
ltx_list = sorted(glob.glob('/lcrc/group/earthscience/rjackson/florence_winds/grids/20180914/*ltx*.nc'))
mhx_list = sorted(glob.glob('/lcrc/group/earthscience/rjackson/florence_winds/grids/20180914/*mhx*.nc'))
from copy import deepcopy

for i in range(len(ltx_list)):
    ltx_grid = pyart.io.read_grid(ltx_list[i])
    mhx_grid = pyart.io.read_grid(mhx_list[i])
    fig = plt.figure(figsize=(35, 20)) 
    font = {'family' : 'normal',
            'weight' : 'bold',
            'size'   : 44}

    plt.rc('font', **font)
    ltx_grid.fields["rainfall_rate"] = deepcopy(ltx_grid.fields["reflectivity"])
    ltx_grid.fields["rainfall_rate"]["standard_name"] = "rainfall_rate"
    ltx_grid.fields["rainfall_rate"]["long_name"] = "rainfall rate"
    ltx_grid.fields["rainfall_rate"]["units"] = "mm hr-1"
    ltx_grid.fields["rainfall_rate"]["data"] = (10**(ltx_grid.fields["reflectivity"]["data"]/10)/300)**(1/1.4)
    
    mhx_grid.fields["rainfall_rate"] = deepcopy(mhx_grid.fields["reflectivity"])
    mhx_grid.fields["rainfall_rate"]["standard_name"] = "rainfall_rate"
    mhx_grid.fields["rainfall_rate"]["long_name"] = "rainfall rate"
    mhx_grid.fields["rainfall_rate"]["units"] = "mm hr-1"
    mhx_grid.fields["rainfall_rate"]["data"] = (10**(mhx_grid.fields["reflectivity"]["data"]/10)/300)**(1/1.4)
    
    the_mask = np.logical_and(ltx_grid.fields["rainfall_rate"]["data"].mask,
                              mhx_grid.fields["rainfall_rate"]["data"].mask)
    ltx_grid.fields["rainfall_rate"]["data"] = ltx_grid.fields["rainfall_rate"]["data"].filled(0)
    mhx_grid.fields["rainfall_rate"]["data"] = mhx_grid.fields["rainfall_rate"]["data"].filled(0)
    ltx_grid.fields["rainfall_rate"]["data"] = np.ma.masked_where(the_mask, 
                                                                  ltx_grid.fields["rainfall_rate"]["data"])
    mhx_grid.fields["rainfall_rate"]["data"] = np.ma.masked_where(the_mask, 
                                                                  mhx_grid.fields["rainfall_rate"]["data"])
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax = pydda.vis.plot_horiz_xsection_streamlines_map([ltx_grid, mhx_grid], ax=ax, 
                                                       background_field='rainfall_rate',
                                                       bg_grid_no=-1, level=2, 
                                                       vmin=0, vmax=50, show_lobes=False)
    wind_speed = np.sqrt(ltx_grid.fields["u"]["data"]**2 + ltx_grid.fields["v"]["data"]**2)
    wind_speed = wind_speed.filled(np.nan)
    lons = ltx_grid.point_longitude["data"]
    lats = ltx_grid.point_latitude["data"]
    cs = ax.contour(lons[2, ::4, ::4], lats[2, ::4, ::4], wind_speed[2, ::4, ::4], levels=[28, 32], 
                   linewidths=8, colors=['b', 'r', 'k'])
    plt.clabel(cs, ax=ax, inline=1, fontsize=15)
    ax.set_xticks(np.arange(-80, -75, 0.5))
    ax.set_yticks(np.arange(33, 35.8, 0.5))
    ax.set_title(ltx_grid.time["units"][-20:])
    plt.savefig(ltx_list[i][-15:] + '.png')
    del fig, ax

In [None]:
ltx_list = sorted(glob.glob('/lcrc/group/earthscience/rjackson/ddop_grids/grids/2006/20060120/berrwinds*'))
mhx_list = sorted(glob.glob('/lcrc/group/earthscience/rjackson/ddop_grids/grids/2006/20060120/cpolwinds*'))
ltx_grid = pyart.io.read_grid(ltx_list[0])
mhx_grid = pyart.io.read_grid(mhx_list[0])
fig = plt.figure(figsize=(10,5)) 
ax = plt.axes(projection=ccrs.PlateCarree())
pydda.vis.plot_horiz_xsection_barbs_map([ltx_grid, mhx_grid], ax=ax, bg_grid_no=-1, level=2, barb_spacing_x_km=5.0,
                                            barb_spacing_y_km=5.0, vmin=0, vmax=50)

In [None]:
pydda.vis.plot_horiz_xsection_barbs_map?

## Download NEXRAD level 2 data

In [None]:
#first lets connect to the bucket
conn = S3Connection(anon = True)
bucket = conn.get_bucket('noaa-nexrad-level2')

In [None]:
my_pref = '2018/09/14/KRAX/'
bucket_list = list(bucket.list(prefix = my_pref))

In [None]:
kvnx_download_path = '/lcrc/group/earthscience/rjackson/florence/'
import os

In [None]:
print(bucket_list)

In [None]:
for item in bucket_list:
    item.get_contents_to_filename(os.path.join(kvnx_download_path,item.key))

In [None]:
out_img_path = '/lcrc/group/earthscience/rjackson/florence_winds/png/'
out_grid_path = '/lcrc/group/earthscience/rjackson/florence_winds/grids/'

from scipy.interpolate import griddata
from copy import deepcopy
import os

def reduce_pyart_grid_res(Grid, skip_factor):
    Grid2 = deepcopy(Grid)
    field_dict = {}
    for field_name in Grid2.fields.keys():
        field_dict[field_name] = Grid2.fields[field_name].copy()
        field_dict[field_name]["data"] = Grid2.fields[field_name]["data"][:, ::skip_factor, ::skip_factor]
        
    x = Grid2.x
    x["data"] = x["data"][::skip_factor]
    y = Grid2.y
    y["data"] = y["data"][::skip_factor]
    z = Grid2.z
    z["data"] = z["data"]
    metadata = Grid2.metadata
    origin_latitude = Grid2.origin_latitude
    origin_longitude = Grid2.origin_longitude
    origin_altitude = Grid2.origin_altitude
    projection = Grid2.projection
    radar_latitude = Grid2.radar_latitude
    radar_longitude = Grid2.radar_longitude
    radar_altitude = Grid2.radar_altitude
    radar_time = Grid2.radar_time
    radar_name = Grid2.radar_name
    gtime = Grid2.time
    new_grid = pyart.core.Grid(gtime, field_dict, metadata, origin_latitude, origin_longitude, origin_altitude, 
                               x, y, z, projection, radar_latitude, radar_longitude, radar_altitude, 
                               radar_time, radar_name) 
    del Grid2
    return new_grid

def split_pyart_grid(Grid, split_factor, axis=1):
    grid_splits = []
    split_field = {}
    Grid2 = deepcopy(Grid)
    for field_name in Grid2.fields.keys():
        if isinstance(Grid2.fields[field_name]["data"], np.ma.MaskedArray):
            no_mask = Grid2.fields[field_name]["data"].filled(np.nan).copy() 
        else:
            no_mask = Grid2.fields[field_name]["data"].copy()
        split_field[field_name] = np.array_split(no_mask, split_factor, axis=axis)
        if isinstance(Grid2.fields[field_name]["data"], np.ma.MaskedArray):
            split_field[field_name] = [np.ma.masked_where(
                np.isnan(arr), arr) for arr in split_field[field_name]]
    x = Grid2.x
    y = Grid2.y
    z = Grid2.z
    x_split = np.array_split(x["data"], split_factor)
    y_split = np.array_split(y["data"], split_factor)
    z_split = np.array_split(z["data"], split_factor)
    gtime = Grid2.time
    metadata = Grid2.metadata
    origin_latitude = Grid2.origin_latitude
    origin_longitude = Grid2.origin_longitude
    origin_altitude = Grid2.origin_altitude
    projection = Grid2.projection
    radar_latitude = Grid2.radar_latitude
    radar_longitude = Grid2.radar_longitude
    radar_altitude = Grid2.radar_altitude
    radar_time = Grid2.radar_time
    radar_name = Grid2.radar_name
    for i in range(split_factor):
        grid_dic = {}

        for field_name in Grid2.fields.keys():
            grid_dic[field_name] = Grid2.fields[field_name].copy()
            grid_dic[field_name]["data"] = split_field[field_name][i]
        x_dic = x.copy()
        y_dic = y.copy()
        z_dic = z.copy()
        if(axis == 1):
            y_dic["data"] = y_split[i]
        elif(axis == 2):
            x_dic["data"] = x_split[i]
        elif(axis == 0):
            z_dic["data"] = z_split[i]
        
        new_grid = pyart.core.Grid(gtime, grid_dic, metadata, origin_latitude, origin_longitude, origin_altitude, 
                               x_dic, y_dic, z_dic, projection, radar_latitude, radar_longitude, radar_altitude, 
                               radar_time, radar_name) 
        grid_splits.append(new_grid)
        
        
    return grid_splits

def concatenate_pyart_grids(grid_list, axis=1):
    new_grid = deepcopy(grid_list[0])
    for field_name in new_grid.fields.keys():
        new_grid.fields[field_name]["data"] = np.ma.concatenate([x.fields[field_name]["data"] for x in grid_list], axis=axis)
    if(axis == 2):
        new_grid.x["data"] = np.ma.concatenate([x.x["data"] for x in grid_list])
        new_grid.nx = np.sum([x.nx for x in grid_list])
    elif(axis == 1):
        new_grid.y["data"] = np.ma.concatenate([x.y["data"] for x in grid_list])
        new_grid.ny = np.sum([x.ny for x in grid_list])
    elif(axis == 0):
        new_grid.z["data"] = np.ma.concatenate([x.z["data"] for x in grid_list]) 
        new_grid.nz = np.sum([x.nz for x in grid_list])
    return new_grid

# Procedure: 1. Do first pass of retrieval on reduced resolution grid
# 2. Then, we use the reduced resolution retrieval as an input to the
# high resolution retrieval in each region
# Finally, we check for continuity at the boundaries
def do_dd_wind_field_nested(grid_list, u_init, v_init, w_init, reduction_factor=2,
                            num_splits=2, **kwargs):
    """
    This function performs a wind retrieval using a nested domain. This is useful for
    grids that are larger than about 400 by 400 by 40 points, since the use of larger
    grids on a single machine will exceed memory limitations. 
    
    This procedure relies on a dask distributed cluster to be set up. The retrieval is 
    first performed at a resolution that is coarser than the analysis grid by 
    reduction_factor. This provides the initial state for the 
    
    The domain is split into num_splits**2 sub-domains for the nested retrieval step, and
    each nested retrieval is mapped onto a distributed worker for parallel processing. If
    NumPy and SciPy are already set up to use parallel numerical analysis libraries, it is 
    recommended that a single machine be dedicated to each nest rather than a single core
    for best peformance.
    
    Parameters
    ==========
    
    
    
    
    **kwargs: dict
        This function will take the same keyword arguments as get_dd_wind_field, as these 
        arguments are passed into each call of get_dd_wind_field.
    
    """
    # First, we do retrieval on whole grid with fraction of resolution
    grid_lo_res_list = [reduce_pyart_grid_res(G, reduction_factor) for G in grid_list]
    
    first_pass = pydda.retrieval.get_dd_wind_field(grid_lo_res_list, 
                                                   u_init[::, ::reduction_factor, ::reduction_factor], 
                                                   v_init[::, ::reduction_factor, ::reduction_factor], 
                                                   w_init[::, ::reduction_factor, ::reduction_factor], **kwargs)
    
    # Take the first pass field and regrid to analysis field
    reduced_x = first_pass[0].point_x["data"].flatten()
    reduced_y = first_pass[0].point_y["data"].flatten()
    reduced_z = first_pass[0].point_z["data"].flatten()
    x = grid_list[0].point_x["data"].flatten()
    y = grid_list[0].point_y["data"].flatten()
    z = grid_list[0].point_z["data"].flatten()
    u_init_new = griddata((reduced_z, reduced_y, reduced_x),
                                             first_pass[0].fields["u"]["data"].flatten(), 
                                             (z, y, x), method='nearest')
    v_init_new = griddata((reduced_z, reduced_y, reduced_x),
                                             first_pass[0].fields["v"]["data"].flatten(), 
                                             (z, y, x), method='nearest')
    w_init_new = griddata((reduced_z, reduced_y, reduced_x),
                                             first_pass[0].fields["w"]["data"].flatten(), 
                                             (z, y, x), method='nearest')
    u_init_new = np.reshape(u_init_new, u_init.shape)
    v_init_new = np.reshape(v_init_new, v_init.shape)
    w_init_new = np.reshape(w_init_new, w_init.shape)
    
    # Finally, split the analysis into num_splits**2 pieces and save
    # as temporary files
    tempfile_name_base = datetime.datetime.now().strftime('%y%m%d.%H%M%S')
    tiny_grids = []
    k = 0
    for G in grid_list:
        cur_list = []
        split_grids_x = split_pyart_grid(G, num_splits, axis=2)
        i = 0
        for sgrid in split_grids_x:
            g_list = split_pyart_grid(sgrid, num_splits)
            grid_fns = []
            j = 0
            for g in g_list:
                fn = tempfile_name_base + str(k) + '.' + str(i) + '.' + str(j) + '.nc'
                pyart.io.write_grid(tempfile_name_base + str(k) + '.' + str(i) + '.' + str(j) + '.nc',
                                    g)
                j = j + 1
                grid_fns.append(fn)
            cur_list.append(grid_fns)
            i = i + 1
        del split_grids_x, g_list
        
        k = k + 1
        tiny_grids.append(cur_list)
    
    # Temporarily save the tiny grids and free up memory...we want to load these when
    # we are running it on the cluster
    
    u_init_split_x = np.array_split(u_init_new, num_splits, axis=2)
    u_init_split = [np.array_split(ux, num_splits, axis=1) for ux in u_init_split_x]
    w_init_split_x = np.array_split(w_init_new, num_splits, axis=2)
    w_init_split = [np.array_split(wx, num_splits, axis=1) for wx in w_init_split_x]
    v_init_split_x = np.array_split(v_init_new, num_splits, axis=2)
    v_init_split = [np.array_split(vx, num_splits, axis=1) for vx in v_init_split_x]
    
    # Clear out unneeded variables (do not need lo-res grids in memory anymore)
    del u_init_split_x, w_init_split_x, v_init_split_x
    del first_pass, reduced_x, reduced_y, reduced_z, x, y, z, grid_lo_res_list
    gc.collect()
    
    # Serial just for testing, need to use dask in future
    tiny_retrieval = []
    def do_tiny_retrieval(i,j):
        
        tgrids = [pyart.io.read_grid(tiny_grids[k][i][j]) for k in range(len(grid_list))]
        print(tgrids)
        new_grids = pydda.retrieval.get_dd_wind_field(tgrids, u_init_split[i][j], v_init_split[i][j], 
                                                 w_init_split[i][j], **kwargs)
        del tgrids
        gc.collect()
        return new_grids
    
    futures_array = []
    for i in range(num_splits):
        for j in range(num_splits):
            futures_array.append(client.submit(do_tiny_retrieval, i, j))
    
    print("Waiting for nested grid to be retrieved...")
    wait(futures_array)
    
    
    tiny_retrieval2 = client.gather(futures_array)
    
    tiny_retrieval = []
    
    for i in range(num_splits):    
        new_grid_list = []
    
        for j in range(len(grid_list)):
            print([tiny_retrieval2[k+i*num_splits][j] for k in range(0, num_splits)])
            
            new_grid_list.append(concatenate_pyart_grids([tiny_retrieval2[k+i*num_splits][j] for k in range(0, num_splits)], 
                                                         axis=1))
        tiny_retrieval.append(new_grid_list)
    
    new_grid_list = []
    for i in range(len(grid_list)):
        new_grid_list.append(concatenate_pyart_grids([tiny_retrieval[k][i] for k in range(num_splits)], axis=2))
    
    tempfile_list = glob.glob(tempfile_name_base + "*")
    for fn in tempfile_list:
        os.remove(fn)
    return new_grid_list

    # Then just tile the pieces back together
    # Combine the split grids together into one!
             
def make_retrieved_grid_extended(the_time, ltx_list, mhx_list, cae_list,
                                 clx_list, fcx_list, rax_list, gsp_list, do_hrrr=True):
    out_grid_dir = (out_grid_path + '/' + "%04d" % the_time.year +
                   "%02d" % the_time.month +"%02d" % the_time.day + '/')
    out_img_dir = (out_img_path + '/' + "%04d" % the_time.year +
                   "%02d" % the_time.month +"%02d" % the_time.day + '/')
    if(not os.path.isdir((out_img_dir))):
        os.makedirs(out_img_dir)
    if(not os.path.isdir((out_grid_dir))):
        os.makedirs(out_grid_dir)
    if(do_hrrr == True):
        out_grid_mhx_file_path = (out_grid_dir + '05kmwinds_gridmhxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_grid_ltx_file_path = (out_grid_dir + '05kmwinds_gridltxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_grid_cae_file_path = (out_grid_dir + '05kmwinds_gridcaeext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_grid_clx_file_path = (out_grid_dir + '05kmwinds_gridclxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc')
        out_grid_fcx_file_path = (out_grid_dir + '05kmwinds_gridfcxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_grid_gsp_file_path = (out_grid_dir + '05kmwinds_gridgspext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_grid_rax_file_path = (out_grid_dir + '05kmwinds_gridraxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + '.nc') 
        out_img_file_path = (out_img_dir + '05kmwinds' + "%04d" % the_time.year + 
                         "%02d" % the_time.month + "%02d" % the_time.day + '.' +
                         "%02d" % the_time.hour + "%02d" % the_time.minute + '.png')
    else:
        out_grid_mhx_file_path = (out_grid_dir + '05kmwinds_gridmhxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc')
        out_grid_ltx_file_path = (out_grid_dir + '05kmwinds_gridltxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc') 
        out_grid_cae_file_path = (out_grid_dir + '05kmwinds_gridcaeext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc') 
        out_grid_clx_file_path = (out_grid_dir + '05kmwinds_gridclxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc')
        out_grid_fcx_file_path = (out_grid_dir + '05kmwinds_gridfcxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc') 
        out_grid_gsp_file_path = (out_grid_dir + '05kmwinds_gridgspext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc') 
        out_grid_rax_file_path = (out_grid_dir + '05kmwinds_gridraxext' + "%04d" % the_time.year + 
            "%02d" % the_time.month + "%02d" % the_time.day + '.' +
            "%02d" % the_time.hour + "%02d" % the_time.minute + 'nohrrr.nc') 
        out_img_file_path = (out_img_dir + '05kmwinds' + "%04d" % the_time.year + 
                         "%02d" % the_time.month + "%02d" % the_time.day + '.' +
                         "%02d" % the_time.hour + "%02d" % the_time.minute + '.nohrr.png')
        
    #if(os.path.isfile(out_grid_mhx_file_path) and os.path.isfile(out_grid_ltx_file_path)):
    #    return
    
    print("## Loading data...")
    the_ind_mhx = np.argmin(np.abs(mhx_times-the_time))
    the_ind_ltx = np.argmin(np.abs(ltx_times-the_time))
    the_ind_clx = np.argmin(np.abs(clx_times-the_time))
    the_ind_fcx = np.argmin(np.abs(fcx_times-the_time))
    the_ind_gsp = np.argmin(np.abs(gsp_times-the_time))
    the_ind_rax = np.argmin(np.abs(rax_times-the_time))

    if(np.abs(mhx_times[the_ind_mhx]-ltx_times[the_ind_ltx]) > datetime.timedelta(minutes=5)):
        print("No simultaneous coverage!")
        return
    try:
        mhx_radar = pyart.io.read(ltx_list[the_ind_mhx])
        ltx_radar = pyart.io.read(mhx_list[the_ind_ltx])
        clx_radar = pyart.io.read(clx_list[the_ind_clx])
        fcx_radar = pyart.io.read(fcx_list[the_ind_fcx])
        gsp_radar = pyart.io.read(gsp_list[the_ind_gsp])
        rax_radar = pyart.io.read(rax_list[the_ind_rax])
    except:
        print(str(the_time) + " Failed!")
        return
        
    gf_mhx = pyart.filters.GateFilter(mhx_radar)
    gf_mhx.exclude_below('cross_correlation_ratio', 0.5)
    gf_mhx.exclude_below('reflectivity', -20)
    gf_ltx = pyart.filters.GateFilter(ltx_radar)
    gf_ltx.exclude_below('cross_correlation_ratio', 0.5)
    gf_ltx.exclude_below('reflectivity', -20)
    gf_fcx = pyart.filters.GateFilter(fcx_radar)
    gf_fcx.exclude_below('cross_correlation_ratio', 0.5)
    gf_fcx.exclude_below('reflectivity', -20)
    gf_gsp = pyart.filters.GateFilter(gsp_radar)
    gf_gsp.exclude_below('cross_correlation_ratio', 0.5)
    gf_gsp.exclude_below('reflectivity', -20)
    gf_rax = pyart.filters.GateFilter(rax_radar)
    gf_rax.exclude_below('cross_correlation_ratio', 0.5)
    gf_rax.exclude_below('reflectivity', -20)
    gf_clx = pyart.filters.GateFilter(clx_radar)
    gf_clx.exclude_below('cross_correlation_ratio', 0.5)
    gf_clx.exclude_below('reflectivity', -20)
    
    print("## Dealiasing...")
    # Dealias
    try:
        dealiased_vel_mhx = pyart.correct.dealias_region_based(mhx_radar, gatefilter=gf_mhx)
        mhx_radar.add_field('corrected_velocity', dealiased_vel_mhx, replace_existing=True) 
        dealiased_vel_ltx = pyart.correct.dealias_region_based(ltx_radar, gatefilter=gf_ltx)
        ltx_radar.add_field('corrected_velocity', dealiased_vel_ltx, replace_existing=True)  
        dealiased_vel_fcx = pyart.correct.dealias_region_based(fcx_radar, gatefilter=gf_fcx)
        fcx_radar.add_field('corrected_velocity', dealiased_vel_fcx, replace_existing=True)  
        dealiased_vel_gsp = pyart.correct.dealias_region_based(gsp_radar, gatefilter=gf_gsp)
        gsp_radar.add_field('corrected_velocity', dealiased_vel_gsp, replace_existing=True) 
        dealiased_vel_rax = pyart.correct.dealias_region_based(rax_radar, gatefilter=gf_rax)
        rax_radar.add_field('corrected_velocity', dealiased_vel_rax, replace_existing=True)  
        dealiased_vel_clx = pyart.correct.dealias_region_based(clx_radar, gatefilter=gf_clx)
        clx_radar.add_field('corrected_velocity', dealiased_vel_clx, replace_existing=True)  
    except KeyError:
        print("No velocity information available!")
        return
    
    print("## Gridding...")
    # Grid
    grid_spec = (31, 1101, 1101)
    grid_z = (0., 15000.)
    grid_y = (-650000., 650000.)
    grid_x = (-650000., 650000.)
    grid_mhx = pyart.map.grid_from_radars(mhx_radar,grid_spec,
                   (grid_z, grid_y, grid_x),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0.,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

    grid_ltx = pyart.map.grid_from_radars(ltx_radar,grid_spec,
                   (grid_z, grid_y, grid_x),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )

    grid_fcx = pyart.map.grid_from_radars(fcx_radar,grid_spec,
                   (grid_z, grid_y, grid_x),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )
    grid_gsp = pyart.map.grid_from_radars(gsp_radar,grid_spec,
                   (grid_z, grid_y, grid_x),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0.,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )
    grid_rax = pyart.map.grid_from_radars(rax_radar,grid_spec,
                   (grid_z, grid_y, grid_x),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )
    grid_clx = pyart.map.grid_from_radars(clx_radar,grid_spec,
                   (grid_z, grid_y, grid_x),
                   fields=['reflectivity','corrected_velocity'],
                   refl_field='reflectivity',roi_func='dist_beam',
                   h_factor=0,nb=0.6,bsp=1.,min_radius=200.,
                   grid_origin=(mhx_radar.latitude['data'], mhx_radar.longitude['data'])
                   )
    
    # Get HRRR data from nearest hour
    if(do_hrrr == True):
        print("## Processing HRRR data...")
        hrrr_date = datetime.datetime(the_time.year, the_time.month, the_time.day, the_time.hour)
        if(the_time.minute > 30):
            hrrr_date += datetime.timedelta(hours=1)
    
        hrrr_path = ('/lcrc/group/earthscience/rjackson/florence_hrrr/' + 
                     "%04d" % hrrr_date.year +
                     "%02d" % hrrr_date.month +
                     "%02d" % hrrr_date.day +
                     '/hrrr.t' + "%02d" % hrrr_date.hour  + 'z.wrfprsf00.grib2')
        grid_mhx = pydda.constraints.add_hrrr_constraint_to_grid(grid_mhx,
            hrrr_path)
        Cmod = 5e-6
        model_fields=["hrrr", "erainterim"]
    else:
        Cmod = 0.0
        model_fields=["erainterim"]
    grid_mhx = pydda.constraints.make_constraint_from_era_interim(grid_mhx)    
    print("## Running PyDDA...")
    u_init, v_init, w_init = pydda.initialization.make_constant_wind_field(grid_mhx, (0.0, 0.0, 0.0))
    out_grids = pydda.retrieval.get_dd_wind_field_nested([grid_mhx, grid_ltx, grid_fcx, grid_gsp,
                                                  grid_rax, grid_clx], 
                                                  u_init, v_init, w_init, Co=1.0, Cm=100.0,
                                                  Cmod=Cmod, mask_outside_opt=True, vel_name='corrected_velocity',
                                                  model_fields=model_fields, client=client)
    print('## Making plot..')
    print(out_grids[0].fields["u"]["data"].shape)
    print(out_grids[0].nz)
    print(out_grids[0].ny)
    print(out_grids[0].nx)
    fig = plt.figure(figsize=(15,10)) 
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax = pydda.vis.plot_horiz_xsection_barbs_map(out_grids, ax=ax, bg_grid_no=-1, level=1, barb_spacing_x_km=60.0,
                                             barb_spacing_y_km=60.0)

    plt.title(out_grids[0].time['units'][13:] + ' winds at 0.5 km')
    print("## Saving plot...")
       
    plt.savefig(out_img_file_path)
 
    pyart.io.write_grid(out_grid_mhx_file_path, out_grids[0])
    pyart.io.write_grid(out_grid_ltx_file_path, out_grids[1])
    pyart.io.write_grid(out_grid_fcx_file_path, out_grids[2])
    pyart.io.write_grid(out_grid_gsp_file_path, out_grids[3])
    pyart.io.write_grid(out_grid_rax_file_path, out_grids[4])
    pyart.io.write_grid(out_grid_clx_file_path, out_grids[5])

    del out_grids, grid_mhx, grid_ltx,  grid_fcx, grid_gsp, grid_rax, grid_clx 
    del mhx_radar, ltx_radar,  fcx_radar, gsp_radar, rax_radar, clx_radar, u_init, v_init, w_init
    gc.collect()

In [None]:
gc.collect()
the_time = datetime.datetime(2018,9,14,6,50)
make_retrieved_grid_extended(the_time, ltx_list, mhx_list, cae_list,
                                 clx_list, fcx_list, rax_list, gsp_list, do_hrrr=True)

In [None]:
make_retrieved_grid_extended(the_time, ltx_list, mhx_list, cae_list,
                                 clx_list, fcx_list, rax_list, gsp_list)

In [None]:
def scale_bar(ax, length, location=(0.5, 0.05), linewidth=3):
    """
    ax is the axes to draw the scalebar on.
    location is center of the scalebar in axis coordinates ie. 0.5 is the middle of the plot
    length is the length of the scalebar in km.
    linewidth is the thickness of the scalebar.
    """
    #Projection in metres, need to change this to suit your own figure
    utm = ccrs.UTM(17)
    #Get the extent of the plotted area in coordinates in metres
    x0, x1, y0, y1 = ax.get_extent(utm)
    #Turn the specified scalebar location into coordinates in metres
    sbcx, sbcy = x0 + (x1 - x0) * location[0], y0 + (y1 - y0) * location[1]
    #Generate the x coordinate for the ends of the scalebar
    bar_xs = [sbcx - length * 500, sbcx + length * 500]
    #Plot the scalebar
    ax.plot(bar_xs, [sbcy, sbcy], transform=utm, color='k', linewidth=linewidth)
    #Plot the scalebar label
    ax.text(sbcx, sbcy, str(length) + ' km', transform=utm,
            horizontalalignment='center', verticalalignment='bottom')

In [None]:
grid_list = glob.glob('/lcrc/group/earthscience/rjackson/florence_winds/grids/test_extended/05kmwinds_grid*ext*4.0650.nc')

grids = []
for fn in grid_list:
    grids.append(pyart.io.read_grid(fn))
print(grid_list)

In [None]:
fig = plt.figure(figsize=(15,10)) 
font = {'family' : 'monospace',
        'weight' : 'bold',
        'size'   : 20}

plt.rc('font', **font)
ax = plt.axes(projection=ccrs.PlateCarree())
ax = pydda.vis.plot_horiz_xsection_barbs_map(grids, ax=ax, bg_grid_no=-1, level=1, barb_spacing_x_km=50.0,
                                             barb_spacing_y_km=50.0, show_lobes=False)
wind_speed = np.sqrt(grids[0].fields["u"]["data"]**2 + grids[1].fields["v"]["data"]**2)
wind_speed = wind_speed.filled(np.nan)
lons = grids[0].point_longitude["data"]
lats = grids[0].point_latitude["data"]
cs = ax.contour(lons[2, ::2, ::2], lats[1, ::2, ::2], wind_speed[2, ::2, ::2], levels=[28, 32], 
                linewidths=3, colors=['b', 'r', 'k'])
rad_list = ["GSP", "RAX", "FCX", "MHX", "CLX", "LTX"]
for i in range(len(grids)):
    ax.text(grids[i].radar_longitude["data"], grids[i].radar_latitude["data"], rad_list[i],
        fontsize=20, horizontalalignment="center")

ax.set_xticks(np.arange(-85, -70, 1))
ax.set_yticks(np.arange(27, 40, 1))
ax.set_xlim([-84, -73])
ax.set_ylim([31, 39])
scale_bar(ax, length=100, location=(0.5, 0.05))

In [None]:
print([i in enumerate(grids)])

In [None]:
new_grid = reduce_pyart_grid_res(grids[0], 2)
disp = pyart.graph.GridMapDisplay(new_grid)
#new_grid.y
disp.plot_grid('velocity', level=7, vmin=0, vmax=60)
#new_grid.fields['reflectivity']["data"]

In [None]:
grids[0].fields["reflectivity"]["data"].shape

In [None]:
split_grids_x= split_pyart_grid(grids[0], 3, axis=2)
#fig, ax = plt.subplots(3, 1, figsize=(30,30))
#for i in range(3):
disp = pyart.graph.GridMapDisplay(split_grids_x[1])
disp.plot_grid('reflectivity', level=7, vmin=0, vmax=60)

#for i in range(3):
#    disp = pyart.graph.GridMapDisplay(split_grids_y[i])
#    disp.plot_grid('reflectivity', ax=ax[i,1], level=7, vmin=0, vmax=60)

In [None]:
np.array_split([[2,2,2,4], [3,3,3,4]], 2, axis=1)