In [None]:
import pygrib
import numpy as np
from scipy.integrate import quadrature as quad
from scipy.stats import gamma
import matplotlib.pyplot as plt

In [None]:
# read in grib file
fn_grb = 'blend.t00z.qmd.f012.co.grib2'
ds_grb = pygrib.open(fn_grb)

# latitude and longitude grid values
lat, long = ds_grb.message(2).data()[1:]

# extracting precipitation levels for 6 hr forecast
precip_shape = lat.shape
precip_levels = np.zeros(shape=(99,)+precip_shape)
for i in range(99):
    precip_levels[i,:,:] = ds_grb.message(i+2).data()[0]

In [None]:
def cdf(x, data):
    return np.interp(x, data, np.linspace(0.01, 0.99, 99))

def approx_gamma_params(data):
    loc_idx = int(np.where(data != 0)[0].min())
    loc = data[loc_idx]
    def cdf_(x):
        return cdf(x, data)
    def cdf_shift(x):
        return 1-cdf_(x)
    def x_cdf_shift(x):
        return 2*x*(1-cdf_(x))
    data_max = data.max()
    mom1, _ = quad(cdf_shift, loc, data_max)
    mom2, _ = quad(x_cdf_shift, loc, data_max)
    mean = mom1
    var = mom2 - mom1**2
    shape = mean**2/var
    scale = mean/var
    levels = np.linspace(0.01, 0.99, 99-loc_idx)
    # return shape, loc, scale
    output_zeros = np.zeros(loc_idx)
    output_gamma = gamma.ppf(levels, shape, loc=loc, scale=scale) 
    return np.hstack([output_zeros, output_gamma])

In [None]:
# initializing output
level_width = 30 
precip_levels_approx_gamma = np.zeros(shape=(int(np.floor(100/level_width)),)+lat.shape)
nonzero_idx = np.where(precip_levels[-1,:,:] != 0)

In [None]:
# wrapped function for parallel processing
def wrap(n):
    i = nonzero_idx[0][n]
    j = nonzero_idx[1][n]
    precip_levels_approx_gamma[:,i,j] = approx_gamma_params(precip_levels[:,i,j])

# parallel code using multiprocessing - doesn't seem to speed up code with 8 cores though!
if __name__ == '__main__':
    if mp.cpu_count() > 16:
        pool = mp.Pool(processes = mp.cpu_count()-16)
    else:
        pool = mp.Pool(processes = np.cpu_count())
    pool.map_async(wrap, list(range(nonzero_idx[0].shape[0])))
    pool.close()
    pool.join()

In [None]:
np.save('precip_levels_approx_gamma', precip_levels_approx_gamma)

In [None]:
#n = int(np.random.uniform(0,nonzero_idx[0].shape[0]-1))
i = nonzero_idx[0][n]
j = nonzero_idx[1][n]
approx = approx_gamma_params(precip_levels[:,i,j])
levels = np.linspace(1,99,99)
plt.plot(levels, precip_levels[:,i,j], c='xkcd:black', label='orig')
plt.plot(levels, approx, c='xkcd:red', label='approx')
plt.legend()

In [None]:
approx[67:]

In [None]:
np.where(precip_levels[:,i,j] != 0)[0].min()

In [None]:
data = precip_levels[:,i,j]
loc_idx = np.where(data != 0)[0].min()
loc = data[loc_idx]
print(loc_idx, loc)

In [None]:
def cdf_(x):
    return cdf(x, data)
def cdf_shift(x):
    return 1-cdf_(x)
def x_cdf_shift(x):
    return 2*x*(1-cdf_(x))

In [None]:
data_max = cdf_(.99)
print(data_max, data.max())

In [None]:
mom1, _ = quad(cdf_shift, loc, data_max)
mom2, _ = quad(x_cdf_shift, loc, data_max)
print(mom1, mom2)