In [None]:
import numpy as np
import pygrib
from scipy import interpolate, optimize 

import matplotlib as mpl
import matplotlib.pyplot as plt 
import matplotlib.colors as colors
from mpl_toolkits.basemap import Basemap, cm

import multiprocessing as mp

In [None]:
# linear spline functions

def linear_splines_unif(data, num_knots=10, level_width=1, zero_inflated=True):   
    '''
    Calculates piecewise linear splines for quantile data using specified number of 
    knots uniformly spaced and returning interpolated approximation at every 
    quantile level with level_width.
    ''' 

    # checking if cdf is all zero
    if data[-1] == 0:
        return np.zeros(int(np.floor(100/level_width)))
    
    if zero_inflated:
        # calculating where cdf starts being nonzero (all zero cdf's should not be inputted)
        levels = np.array(range(1,100))
        knot_ = np.where(data > 0)[0].min() - 1   
        if knot_ > 1:
            knots = np.unique(np.linspace(knot_-1, 98, num_knots-1, dtype=int))
            knots = np.insert(knots, 0, 0)
        else:
            knots = np.unique(np.linspace(0, 98, num_knots, dtype=int))
    else:
        knots = np.unique(np.linspace(0, 98, num_knots, dtype=int))
        
    # approx = interpolate.interp1d(knots+1, data[knots], assume_sorted=True) 
    # return approx(levels[level_width-1::level_width])
    return np.interp(levels[level_width-1::level_width], knots+1, data[knots])

def linear_splines(x, num_knots, *params):
    '''
    Function to be used in scipy.optimize.curve_fit in linear_splines_var function.
    '''

    knot_vals = list(params[0][0:num_knots])
    knots = list(params[0][num_knots:])
    return np.interp(x, knots, knot_vals)

def linear_splines_var(data, num_knots=10, level_width=1, zero_inflated=True):
    '''
    Calculates piecewise linear splines for quantile data using specified number of
    knots with optimized placement and returning interpolated approximation at every
    quantile level with level_width.
    '''
    
    # checking if cdf is all zero
    if data[-1] == 0:
        return np.zeros(int(np.floor(100/level_width)))

    # setting up intial value of parameters
    p_0 = np.linspace(0,98,num_knots).astype(int)
    p_0 = np.hstack([data[p_0], p_0])

    # try to fit parameters with RuntimeError exception that returns linear_splines_unif
    # that uses uniformly space knots
    try:
        fit, _ = optimize.curve_fit(lambda x, *params : linear_splines(x, num_knots, params), 
                np.linspace(1,99,99), data, p_0)
        levels = np.linspace(1,99,99)
        levels = levels[level_width-1::level_width]
        return np.interp(levels, fit[num_knots:], fit[:num_knots])
    except RuntimeError:
        return linear_splines_unif(data, num_knots, level_width, zero_inflated)

In [None]:
# read in grib file
fn_grb = 'data/blend20220324.t12z.qmd.f018.co.grib2'
# fn_grb = 'data/blend20220210.t12z.qmd.f018.co.grib2'
# fn_grb = 'data/blend.20220525t12z.qmd.f018.co.grib2'
ds_grb = pygrib.open(fn_grb)

# latitude and longitude grid values
lat, long = ds_grb.message(2).data()[1:]

# extracting data
precip_shape = lat.shape
precip = np.zeros(shape=(99,)+precip_shape)
# temp_shape = lat.shape
# temp = precip_levels
for i in range(99):
    precip[i,:,:] = ds_grb.message(i+2).data()[0] # ACPC:surface:12-18 hour acc fcst
    # temp[i,:,:] = ds_grb.message(i+215).data()[0] # TMP:2 m above ground:0-18 hour max

In [None]:
# checking for monotonicity
# for i in range(lat.shape[0]):
#     for j in range(lat.shape[1]):
#         for level in range(99-1):
#             if precip[level,i,j] > precip[level+1,i,j]:
#                 print(f'Issue with precip at {i}, {j}.')
#             if temp[level,i,j] > temp[level+1,i,j]:
#                 print(f'Issue with temp at {i}, {j}.')

In [None]:
level_width = 30
precip_unif = np.zeros(shape=(int(np.floor(100/level_width)),)+lat.shape)
for i in range(precip_shape[0]):
    for j in range(precip_shape[1]):
        precip_unif[:,i,j] = linear_splines_unif(data=precip[:,i,j], 
                                    num_knots=10, level_width=level_width, zero_inflated=True)      
np.save('precip_unif_0324', precip_unif)
# precip_unif = np.load('precip_unif_????.npy')

In [None]:
count = [0,0,0]
level_width = 30
precip_var = np.zeros(shape=(int(np.floor(100/level_width)),)+lat.shape)
for i in range(precip_shape[0]):
    for j in range(precip_shape[1]):
        precip_var[:,i,j] = linear_splines_var(precip[:,i,j], level_width=level_width)
        if sum(count) in [100000, 500000, 1000000, 1500000, 2000000, 2500000, 3000000, 3500000]:
            print(sum(count),lat.shape[0]*lat.shape[1])
np.save('precip_var_0324', precip_var)
# precip_var = np.load('precip_var_????.npy')

In [None]:
# graphing using basemap

def Basemap_plot(data, long, lat, level, diff=False, unif=True, scale=1):

    possible_levels = np.array([30, 60, 90])
    level_idx = np.where(possible_levels == level)[0][0]
    if diff:
        if unif:
            var_name = f'Uniform node error at {level}% level'
        else:
            var_name = f'Variable node error at {level}% level'
    else:
        var_name = f'Precipitation at probability level {level}%'
            
    map = Basemap(llcrnrlon=-123.,llcrnrlat=20., 
                urcrnrlon=-59., urcrnrlat=48., 
                projection='lcc', 
                lat_1=38.5,
                lat_0=38.5,
                lon_0=-97.5,
                resolution='l')

    # draw coastlines, country boundaries, fill continents
    map.drawcoastlines(linewidth=0.25)
    map.drawcountries(linewidth=0.25)
    map.fillcontinents(color='xkcd:white',lake_color='xkcd:white')



    # draw the edge of the map projection region (the projection limb)
    map.drawmapboundary(fill_color='xkcd:white')
    map.drawstates()

    # draw lat/lon grid lines every 30 degrees.
    map.drawmeridians(np.arange(-180,180,30))
    map.drawparallels(np.arange(-90,90,30))

    x, y = map(long, lat)

    if diff:
        if scale ==0:
            data_abs_max = int(np.ceil(max(np.abs(data.min()),np.abs(data.max()))))
            levels = list(range(-data_abs_max,data_abs_max+1))
        else:
            levels = list(range(-scale, scale+1))
        plt.pcolormesh(x, y, data,
                    norm=colors.Normalize(vmin=levels[0], vmax=levels[-1]),
                    cmap='seismic', shading='nearest')
        # map.contourf(x, y, data, 16, levels=levels, cmap='seismic')
        map.colorbar()
        map.colorbar().set_label('mm of precipitation')
    else:
        map.contourf(x, y, data, 16, linewidths=1.5)
        map.colorbar()
        map.colorbar().set_label('mm of precipitation')

    plt.title(var_name)
    plt.show()

In [None]:
Basemap_plot(data=precip[30-1,:,:], long=long, lat=lat, level=30, diff=False, scale=1)

In [None]:
data = precip_unif[0,:,:] - precip[30-1,:,:]
Basemap_plot(data=data, long=long, lat=lat, level=30, diff=True, scale=1)