In [None]:
import numpy as np
import pygrib
from scipy import optimize

import matplotlib as mpl
import matplotlib.pyplot as plt 
import matplotlib.colors as colors
from mpl_toolkits.basemap import Basemap, cm

import time

In [None]:
# read in grib file
# month's completed: 
month = 'june'
if month == 'jan':
    date = '20220111'
    lead_time = '00'
    forecast_time = '018'
elif month == 'march':
    date = '20220324'
    lead_time = '12'
    forecast_time = '018'
elif month == 'april':
    date = '20220425'
    lead_time = '12'
    forecast_time = '018'
elif month == 'june':
    date = '20220626'
    lead_time = '12'
    forecast_time = '006'
elif month == 'july':
    date = '20220703'
    lead_time = '00'
    forecast_time = '018'
elif month == 'sept':
    date = '20210920'
    lead_time = '12'
    forecast_time = '018'
elif month == 'oct':
    date = '20211015'
    lead_time = '00'
    forecast_time = '018'
elif month == 'dec':
    date = '20211205'
    lead_time = '12'
    forecast_time = '018'
elif month == 'wet':
    date = ''
    lead_time = '00'
    forecast_time = '012'
else:
    print('Issue with month input.')
fn_grb = 'data/blend' + date + '.t' + lead_time + 'z.qmd.f' + forecast_time + '.co.grib2'
ds_grb = pygrib.open(fn_grb)

# latitude and longitude grid values
lat, long = ds_grb.message(2).data()[1:]

# extracting data
precip_shape = lat.shape
precip = np.zeros(shape=(99,)+precip_shape)
for i in range(99):
    precip[i,:,:] = ds_grb.message(i+2).data()[0] # ACPC:surface:12-18 hour acc fcst
ds_grb.close()

In [None]:
fn_grb = 'data/urma2p5.2022062612.pcp_06h.wexp.grb2'
ds_grb = pygrib.open(fn_grb)
obs = ds_grb.message(1).data()[0]
ds_grb.close()

In [None]:
include_temp = False
if include_temp:
    ds_grb = pygrib.open(fn_grb)
    temp_shape = lat.shape
    temp = np.zeros(shape=(99,)+temp_shape)
    for i in range(99):
        temp[i,:,:] = ds_grb.message(i+215).data()[0] # TMP:2 m above ground:0-18 hour max
    ds_grb.close()

In [None]:
# masking precip at grid points that are not monotonic
mask = np.zeros(precip.shape)
for i in range(lat.shape[0]):
    for j in range(lat.shape[1]):
        issue = False
        for level in range(99-1):
            if precip[level,i,j] > precip[level+1,i,j]:
                if not issue:
                    issue = True
                    #mask[level+1:,i,j] = np.ones(mask.shape[0] - level - 1)
                    mask[:,i,j] = np.ones(mask.shape[0])
                    
precip = np.ma.masked_array(precip, mask)

In [None]:
obs_unmasked = obs
obs = np.ma.masked_array(obs, mask[-1,:,:])

In [None]:
# linear spline functions

def linear_splines_unif(data, num_knots=10, zero_inflated=True):   
    '''
    Calculates piecewise linear splines for quantile data using specified number of 
    knots uniformly spaced and returning interpolated approximation at every 
    quantile level.
    ''' 

    # checking if cdf is all zero
    if data[-1] == 0:
        return np.zeros(99)
    
    if zero_inflated:
        # calculating where cdf starts being nonzero (all zero cdf's should not be inputted)
        knot_ = np.where(data > 0)[0].min() - 1   
        if knot_ > 1:
            knots = np.unique(np.linspace(knot_-1, 98, num_knots-1, dtype=int))
            knots = np.insert(knots, 0, 0)
        else:
            knots = np.unique(np.linspace(0, 98, num_knots, dtype=int))
    else:
        knots = np.unique(np.linspace(0, 98, num_knots, dtype=int))
        
    levels = range(1,100)
    return np.interp(levels, knots+1, data[knots])

def linear_splines(x, num_knots, *params):
    '''
    Function to be used in scipy.optimize.curve_fit in linear_splines_var function.
    '''

    knot_vals = list(params[0][0:num_knots])
    knots = list(params[0][num_knots:])
    return np.interp(x, knots, knot_vals)

def linear_splines_var(data, num_knots=5, zero_inflated=True):
    '''
    Calculates piecewise linear splines for quantile data using specified number of
    knots with optimized placement and returning interpolated approximation at every
    quantile level with level_width.
    '''
    
    # checking if cdf is all zero
    if data[-1] == 0:
        return np.zeros(99)

    # setting up intial value of parameters
    p_0 = np.linspace(0,98,num_knots).astype(int)
    p_0 = np.hstack([data[p_0], p_0])

    # try to fit parameters with RuntimeError exception that returns linear_splines_unif
    # that uses uniformly space knots
    try:
        fit, _ = optimize.curve_fit(lambda x, *params : linear_splines(x, num_knots, params), 
                np.linspace(1,99,99), data, p_0)
        levels = range(1,100)
        return np.interp(levels, fit[num_knots:], fit[:num_knots])
    except RuntimeError:
        return linear_splines_unif(data, num_knots*2, zero_inflated)
    
qs = np.linspace(0.01, 0.99, 99)
N = int(1e3)

def calc_errors(orig, approx):
    xs = np.linspace(orig.min(), orig.max(), N)
    differences = np.abs(np.interp(xs[1:], orig, qs) - np.interp(xs[1:], approx, qs))
    differences_weighted = differences * (xs[1:] - xs[:-1])
    return [differences.max(), np.sum(differences_weighted), np.sum(differences * differences_weighted)]

def obs_CRPS(obs, approx):
    xs = np.linspace(approx.min(), approx.max(), N)
    obs_cdf = np.zeros(xs.shape[0])
    obs_nonzero = np.where(xs >= obs)[0]
    obs_cdf[obs_nonzero] = np.ones(obs_nonzero.shape[0])
    return np.sum((np.interp(xs[1:], approx, qs) - obs_cdf[1:])**2 * (xs[1:] - xs[:-1]))

In [None]:
# approx_type = 'unif' # 'unif', 'var', or 'both'
# error_calc = True
# save = True

# precip_unif = np.zeros(shape=(99,)+lat.shape)
# precip_var = np.zeros(shape=(99,)+lat.shape)
# if error_calc:
#     errors_unif = np.zeros((3,) + lat.shape)
#     errors_var = np.zeros((3,) + lat.shape)
# idx = np.where(mask[-1,:,:] == 0)
# level_idxs = [4, 24, 50, 74, 94]

# if approx_type == 'unif':
#     for n in range(idx[0].shape[0]):
#         i = idx[0][n]
#         j = idx[1][n]
#         if precip[-1,i,j] != 0:
#             precip_unif[:,i,j] = linear_splines_unif(data=precip[:,i,j], num_knots=10, zero_inflated=True)  
#             if error_calc:
#                 errors_unif[:,i,j] = calc_errors(precip[:,i,j], precip_unif[:,i,j])
#     precip_unif = np.ma.masked_array(precip_unif, mask)
#     errors_unif = np.ma.masked_array(errors_unif, [mask[-1,:,:], mask[-1,:,:], mask[-1,:,:]])
#     if save:
#         precip_unif[level_idxs,:,:].dump('results/precip_unif_' + month)
#         errors_unif.dump('results/errors_unif_' + month)

# elif approx_type == 'var':
#     for n in range(idx[0].shape[0]):
#         i = idx[0][n]
#         j = idx[1][n]
#         if precip[-1,i,j] != 0:
#             precip_var[:,i,j] = linear_splines_var(data=precip[:,i,j], num_knots=10, zero_inflated=True) 
#             if error_calc:
#                 errors_var[:,i,j] = calc_errors(precip[:,i,j], precip_var[:,i,j])
#     precip_var = np.ma.masked_array(precip_var, mask)
#     errors_var = np.ma.masked_array(errors_var, [mask[-1,:,:], mask[-1,:,:], mask[-1,:,:]])
#     if save:
#         precip_var[level_idxs,:,:].dump('results/precip_var_' + month)
#         errors_var.dump('results/errors_var_' + month)
        
# elif approx_type == 'both':
#     for n in range(idx[0].shape[0]):
#         i = idx[0][n]
#         j = idx[1][n]
#         if precip[-1,i,j] != 0:
#             precip_unif[:,i,j] = linear_splines_unif(data=precip[:,i,j], num_knots=10, zero_inflated=True)  
#             precip_var[:,i,j] = linear_splines_var(data=precip[:,i,j], num_knots=10, zero_inflated=True) 
#             if error_calc:
#                 errors_unif[:,i,j] = calc_errors(precip[:,i,j], precip_unif[:,i,j])
#                 errors_var[:,i,j] = calc_errors(precip[:,i,j], precip_var[:,i,j])
#     precip_unif = np.ma.masked_array(precip_unif, mask)
#     precip_var = np.ma.masked_array(precip_var, mask)
#     errors_unif = np.ma.masked_array(errors_unif, [mask[-1,:,:], mask[-1,:,:], mask[-1,:,:]])
#     errors_var = np.ma.masked_array(errors_var, [mask[-1,:,:], mask[-1,:,:], mask[-1,:,:]])
#     if save:
#         precip_unif[level_idxs,:,:].dump('results/precip_unif_' + month)
#         errors_unif.dump('results/errors_unif_' + month)
#         precip_var[level_idxs,:,:].dump('results/precip_var_' + month)
#         errors_var.dump('results/errors_var_' + month)
        

precip_unif = np.load('results/precip_unif_' + month, allow_pickle=True)
precip_var = np.load('results/precip_var_' + month, allow_pickle=True)
errors_unif = np.load('results/errors_unif_' + month, allow_pickle=True)
errors_var = np.load('results/errors_var_' + month, allow_pickle=True) 

In [None]:
obs_crps = np.zeros(lat.shape)
idx = np.where(mask[-1,:,:] == 0)
for n in range(idx[0].shape[0]):
    i = idx[0][n]
    j = idx[1][n]
    obs_crps[i,j] = obs_CRPS(obs[i,j], precip[:,i,j])

In [None]:
# graphing using basemap

def Basemap_plot(data, long, lat, diff=False, name=None, color_label='mm of precipiation', boundary=None):
            
    map = Basemap(llcrnrlon=-123.,llcrnrlat=20., urcrnrlon=-59., urcrnrlat=48., projection='lcc', lat_1=38.5, lat_0=38.5, lon_0=-97.5, resolution='l')

    # draw coastlines, country boundaries, fill continents
    map.drawcoastlines(linewidth=0.25)
    map.drawcountries(linewidth=0.25)
    map.fillcontinents(color='xkcd:white',lake_color='xkcd:white')



    # draw the edge of the map projection region (the projection limb)
    map.drawmapboundary(fill_color='xkcd:white')
    map.drawstates()

    # draw lat/lon grid lines every 30 degrees.
    map.drawmeridians(np.arange(-180,180,30))
    map.drawparallels(np.arange(-90,90,30))

    x, y = map(long, lat)

    if diff:
        if boundary is None:
            boundary = int(np.ceil(max(np.abs(data.min()), np.abs(data.max()))))
        levels = list(range(-boundary, boundary+1))
        plt.pcolormesh(x, y, data, norm=colors.Normalize(vmin=levels[0], vmax=levels[-1]), cmap='seismic')
        #map.contourf(x, y, data, 2*boundary+1, levels=levels, cmap='seismic')
        map.colorbar()
        map.colorbar().set_label(color_label)
    else:
        map.contourf(x, y, data, 16, linewidths=1.5)
        map.colorbar()
        map.colorbar().set_label(color_label)
        
    if name is not None:
        plt.title(name)
    
    plt.show()

In [None]:
level = 50
level_idx = np.where(np.array([5,25,50,75,95]) == level)[0]
Basemap_plot(data=precip[level-1,:,:], long=long, lat=lat, name=f'Precipitation at {level}% quantile')
Basemap_plot(data=(precip_unif[level_idx,:,:]-precip[level-1,:,:])[0], long=long, lat=lat, diff=True, name=f'Uniform node error at {level}% quantile')
Basemap_plot(data=(precip_var[level_idx,:,:]-precip[level-1,:,:])[0], long=long, lat=lat, diff=True, name=f'Variable node error at {level}% quantile')
Basemap_plot(data=errors_unif[0,:,:], long=long, lat=lat, name=f'KS statistic for uniform nodes at {level}% quantile', color_label='KS')
Basemap_plot(data=errors_unif[1,:,:], long=long, lat=lat, name=f'L_1 norm for uniform nodes at {level}% quantile', color_label='L_1 norm')
Basemap_plot(data=errors_unif[2,:,:], long=long, lat=lat, name=f'CRPS for uniform nodes at {level}% quantile', color_label='CRPS')
Basemap_plot(data=errors_var[0,:,:], long=long, lat=lat, name=f'KS statistic for variable nodes at {level}% quantile', color_label='KS')
Basemap_plot(data=errors_var[1,:,:], long=long, lat=lat, name=f'L_1 norm for variable nodes at {level}% quantile', color_label='L_1 norm')
Basemap_plot(data=errors_var[2,:,:], long=long, lat=lat, name=f'CRPS for variable nodes at {level}% quantile', color_label='CRPS')

In [None]:
Basemap_plot(data=precip[25-1,:,:], long=long, lat=lat, name='Precipitation at 25% quantile')
Basemap_plot(data=precip[50-1,:,:], long=long, lat=lat, name='Precipitation at 50% quantile')
Basemap_plot(data=precip[75-1,:,:], long=long, lat=lat, name='Precipitation at 75% quantile')
Basemap_plot(data=obs, long=long, lat=lat, name='Observed (masked)')
Basemap_plot(data=obs_unmasked, long=long, lat=lat, name='Observed (unmasked)')
Basemap_plot(data=obs_crps[:,:], long=long, lat=lat, name='CRPS')