In [None]:
import pygrib
import numpy as np
from scipy.integrate import quadrature as quad
from scipy.stats import gamma
import matplotlib.pyplot as plt

In [None]:
# read in grib file
month = 'jan'
if month == 'jan':
    date = '20220111'
    lead_time = '00'
    forecast_time = '018'
elif month == 'march':
    date = '20220324'
    lead_time = '12'
    forecast_time = '018'
elif month == 'april':
    date = '20220425'
    lead_time = '12'
    forecast_time = '018'
elif month == 'july':
    date = '20220703'
    lead_time = '00'
    forecast_time = '018'
elif month == 'sept':
    date = '20210920'
    lead_time = '12'
    forecast_time = '018'
elif month == 'oct':
    date = '20211015'
    lead_time = '00'
    forecast_time = '018'
elif month == 'dec':
    date = '20211205'
    lead_time = '12'
    forecast_time = '018'
else:
    print('Issue with month input.')
fn_grb = 'data/blend' + date + '.t' + lead_time + 'z.qmd.f' + forecast_time + '.co.grib2'
ds_grb = pygrib.open(fn_grb)

# latitude and longitude grid values
lat, long = ds_grb.message(2).data()[1:]

# extracting data
precip_shape = lat.shape
precip = np.zeros(shape=(99,)+precip_shape)
for i in range(99):
    precip[i,:,:] = ds_grb.message(i+2).data()[0] # ACPC:surface:12-18 hour acc fcst
ds_grb.close()

In [None]:
def cdf(x, data):
    return np.interp(x, data, np.linspace(0.01, 0.99, 99))

def approx_gamma_params(data):
    loc_idx = int(np.where(data != 0)[0].min())
    def cdf_(x):
        return cdf(x, data)
    def cdf_shift(x):
        return 1-cdf_(x)
    def x_cdf_shift(x):
        return 2*x*(1-cdf_(x))
    data_max = data.max()
    mom1, _ = quad(cdf_shift, 0, data_max)
    mom2, _ = quad(x_cdf_shift, 0, data_max)
    mean = mom1
    var = mom2 - mom1**2
    shape = mean**2/var
    scale = mean/var
    levels = np.linspace(0.01, 0.99, 99-loc_idx)
    # return shape, loc, scale
    output_zeros = np.zeros(loc_idx)
    output_gamma = gamma.ppf(levels, shape, scale=scale) 
    return np.hstack([output_zeros, output_gamma])

In [None]:
# initializing output
level_width = 30 
precip_levels_approx_gamma = np.zeros(shape=(int(np.floor(100/level_width)),)+lat.shape)
nonzero_idx = np.where(precip_levels[-1,:,:] != 0)

In [None]:
# wrapped function for parallel processing
def wrap(n):
    i = nonzero_idx[0][n]
    j = nonzero_idx[1][n]
    precip_levels_approx_gamma[:,i,j] = approx_gamma_params(precip_levels[:,i,j])

# parallel code using multiprocessing - doesn't seem to speed up code with 8 cores though!
if __name__ == '__main__':
    if mp.cpu_count() > 16:
        pool = mp.Pool(processes = mp.cpu_count()-16)
    else:
        pool = mp.Pool(processes = np.cpu_count())
    pool.map_async(wrap, list(range(nonzero_idx[0].shape[0])))
    pool.close()
    pool.join()

In [None]:
alpha = [0.01, 0.1, 1.0, 2.0, 4.0]
beta = 1.0
scale = 1/beta
q = np.linspace(0.01,0.99,99)
data = np.zeros(shape=(99,len(alpha)))
for i in range(len(alpha)):
    data[:,i] = gamma.ppf(q=q, a=alpha[i], scale=scale)

In [None]:
approx = np.zeros(data.shape)
for i in range(len(alpha)):
    approx[:,i] = approx_gamma_params(data[:,i])

In [None]:
errors = np.zeros(data.shape)
for i in range(len(alpha)):
    errors[:,i] = (approx[:,i]-data[:,i]) / max(data[:,i].max(),approx[:,i].max())

In [None]:
cs = ['xkcd:black', 'xkcd:blue', 'xkcd:red', 'xkcd:green', 'xkcd:orange']
labels = [f'alpha={alpha[0]}', f'alpha={alpha[1]}', f'alpha={alpha[2]}', f'alpha={alpha[3]}', f'alpha={alpha[4]}']
for i in range(len(cs)):
    plt.plot(q, errors[:,i], c=cs[i], label=labels[i])

    
plt.xlabel('probability levels')
plt.ylabel('absolute relative')
plt.legend()
plt.show()