## Import Necessary Packages

In [13]:
import warnings
import numpy as np
import xarray as xr
import proplot as pplt
pplt.rc.reso='hi'
warnings.filterwarnings('ignore')

## User-Defined Fields

In [22]:
REGIONS   = {
    'Eastern Arabian Sea':{'latmin':9.,'latmax':18.,'lonmin':67.,'lonmax':72.5}, 
    'Central India':{'latmin':18.,'latmax':24.,'lonmin':76.,'lonmax':83.},
    'Northern Bay of Bengal':{'latmin':18.,'latmax':21.5,'lonmin':87.5,'lonmax':90.},
    'Central Bay of Bengal':{'latmin':8.5,'latmax':15.,'lonmin':86.,'lonmax':90.}} 
BINPARAMS = {
    'bl':{'min':-0.6,'max':0.1,'width':0.0025},
    'cape':{'min':-70.,'max':20.,'width':1.},
    'subsat':{'min':-20.,'max':70.,'width':1.}}
FILEDIR   = '/ocean/projects/atm200007p/sferrett/Repos/monsoon-pr/data/processed'
PRTHRESH  = 0.25
NUMTHRESH = 50
CASES     = [(6,'June','red6'),(7,'July','green6'),(8,'August','blue6')]

## Import Data

In [23]:
def convert_common_units(ds):
    coeff = 9.8/(3*340)
    dims  = ['time','lat','lon']
    newds = xr.Dataset(
    data_vars=dict(
        pr=(dims,ds.pr.data),
        bl=(dims,ds.bl.data),
        cape=(dims,(coeff*ds.wb.data*ds.cape.data)),
        subsat=(dims,(coeff*ds.wl.data*ds.subsat.data))),
    coords=dict(time=ds.time.data,lat=ds.lat.data,lon=ds.lon.data))
    return newds

def apply_precipitating_mask(data,prthresh=PRTHRESH):
    return data.where(data.pr>prthresh)

def open_dataset(filename,precipitating=False,convertunits=False,prthresh=PRTHRESH,filedir=FILEDIR):
    filepath = f'{filedir}/{filename}'
    ds = xr.open_mfdataset(filepath)
    if precipitating==True:
        ds = apply_precipitating_mask(ds)
    if convertunits==True:
        ds = convert_common_units(ds)
    return ds

def get_region(data,key,regions=REGIONS):
    region = regions[key]
    return data.sel(lat=slice(region['latmin'],region['latmax']),lon=slice(region['lonmin'],region['lonmax']))

def get_month(data,month):
    return data.sel(time=data.time.dt.month==month)

In [24]:
data = open_dataset('OBS_bl-pr_*.nc',precipitating=True).load()

## Plot $\text{P}-B_L$ Relationship

In [53]:
def get_bin_edges(key,binparams=BINPARAMS):
    varname  = binparams[key]
    binedges = np.arange(varname['min'],varname['max']+varname['width'],varname['width'])
    return binedges

def get_binned_stats(data,binparams=BINPARAMS,prthresh=PRTHRESH):
    blbins      = get_bin_edges('bl')
    capebins    = get_bin_edges('cape')
    subsatbins  = get_bin_edges('subsat')
    blidxs      = ((data.bl.values-binparams['bl']['min'])/(binparams['bl']['width'])+0.5).astype(int)
    capeidxs    = ((data.cape.values-binparams['cape']['min'])/binparams['cape']['width']-0.5).astype(int)
    subsatidxs  = ((data.subsat.values-binparams['subsat']['min'])/binparams['subsat']['width']-0.5).astype(int)
    nblbins     = blbins.size
    ncapebins   = capebins.size
    nsubsatbins = subsatbins.size
    Q0 = np.zeros((nblbins))
    Q1 = np.zeros((nblbins))
    Q2 = np.zeros((nblbins))
    QE = np.zeros((nblbins))
    P0 = np.zeros((nsubsatbins,ncapebins))
    P1 = np.zeros((nsubsatbins,ncapebins))
    P2 = np.zeros((nsubsatbins,ncapebins))
    PE = np.zeros((nsubsatbins,ncapebins))
    for timeidx in range(len(data.time)):
        for latidx in range(len(data.lat)):
            for lonidx in range(len(data.lon)):
                prval       = data.pr.values[timeidx,latidx,lonidx]
                blidx       = blidxs[timeidx,latidx,lonidx]
                capeidx     = capeidxs[timeidx,latidx,lonidx]
                subsatidx   = subsatidxs[timeidx,latidx,lonidx]
                validpr     = np.isfinite(prval)
                validbl     = (0<=blidx<nblbins)
                validcape   = (0<=capeidx<ncapebins)
                validsubsat = (0<=subsatidx<nsubsatbins)
                if validbl & validpr:
                    Q0[blidx] += 1 
                    Q1[blidx] += prval
                    Q1[blidx] += prval**2
                    if prval > prthresh:
                        QE[blidx] += 1
                if validcape & validsubsat & validpr:
                    P0[subsatidx,capeidx] += 1
                    P1[subsatidx,capeidx] += prval
                    P2[subsatidx,capeidx] += prval**2
                    if prval > prthresh:
                        PE[subsatidx,capeidx] += 1
    ds = xr.Dataset(data_vars={'Q0':(('bl'),Q0),'QE':(('bl'),QE),'Q1':(('bl'),Q1),'Q2':(('bl'),Q2),
                               'P0':(('subsat','cape'),P0),'PE':(('subsat','cape'),PE),'P1':(('subsat','cape'),P1),'P2':(('subsat','cape'),P2)},
                          coords={'subsat':subsatbins,'cape':capebins,'bl':blbins})
    ds.Q0.attrs = dict(long_name='Count of points in each bin')
    ds.QE.attrs = dict(long_name=f'Count of precipitating ( > {prthresh} mm/day) points in each bin')
    ds.Q1.attrs = dict(long_name='Sum of precipitation in each bin',units='mm/day')
    ds.Q2.attrs = dict(long_name='Sum of squared precipitation in each bin',units='mm²/day²')
    ds.P0.attrs = dict(long_name='Count of points in each bin')
    ds.PE.attrs = dict(long_name=f'Count of precipitating ( > {prthresh} mm/day) points in each bin')
    ds.P1.attrs = dict(long_name='Sum of precipitation in each bin',units='mm/day')
    ds.P2.attrs = dict(long_name='Sum of squared precipitation in each bin',units='mm²/day²')
    return ds

def get_bin_mean_pr(montthstats,numthresh=NUMTHRESH):
    Q0 = monthstats.Q0
    Q1 = monthstats.Q1
    Q0[Q0==0.0] = np.nan
    Q = Q1/Q0
    Q[Q<numthresh] = np.nan
    return Q

In [65]:
fig,axs = pplt.subplots(nrows=1,ncols=4,share=False,span=False)
axs.format(collabels=REGIONS.keys(),xlabel='$\mathit{B_L}$ (m/s$^2$)',ylabel='Precipitation (mm/day)')
for i,region in enumerate(REGIONS):
    regiondata = get_region(data,region)
    for month,label,color in CASES:
        monthdata  = get_month(regiondata,month)
        monthstats = get_binned_stats(monthdata)
        binmeanpr  = get_bin_mean_pr(monthstats)
        axs[i].scatter(monthstats.bl,binmeanpr,color=color,alpha=0.5,label=label)
axs[0].legend(loc='ul',ncols=1)
pplt.show()