In [10]:
import warnings
import numpy as np
import xarray as xr
from numba import jit
import proplot as pplt
pplt.rc.reso='hi'
warnings.filterwarnings('ignore')

In [11]:
FILEDIR = '/global/cfs/cdirs/m4334/sferrett/monsoon-pod/data/processed'
SAVEDIR = '/global/cfs/cdirs/m4334/sferrett/monsoon-pod/figs'
REGIONS = {
    'Eastern Arabian Sea':{'latmin':9.,'latmax':19.5,'lonmin':64.,'lonmax':72.}, 
    'Central India':{'latmin':18.,'latmax':24.,'lonmin':76.,'lonmax':83.},
    'Central Bay of Bengal':{'latmin':9.,'latmax':14.5,'lonmin':86.5,'lonmax':90.},
    'Equatorial Indian Ocean':{'latmin':5.,'latmax':10.,'lonmin':62.,'lonmax':67.5},
    'Konkan Coast':{'latmin':15.,'latmax':19.5,'lonmin':69.,'lonmax':72.5}} 
BINPARAMS = {
    'bl':{'min':-0.6,'max':0.1,'width':0.0025},
    'cape':{'min':-70.,'max':20.,'width':1.},
    'subsat':{'min':-20.,'max':70.,'width':1.}}
SAMPLETHRESH  = 50
MONTHPAIRS    = [(6,7),(7,8)]
NITERATIONS   = 1000
YEARSINSAMPLE = 5

# CASES   = {
#     'JJ':[(6,'June','#D42028'),(7,'July','#F2C85E')],
#     'JA':[(7,'July','#F2C85E'),(8,'August','#5BA7DA')]}

In [12]:
def load(filename,filedir=FILEDIR):
    filepath = f'{filedir}/{filename}'
    ds = xr.open_dataset(filepath)
    return ds.load()

In [13]:
hiresimergds  = load('ERA5_IMERG_pr_bl_terms.nc')
loresimergds  = load('LOW_ERA5_IMERG_pr_bl_terms.nc')
loresgpcpds   = load('ERA5_GPCP_pr_bl_terms.nc')

In [None]:
def get_region(data,key,regions=REGIONS):
    region = regions[key]
    return data.sel(lat=slice(region['latmin'],region['latmax']),lon=slice(region['lonmin'],region['lonmax']))

def get_month(data,months):
    if not isinstance(months,(list,tuple)):
        months = [months]
    monthmask = data.time.dt.month.isin(months)
    return data.sel(time=monthmask)

def get_bin_mean_pr(monthstats,bintype,samplethresh=SAMPLETHRESH):
    if bintype=='1D':
        blbins = monthstats.bl.values
        Q0 = monthstats.Q0.values
        Q1 = monthstats.Q1.values
        Q0[Q0==0.0] = np.nan
        binmeanpr = Q1/Q0
        binmeanpr[Q0<samplethresh] = np.nan
        return xr.DataArray(binmeanpr,coords={'bl':blbins})
    if bintype=='2D':
        subsatbins = monthstats.subsat.values
        capebins   = monthstats.cape.values
        P0 = monthstats.P0.values
        P1 = monthstats.P1.values
        P0[P0==0.0] = np.nan
        binmeanpr = P1/P0
        binmeanpr[P0<samplethresh] = np.nan
        return xr.DataArray(binmeanpr,coords={'subsat':subsatbins,'cape':capebins})

In [None]:
def calc_bin_mean_pr(data,bootstrap=False,regions=REGIONS,samplethresh=SAMPLETHRESH,
                     niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE):
    
    def process_region(region,monthpair):
        regiondata = get_month(get_region(data,region),monthpair)
        if bootstrap:
            samples = get_bootstrap_samples(regiondata,niterations,yearsinsample)
            return [process_sample(sample,monthpair) for sample in samples]
        else:
            return process_sample(regiondata,monthpair)

    def process_sample(sample,monthpair):
        results = []
        for month in monthpair:
            monthdata  = get_month(sample,month)
            monthstats = calc_binned_stats(monthdata)
            binmeanpr  = get_bin_mean_pr(monthstats,bintype='1D')
            results.append(binmeanpr)
        return results

    results  = []
    allmonths = []
    for i,region in enumerate(regions):
        monthpair = MONTHPAIRS[0] if i<3 else MONTHPAIRS[1]
        results.append(process_region(region,monthpair))
        allmonths.extend(monthpair)
        
    dims   = ['region','iteration','month','bl'] if bootstrap else ['region','month','bl']
    coords = {'region':regions,'month':sorted(set(allmonths)),'bl':monthstats.bl.values}
    if bootstrap:
        coords['iteration'] = range(niterations)
    da = xr.DataArray(results,dims=dims,coords=coords)
    return da #da.rename({'bl': 'blbin'})

def calc_confidence_intervals(actualvalues,bootstrapensemble,confidence=0.95):
    actualdiff = bootstrapensemble-actualvalues
    lowerpercentile,upperpercentile = (1-confidence)/2,(1+confidence)/2
    cilower = actualvalues-actualdiff.quantile(upperpercentile,dim='iteration')
    ciupper = actualvalues-actualdiff.quantile(lowerpercentile,dim='iteration')
    return cilower,ciupper

def calc_bin_mean_pr_with_ci(data, niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE,regions=REGIONS,samplethresh=SAMPLETHRESH):
    actualbinmeanpr       = calc_bin_mean_pr(data,regions)
    bootstrappedbinmeanpr = calc_bin_mean_pr(data,regions,bootstrap=True)
    cilower,ciupper = calc_confidence_intervals(actualbinmeanpr,bootstrappedbinmeanpr)
    yerr = xr.concat([actualbinmeanpr-cilower,ciupper-actualbinmeanpr],dim='ci')
    return actualbinmeanpr,yerr,bootstrappedbinmeanpr

def plot_fig(data,yerr,datasetname):
    fig,axs = pplt.subplots(nrows=1,ncols=5,share=True)
    axs.format(suptitle=f'{datasetname}',abcloc='l',abc='a)',
               xlabel='$\mathit{B_L}$ (m/s$^2$)',xlim=(-0.35,0.05),xticks=0.05,
               ylabel='Precipitation (mm/day)',ylim=(0,500),yticks=100,)

    for i,region in enumerate(data.region.values):
        ax = axs[i]
        ax.format(titleloc='l',title=f'{region}')
        casekey = 'JJ' if i<3 else 'JA'
        axnum   = 0 if casekey=='JJ' else 3
        for j, (month,label,color) in enumerate(CASES[casekey]):
            monthdata = data.sel(region=region,month=month)
            monthyerr = yerr.sel(region=region,month=month)
            ax.scatter(monthdata.bl,monthdata,color=color,label=label,marker='.',s=20)
            ax.vlines(monthdata.bl,monthdata-monthyerr.sel(ci=0),monthdata+monthyerr.sel(ci=1), 
                      color=color,alpha=0.5)
        axs[axnum].legend(loc='ul',ncols=1)
    pplt.show()

# Main execution
datasets = [
    ('ERA5/IMERG V06',hiresimergds),
    ('LOW-ERA5/IMERG V06',loresimergds),
    ('ERA5/GPCP',loresgpcpds)
]

for datasetname, data in datasets:
    actual,yerr,bootstrap = calc_bin_mean_pr_with_ci(data)
    plot_dataset(actual,yerr,datasetname)


# def calc_actual_bin_mean_pr(data,regions=REGIONS,samplethresh=SAMPLETHRESH):
#     regionnames   = []
#     binmeanprlist = []
#     blbins = None
#     for i,region in enumerate(regions):
#         regionnames.append(region)
#         monthpair = MONTHPAIRS[0] if i<3 else MONTHPAIRS[1]
#         regiondata = get_month(get_region(data,region),monthpair)
#         regionbinmeanpr = []
#         for month in monthpair:
#             monthdata  = get_month(regiondata,month)
#             monthstats = calc_binned_stats(monthdata)
#             binmeanpr  = get_bin_mean_pr(monthstats,bintype='1D')
#             regionbinmeanpr.append(binmeanpr)
#             if blbins is None:
#                 blbins = monthstats.bl.values
#         binmeanprlist.append(regionbinmeanpr)

# def calc_bootstrapped_bin_mean_pr(data, niterations=NITERATIONS, yearsinsample=YEARSINSAMPLE, regions=REGIONS, samplethresh=SAMPLETHRESH):
#     regionnames = []
#     binmeanprensemble = []
#     bl_bins = None
#     for i,region in enumerate(regions):
#         regionnames.append(region)
#         monthpair  = MONTHPAIRS[0] if i<3 else MONTHPAIRS[1]
#         regiondata = get_month(get_region(data,region),monthpair)
#         samples    = get_bootstrap_samples(regiondata,niterations,yearsinsample)
#         regionbinmeanprsamples = []
    #     for sample in samples:
    #         samplebinmeanpr = []
    #         for month in monthpair:
    #             monthdata  = get_month(sample,month)
    #             monthstats = calc_binned_stats(monthdata)
    #             binmeanpr  = get_bin_mean_pr(monthstats,bintype='1D')
    #             samplebinmeanpr.append(binmeanpr)
    #             if blbins is None:
    #                 blbins = monthstats.bl.values
    #         regionbinmeanprsamples.append(samplebinmeanpr)
    #     binmeanprensemble.append(regionbinmeanprsamples)
    
    # da = xr.DataArray(binmeanprensemble, 
    #                   dims=['region','iteration','month','bl'],
    #                   coords={'region':regionnames,
    #                           'iteration':range(niterations),
    #                           'month':['first','second'],
    #                           'bl':blbins})
    # da = da.rename({'bl':'blbin'})
    # return da

# def calc_confidence_intervals(actualvalues,bootstrapensemble,confidence=0.95):
#     actualdiff = bootstrapensemble-actualvalues
#     lowerpercentile = (1-confidence)/2
#     upperpercentile = 1-lowerpercentile
#     cilower = actualvalues-actualdiff.quantile(upperpercentile,dim='iteration')
#     ciupper = actualvalues-actualdiff.quantile(lowerpercentile,dim='iteration')
#     return cilower,ciupper

# def calc_bin_mean_pr_with_ci(data,niterations=NITERATIONS,yearsinsample=YEARSINSAMPLE,
#                              regions=REGIONS,samplethresh=SAMPLETHRESH):
#     actualbinmeanpr       = calc_actual_bin_mean_pr(data,regions,samplethresh)
#     bootstrappedbinmeanpr = calc_bootstrapped_bin_mean_pr(data,niterations,yearsinsample,regions,samplethresh)
#     cilower,ciupper = calc_confidence_intervals(actualbinmeanpr,bootstrappedbinmeanpr)
#     yerr = xr.concat([actualbinmeanpr-cilower,ciupper-actualbinmeanpr],dim='ci')
#     return actualbinmeanpr,yerr,bootstrappedbinmeanpr