In [None]:
import os
import xarray as xr
from multiprocess import Pool
import itertools
import pandas as pd

# read in the GWL baseline periods
df_baseline = pd.read_csv("../data/input/cmip6_wl_0p84_baselines_41years.csv")

def calculate_regional_mon_dev(filename, mon_tot_path="../../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/mon_tot_regrid", output_path="../../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/new_mon_dev_regrid", df_baseline_used = df_baseline):
    ds = xr.open_dataset(os.path.join(mon_tot_path, filename))
    
    # extract the CMIP6 model name and the scenario from the filename
    pattern_model = r'_pr_(.*?)_g025'
    pattern_match = re.search(pattern_model, filename).group(1)
    baseline = df_baseline_used.loc[df_baseline_used['modelscenarioensemble'] == pattern_match]
    
    # if we get no match or multiple ones, throw a warning and exit. Otherwise, we continue
    if len(baseline) > 1:
        print("Warning: baseline contains more than one row corresponding to "+pattern_match+". Please inspect")
        return None

    if len(baseline) == 0:
        print("Warning: baseline contains zero rows corresponding to "+pattern_match+". Please inspect")
        return None
    
    # use the years in the 'beg' and 'end' column to create date stamps for slicing
    baseline_beg = str(baseline["beg"].values[0]) + "-01-01"
    baseline_end = str(baseline["end"].values[0]) + "-12-30"
    
    # Calculate historical averages over 1979-2019 baseline period
    monthclim = ds.pr.sel(time=slice(baseline_beg, baseline_end)).groupby("time.month").mean()
    monthstd = ds.pr.sel(time=slice(baseline_beg,  baseline_end)).groupby("time.month").std()
    anntotal = monthclim.sum('month')
 
    # Calculate the Kotz et al. 2022 equation
    ratio = ((ds.pr.groupby("time.month") - monthclim).groupby('time.month') / monthstd)
    bymonth = ratio.groupby('time.month') * monthclim / anntotal
    byyear = bymonth.groupby('time.year').sum()
    
    ds2 = byyear.to_dataset()
    ds2['GID_0'] = ds.GID_0
    ds2['NAME_0'] = ds.NAME_0
    ds2['GID_1'] = ds.GID_1
    ds2['NAME_1'] = ds.NAME_1
    ds2['VARNAME_1'] = ds.VARNAME_1
    ds2['NL_NAME_1'] = ds.NL_NAME_1
    ds2['TYPE_1'] = ds.TYPE_1
    ds2['ENGTYPE_1'] = ds.ENGTYPE_1
    ds2['CC_1'] = ds.CC_1
    ds2['HASC_1'] = ds.HASC_1
    
    ds2.to_netcdf(os.path.join(output_path, filename.replace("mon_totd", "new_mon_dev")))


In [None]:
# test once (COMMENTED OUT - UNCOMMENT IF NEEDED)
# calculate_regional_mon_dev(filename = os.listdir("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/mon_totd_regrid")[1], mon_tot_path = "../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/mon_totd_regrid", output_path = "../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/new_mon_dev_regrid" )

In [None]:
# set up the pool for parallel processing (bias-corrected data)
pool = Pool(processes = 36)
# run in parallel for bias-corrected data
pool.starmap(calculate_regional_mon_dev, zip(os.listdir("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/bc_mon_totd_regrid"), itertools.repeat("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/bc_mon_totd_regrid"), itertools.repeat("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/bc_new_mon_dev_regrid"), itertools.repeat(df_baseline))) 
pool.close()

In [None]:
# repeat for raw CMIP6 data
pool = Pool(processes = 36)
# run in parallel for raw data
pool.starmap(calculate_regional_mon_dev, zip(os.listdir("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/mon_totd_regrid"), itertools.repeat("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/mon_totd_regrid"), itertools.repeat("../../sharepoint/Data/Climate projections/cmip6_ng_adm1_new/new_mon_dev_regrid"), itertools.repeat(df_baseline)))
pool.close()