In [1]:
import ee
#%matplotlib inline
import math, json
import warnings
#import intake
warnings.filterwarnings('ignore')
from collections import defaultdict
#import matplotlib
#import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import seaborn as sns
from scipy import stats
#import cftime
import xarray as xr
xr.set_options(display_style='html')
import s3fs
#import spei

import datetime, calendar

#plt.style.use("seaborn-darkgrid")
import coiled
#ee.Authenticate()
ee.Initialize()

In [2]:
#https://pynative.com/python-serialize-numpy-ndarray-into-json/
from json import JSONEncoder
class NumpyArrayEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return JSONEncoder.default(self, obj)

In [34]:
HIST_START = 1980
HIST_END = 2014

FUTURE_START = 2080
FUTURE_END = 2100

PERCENTILE_STARTYEAR = 1980
PERCENTILE_ENDYEAR = 2019

NUM_BEST_MODELS = 3

In [4]:
CITYLATLON = {}
with open('ghsl_500k.csv', 'r') as ifile:
    for line in ifile.readlines():
        items = [i.strip() for i in line.split(',')]
        CITYLATLON['city_{0}'.format(items[0])] = (float(items[2]), float(items[3]), int(items[0]))

In [4]:
with open('citiesonly.csv', 'r', encoding='utf-8') as ifile:
    CITYLATLON = {i[0]: (float(i[1]), float(i[2]), i[3]) for i in [j.split(',') + [idx] for idx, j in enumerate(ifile.readlines()[1:])]}

In [5]:
MODELS = {
    'tasmax': ('GFDL-ESM4', 'CanESM5', 'MRI-ESM2-0', 'IPSL-CM6A-LR', 'EC-Earth3-Veg-LR'),
    'tasmin': ('GFDL-ESM4', 'IPSL-CM6A-LR', 'CanESM5', 'MRI-ESM2-0', 'EC-Earth3-Veg-LR'),
    'pr': ('GFDL-ESM4', 'IPSL-CM6A-LR', 'CanESM5', 'EC-Earth3-Veg-LR'),
    'hurs': ('GFDL-ESM4', 'CanESM5', 'MRI-ESM2-0', 'IPSL-CM6A-LR', 'EC-Earth3-Veg-LR'),
    'sfcWind': ('GFDL-ESM4', 'CanESM5', 'IPSL-CM6A-LR')
}

MODELRUN = 'r1i1p1f1'

MODELGRID = {
    'GFDL-ESM4': 'gr1',
    'CanESM5': 'gn',
    'MRI-ESM2-0': 'gn',
    'IPSL-CM6A-LR': 'gr',
    'EC-Earth3-Veg-LR': 'gr'
}

YEARLENGTH = {
    'GFDL-ESM4': 365,
    'CanESM5': 365,
    'MRI-ESM2-0': 366,
    'IPSL-CM6A-LR': 366,
    'EC-Earth3-Veg-LR': 366
}

In [6]:
MODEL_FAMILY = {'UKESM1-0-LL': 'HadAM',
 'NorESM2-MM': 'CCM',
 'NorESM2-LM': 'CCM',
 'MRI-ESM2-0': 'UCLA GCM',
 'MPI-ESM1-2-LR': 'ECMWF',
 'MPI-ESM1-2-HR': 'ECMWF',
 'MIROC6': 'MIROC',
 'MIROC-ES2L': 'MIROC',
 'KIOST-ESM': 'GFDL',
 'KACE-1-0-G': 'HadAM',
 'IPSL-CM6A-LR': 'IPSL',
 'INM-CM5-0': 'INM',
 'INM-CM4-8': 'INM',
 'HadGEM3-GC31-MM': 'HadAM',
 'HadGEM3-GC31-LL': 'HadAM',
 'GFDL-ESM4': 'GFDL',
 'GFDL-CM4_gr2': 'GFDL',
 'GFDL-CM4': 'GFDL',
 'FGOALS-g3': 'CCM',
 'EC-Earth3-Veg-LR': 'ECMWF',
 'EC-Earth3': 'ECMWF',
 'CanESM5': 'CanAM',
 'CNRM-ESM2-1': 'ECMWF',
 'CNRM-CM6-1': 'ECMWF',
 'CMCC-ESM2': 'CCM',
 'CMCC-CM2-SR5': 'CCM',
 'BCC-CSM2-MR': 'CCM',
 'ACCESS-ESM1-5': 'HadAM',
 'ACCESS-CM2': 'HadAM',
 'TaiESM1': 'CCM',
}

In [7]:
VARIABLES = {
    'tasmax': {
        'era_varname': 'maximum_2m_air_temperature',
        'nex_transform': lambda x: x - 273.5,
        'era_transform': lambda x: x - 273.5
    },
    'tasmin': {
        'era_varname': 'minimum_2m_air_temperature',
        'nex_transform': lambda x: x - 273.5,
        'era_transform': lambda x: x - 273.5
    },

    'pr': {
        'era_varname': 'total_precipitation',
        'nex_transform': lambda x: x * 86400,
        'era_transform': lambda x: x * 1000
    },
    'hurs': {
        'era_varname': None,
        'nex_transform': lambda x: x,
        'era_transform': lambda x: x
    },
    'sfcWind': {
        'era_varname': None,
        'nex_transform': lambda x: x * 3600 / 1000,
        'era_transform': lambda x: x * 3600 / 1000
    },
}

In [8]:
def returnperiod_value_daily(nex_varname, rp, latlon):
    era_varname = VARIABLES[nex_varname]['era_varname']
    hist_start = PERCENTILE_STARTYEAR
    hist_end = PERCENTILE_ENDYEAR
    allyears = []
    for year in range(PERCENTILE_STARTYEAR, PERCENTILE_ENDYEAR):
        allyears.append(VARIABLES[nex_varname]['era_transform'](get_eravar(era_varname, latlon, start_year=year, end_year=year, southern_hem=False)))
    d = np.sort(np.concatenate(allyears).flatten())
    d = d[d > 0.01]  # Only consider actual positive events
    vals, counts = np.unique(d, return_counts=True)
    freqs = counts / d.size
    cdf_y = np.cumsum(freqs)
    targetfreq = (PERCENTILE_ENDYEAR - PERCENTILE_STARTYEAR + 1) / rp
    return np.interp(1-targetfreq, vals, cdf_y)
    

def calendardate_percentiles(nex_varname, q, latlon, sh_hem=False):
    era_varname = VARIABLES[nex_varname]['era_varname']
    hist_start = PERCENTILE_STARTYEAR
    hist_end = PERCENTILE_ENDYEAR
    allyears = []
    for year in range(PERCENTILE_STARTYEAR, PERCENTILE_ENDYEAR):
        allyears.append(VARIABLES[nex_varname]['era_transform'](get_eravar(era_varname, latlon, start_year=year, end_year=year, southern_hem=False)))
    if not sh_hem:
        return np.percentile(np.vstack(allyears), q, axis=0)
    else:
        res = np.percentile(np.vstack(allyears), q, axis=0)
        return np.concatenate([res[152:], res[:152]])

def wholeyear_percentile(nex_varname, q, latlon):
    era_varname = VARIABLES[nex_varname]['era_varname']
    hist_start = PERCENTILE_STARTYEAR
    hist_end = PERCENTILE_ENDYEAR
    allyears = []
    for year in range(hist_start, hist_end):
        allyears.append(VARIABLES[nex_varname]['era_transform'](get_eravar(era_varname, latlon, start_year=year, end_year=year, southern_hem=False)))
    return np.percentile(np.concatenate(allyears).flatten(), q)

def yearextreme_percentile(nex_varname, q, latlon, wantmax):
    era_varname = VARIABLES[nex_varname]['era_varname']
    hist_start = PERCENTILE_STARTYEAR
    hist_end = PERCENTILE_ENDYEAR
    allyears = []
    for year in range(hist_start, hist_end):
        allyears.append([np.min, np.max][int(wantmax)](VARIABLES[nex_varname]['era_transform'](get_eravar(era_varname, latlon, start_year=year, end_year=year, southern_hem=False))))
    return np.percentile(np.array(allyears), q)

def thresholdexceedance_mediancount(nex_varname, threshold, latlon, want_gte):
    era_varname = VARIABLES[nex_varname]['era_varname']
    data = VARIABLES[nex_varname]['era_transform'](get_eravar(era_varname, latlon, start_year=PERCENTILE_STARTYEAR, end_year=PERCENTILE_ENDYEAR, southern_hem=False))
    if data.size % 365 != 0:
        raise Exception('Data array length is not an integer multiple of 365')
    byyear = data.reshape(data.size//365, 365)
    if want_gte:
        return np.median(np.sum(byyear >= threshold, axis=1))
    else:
        return np.median(np.sum(byyear <= threshold, axis=1))


def get_rmsd(d1, d2):
    c1 = seasonal_means(d1)
    c2 = seasonal_means(d2)
    return np.sqrt(np.mean(np.sum((c1 - c2)**2)))

def count_runs(tf_array, min_runsize):
    falses = np.zeros(tf_array.shape[0]).reshape((tf_array.shape[0],1))
    extended_a = np.concatenate([[0], tf_array, [0]])
    df = np.diff(extended_a)
    starts = np.nonzero(df == 1)[0]
    ends = np.nonzero(df == -1)[0]
    count = 0
    for idx in range(starts.size):
        if ends[idx] - starts[idx] >= min_runsize:
            count += 1
    return count

def removeLeapDays(arr, start_year, end_year, extralong=False, southern_hem=False):
    if extralong:
        indices = list(range(184))
        jan1_idx = 184

        for year in range(start_year, end_year+1):
            indices += [jan1_idx + i for i in range(365)]
            jan1_idx += 365
            if calendar.isleap(year):
                jan1_idx += 1
        return arr[indices]
    elif not southern_hem:
        indices = []
        jan1_idx = 0
        for year in range(start_year, end_year+1):
            indices += [jan1_idx + i for i in range(365)]
            jan1_idx += 365
            if calendar.isleap(year):
                jan1_idx += 1
        return arr[indices]
    else:
        indices = []
        jul1_idx = 0
        for year in range(start_year-1, end_year):
            indices += [jul1_idx + i for i in range(365)]
            jul1_idx += 365
            if calendar.isleap(year):
                jul1_idx += 1
        return arr[indices]

def get_eravar(varname, latlon, start_year, end_year, southern_hem=False, extralong=False):
    model = 'ERA5'
    dataset = ee.ImageCollection("ECMWF/ERA5/DAILY")
    gee_geom = ee.Geometry.Point((latlon[1], latlon[0]))
    if extralong:
        data_vars = dataset.select(varname).filter(ee.Filter.date('{0}-07-01'.format(start_year-1), '{0}-01-01'.format(end_year+1)))
    elif not southern_hem:
        data_vars = dataset.select(varname).filter(ee.Filter.date('{0}-01-01'.format(start_year), '{0}-01-01'.format(end_year+1)))
    else:
        data_vars = dataset.select(varname).filter(ee.Filter.date('{0}-07-01'.format(start_year-1), '{0}-07-01'.format(end_year)))
    result = [i[4] for i in data_vars.getRegion(gee_geom, 2500, 'epsg:4326').getInfo()[1:]]
    return removeLeapDays(np.array(result), start_year, end_year, extralong=extralong, southern_hem=southern_hem)

def get_var(varname, model, loc_id, start_year, end_year, southern_hem=False, extralong=False, scenario='ssp585'):
    scenario = [scenario, 'historical'][int(start_year < 2015)]
    dataset = datasets[model]
    if extralong:
        dates = ('{0}-07-01'.format(start_year-1), '{0}-12-31'.format(end_year))
    elif not southern_hem:
        dates = ('{0}-01-01'.format(start_year), '{0}-12-31'.format(end_year))
    else:
        dates = ('{0}-07-01'.format(start_year-1), '{0}-06-30'.format(end_year))
    
    ds = dataset.get_timeseries(dates, loc_id)
    if YEARLENGTH[model] == 366:
        return removeLeapDays(np.array(ds), start_year, end_year, extralong=extralong, southern_hem=southern_hem)
    else:
        return np.array(ds)

def quarters(d, start_year, end_year, southern_hem=False):
    q2 = []  # 60-151
    q3 = []  # 152-243
    q4 = []  # 244-334
    q1 = []  # 335-59
    if not southern_hem:
        jan1_idx = 365
        for year in range(start_year, end_year):
            tmp = np.concatenate((d[jan1_idx - 365 : jan1_idx - 365 + 60], d[jan1_idx + 335 : jan1_idx + 365]), axis=0)
            q1.append(tmp)
            q2.append(d[jan1_idx + 60 : jan1_idx + 152])
            q3.append(d[jan1_idx + 152 : jan1_idx + 244])
            q4.append(d[jan1_idx + 244 : jan1_idx + 335])

            jan1_idx += 365 + [0, 0][int(False and calendar.isleap(year))]
        mam_res = np.vstack(q2)
        jja_res = np.vstack(q3)
        son_res = np.vstack(q4)
        djf_res = np.vstack(q1)
    else:
        jul1_idx = 365
        for year in range(start_year, end_year+1):
            tmp = np.concatenate((d[jul1_idx - 365 : jul1_idx - 365 + 60], d[jul1_idx + 335 : jul1_idx + 365]), axis=0)
            q3.append(tmp)
            q4.append(d[jul1_idx + 60 : jul1_idx + 152])
            q1.append(d[jul1_idx + 152 : jul1_idx + 244])
            q2.append(d[jul1_idx + 244 : jul1_idx + 335])

            jul1_idx += 365 + [0, 0][int(False and calendar.isleap(year))]
        mam_res = np.vstack(q4)
        jja_res = np.vstack(q1)
        son_res = np.vstack(q2)
        djf_res = np.vstack(q3)
    return mam_res, jja_res, son_res, djf_res
    
def seasonal_means(d):
    q = quarters(d, HIST_START, HIST_END)
    return np.array([np.mean(q[0], axis=1), np.mean(q[1], axis=1), np.mean(q[2], axis=1), np.mean(q[3], axis=1)])

def calibration_function(hist_obs, hist_mod):
# Calibration functions are P-P plots of historical and modeled values

    source = np.sort(hist_obs.flatten())
    target= np.sort(hist_mod.flatten())
   
    if (np.max(source) == 0 and np.min(source) == 0):
        return np.arange(0, target.size) / target.size
    if (np.max(target) == 0 and np.min(target) == 0):
        return np.arange(0, source.size) / source.size
    new_indices = []

    for target_idx, target_value in enumerate(target):
        if target_idx < len(source):
            source_value = source[target_idx]
            if source_value > target[-1]:
                new_indices.append(target.size - 1)
            else:
                new_indices.append(np.argmax(target >= source_value))
    return np.array(new_indices) / source.size

def calibrate_component(uncalibrated_data, calibration_fxn):
    N = len(uncalibrated_data)
    unsorted_uncalib = [(i, idx) for idx, i in enumerate(uncalibrated_data)]
    sorted_uncalib = sorted(unsorted_uncalib)
    result = [0] * N
    for j in range(N):
        X_j = j / (N + 1)
        Y_jprime = calibration_fxn[math.floor(X_j * len(calibration_fxn))]
        jprime = math.floor(Y_jprime * (N + 1))
        result[sorted_uncalib[j][1]] = sorted_uncalib[min(len(sorted_uncalib)-1, jprime)][0]
    return result

def calibrate(uncalibrated_data, calibration_fxn):
    mam = []
    jja = []
    son = []
    djf = []
    mam_idx = []
    jja_idx = []
    son_idx = []
    djf_idx = []
    for idx, i in enumerate(uncalibrated_data):
        if idx % 365 >= 60 and idx % 365 < 152:
            mam.append(uncalibrated_data[idx])
            mam_idx.append(idx)
        elif idx % 365 >= 152 and idx % 365 < 244:
            jja.append(uncalibrated_data[idx])
            jja_idx.append(idx)
        elif idx % 365 >= 244 and idx % 365 < 335:
            son.append(uncalibrated_data[idx])
            son_idx.append(idx)
        else:
            djf.append(uncalibrated_data[idx])
            djf_idx.append(idx)
    
    mam_calib = calibrate_component(np.array(mam), calibration_fxn[0])
    jja_calib = calibrate_component(np.array(jja), calibration_fxn[1])
    son_calib = calibrate_component(np.array(son), calibration_fxn[2])
    djf_calib = calibrate_component(np.array(djf), calibration_fxn[3])
    
    result = [0] * len(uncalibrated_data)
    for i in range(len(mam_idx)):
        result[mam_idx[i]] = mam_calib[i]
    for i in range(len(jja_idx)):
        result[jja_idx[i]] = jja_calib[i]
    for i in range(len(son_idx)):
        result[son_idx[i]] = son_calib[i]
    for i in range(len(djf_idx)):
        result[djf_idx[i]] = djf_calib[i]

    return np.array(result)

def get_gamma(count, size):
    return np.random.gamma(shape = count + 0.5, size=size)
def get_beta(count, num, size):
    return np.random.beta(a = count + 0.5, b = num - count + 0.5, size=size)

In [28]:
MODEL_URI = {}
with open('modelinfo.csv', 'r') as ifile:
    for line in ifile.readlines():
        items = [i.strip() for i in line.split(',')]
        model, scenario, varname, the_uri = items
        MODEL_URI[(model, scenario, varname)] = the_uri
        
def uri(model, scenario, varname):
    return MODEL_URI[(model, scenario, varname)]

def s3open(path):
    fs = s3fs.S3FileSystem(anon=True)
    return s3fs.S3Map(path, s3=fs)


class Dataset:
    def __init__(self, varname, model, scenario, start_idx, end_idx):
        self.varname =varname
        self.model = model
        self.scenario = scenario
        
        print('Extracting {0} {1} {2}'.format(varname, scenario, model))
        
        thefile = s3open(uri(model, scenario, varname))
        ds = xr.open_mfdataset([thefile], engine='zarr', parallel=True)
        ds = ds[varname].sel(time=slice(['{0}-07-01'.format(FUTURE_START), '{0}-01-01'.format(HIST_START)][int(scenario=='historical')], ['{0}-12-31'.format(FUTURE_END), '{0}-12-31'.format(HIST_END)][int(scenario=='historical')])).sel(lat=[i[0] for i in list(CITYLATLON.values())[start_idx:end_idx]], lon=[[i[1], i[1]+360][int(i[1]<0)] for i in list(CITYLATLON.values())[start_idx:end_idx]], method='nearest')

        self.data = VARIABLES[varname]['nex_transform'](ds.to_numpy())

        self.times = [str(d)[:10] for d in ds.time.data]

        
    def get_timeseries(self, dates, loc_id):
        timestart_idx = self.times.index(dates[0])
        timeend_idx = self.times.index(dates[1]) + 1
        return self.data[timestart_idx:timeend_idx, loc_id, loc_id]

In [29]:
def get_histobs(varname, latlon):
    def relhum(T, Tdp):
        T = T.astype(np.float64)
        Tdp = Tdp.astype(np.float64)
        numerator = np.exp(17.625 * Tdp / (243.04 + Tdp))
        denominator = np.exp(17.625 * T / (243.04 + T))
        return 100 * numerator / denominator

    #print('  Getting historical {0}'.format(varname))
    if varname == 'hurs':
        era_dewpoint = VARIABLES['tasmax']['era_transform'](get_eravar('dewpoint_2m_temperature', latlon, HIST_START, HIST_END, southern_hem=False))
        era_maxtemp = VARIABLES['tasmax']['era_transform'](get_eravar(VARIABLES['tasmax']['era_varname'], latlon, HIST_START, HIST_END, southern_hem=False))
        hist_obs = relhum(era_maxtemp, era_dewpoint)
    elif varname == 'sfcWind':
        era_windspeedu = VARIABLES['sfcWind']['era_transform'](get_eravar('u_component_of_wind_10m', latlon, HIST_START, HIST_END, southern_hem=False))
        success = False
        counter = 0
        while not success:
            try:
                era_windspeedv = VARIABLES['sfcWind']['era_transform'](get_eravar('v_component_of_wind_10m', latlon, HIST_START, HIST_END, southern_hem=False))
                success = True
            except:
                if counter == 10:
                    return None
                else:
                    print('trying again', counter)
                    counter += 1
        hist_obs = np.sqrt(np.power(era_windspeedu, 2) + np.power(era_windspeedv, 2))
    else:
        hist_obs = VARIABLES[varname]['era_transform'](get_eravar(VARIABLES[varname]['era_varname'], latlon, HIST_START, HIST_END, southern_hem=False))
    return hist_obs

In [78]:
class Location:
    def __init__(self, params):
        name, loc_id, latlon, hist_obs, hist_mods, varname = params
        self.name = name
        self.loc_id = loc_id
        self.latlon = latlon
        self.hist_observed = hist_obs
        self.hist_modeled = hist_mods
        self.best_models = None
        self.calib_fxns = None


        #hist_mods = {}
        rmsds = []
        for model in MODELS[varname]:
            #print('    Getting {0}'.format(model))
            hist_obs = self.hist_observed#VARIABLES[varname]['nex_transform'](get_var(varname, model, self.loc_id, HIST_START, HIST_END, southern_hem=False, scenario='historical'))
            #hist_mods[model] = datasets[model].get_timeseries(('{0}-01-01'.format(HIST_START), '{0}-12-31'.format(HIST_END)), loc_id)#hist_mod
            hist_mod = hist_mods[model]
            rmsds.append((get_rmsd(hist_obs, hist_mod), model))
        rmsds.sort()

        best_models = []
        families = []
        idx = 0
        while len(best_models) < NUM_BEST_MODELS:
            if not MODEL_FAMILY[rmsds[idx][1]] in families:
                best_models.append(rmsds[idx][1])
                families.append(MODEL_FAMILY[rmsds[idx][1]])
            idx += 1

        #for m in best_models:
        #    print(m, [i[0] for i in rmsds if i[1]==m][0])
        #best_models = []
        #for idx in range(min(NUM_BEST_MODELS, len(MODELS[varname]))):
        #    best_models.append(rmsds[idx][1])

        self.hist_modeled = hist_mods
        self.best_models = best_models


    # Get calibration functions
        #print('  Getting calibration functions')
        self.calib_fxns = {}
        hist_obs = self.hist_observed
        hist_mod = self.hist_modeled
        for model in self.best_models:
            o_quarters = quarters(hist_obs, HIST_START, HIST_END)
            m_quarters = quarters(hist_mod[model], HIST_START, HIST_END)
            self.calib_fxns[model] = [calibration_function(o_quarters[i].flatten(), m_quarters[i].flatten()) for i in range(4)]


In [31]:
def do_city(lat, lon, loc_id, hist_obs):
    loc = client.submit(Location, cityname, loc_id, (lat, lon), hist_obs[loc_id], {m: datasets[m].get_timeseries(('{0}-01-01'.format(HIST_START), '{0}-12-31'.format(HIST_END)), loc_id) for m in MODELS[varname]})
    return((loc.best_models, loc.calib_fxns))

In [32]:
def get_histobs(varname, latlon):
    def relhum(T, Tdp):
        T = T.astype(np.float64)
        Tdp = Tdp.astype(np.float64)
        numerator = np.exp(17.625 * Tdp / (243.04 + Tdp))
        denominator = np.exp(17.625 * T / (243.04 + T))
        return 100 * numerator / denominator

    #print('  Getting historical {0}'.format(varname))
    if varname == 'hurs':
        era_dewpoint = VARIABLES['tasmax']['era_transform'](get_eravar('dewpoint_2m_temperature', latlon, HIST_START, HIST_END, southern_hem=False))
        era_maxtemp = VARIABLES['tasmax']['era_transform'](get_eravar(VARIABLES['tasmax']['era_varname'], latlon, HIST_START, HIST_END, southern_hem=False))
        hist_obs = relhum(era_maxtemp, era_dewpoint)
    elif varname == 'sfcWind':
        era_windspeedu = VARIABLES['sfcWind']['era_transform'](get_eravar('u_component_of_wind_10m', latlon, HIST_START, HIST_END, southern_hem=False))
        success = False
        counter = 0
        while not success:
            try:
                era_windspeedv = VARIABLES['sfcWind']['era_transform'](get_eravar('v_component_of_wind_10m', latlon, HIST_START, HIST_END, southern_hem=False))
                success = True
            except:
                if counter == 10:
                    return None
                else:
                    print('trying again', counter)
                    counter += 1
        hist_obs = np.sqrt(np.power(era_windspeedu, 2) + np.power(era_windspeedv, 2))
    else:
        hist_obs = VARIABLES[varname]['era_transform'](get_eravar(VARIABLES[varname]['era_varname'], latlon, HIST_START, HIST_END, southern_hem=False))
    return hist_obs

In [76]:
idx_start

12

In [77]:
%%time
cluster = coiled.Cluster(n_workers=25)
client = cluster.get_client()
for varname in ['pr', 'sfcWind']:#VARIABLES:
    
    for idx_start, idx_end in [(i * 12, (i+1)*12) for i in range(83)]:
        
        datasets = client.gather({
            (varname, model): client.submit(Dataset, varname, model, 'historical', idx_start, idx_end) for model in MODELS[varname]
        })
        
        city_params = []
        for cityname in list(CITYLATLON.keys())[idx_start: idx_end]:
            lat, lon, loc_id = CITYLATLON[cityname]
            hist_obs = get_histobs(varname, (lat, lon))
            hist_mod = {m: datasets[(varname, m)].get_timeseries(('{0}-01-01'.format(HIST_START), '{0}-12-31'.format(HIST_END)), loc_id-idx_start) for m in MODELS[varname]}
            city_params.append((cityname, loc_id-idx_start, (lat, lon), hist_obs, hist_mod, varname))
            
        locs_futs = client.map(Location, city_params)

        loclist = client.gather(locs_futs)
        #This is list of result locs -- each has best_odels and calib_fxns
        # need to use list index to get cityinfo which is in start-stop idxs
        for idx, loc in enumerate(loclist):
            best_models = loc.best_models
            calib_fxns = loc.calib_fxns
            with open('bestmodels_{0}.txt'.format(varname), 'a') as ofile:
                for m in best_models:
                    ofile.write('{0}\t{1}\t{2}\t{3}\n'.format(idx+idx_start, varname, m, json.dumps([a.tolist() for a in calib_fxns[m]])))
            #cluster.shutdown()
            

Output()

Output()

KeyError: 'MRI-ESM2-0'

In [74]:
idx_end

24

In [79]:
cluster.shutdown()

In [65]:
varname

'pr'

In [211]:
datasets

{'GFDL-ESM4': <__main__.Dataset at 0x1ce60cd42b0>,
 'CanESM5': <__main__.Dataset at 0x1ce3aaeb9a0>,
 'MRI-ESM2-0': <Future: cancelled, type: __main__.Dataset, key: Dataset-ce4d43fa167265c7915596c8c10eaf43>,
 'IPSL-CM6A-LR': <Future: cancelled, type: __main__.Dataset, key: Dataset-706dbe19d75f2112886f24d16feedee9>,
 'EC-Earth3-Veg-LR': <Future: cancelled, type: __main__.Dataset, key: Dataset-cec8588e1415bcf10adabe65bc26a1b1>}

In [42]:
# to retrieve results use this:
calib_fxns = {}
with open('bestmodels.txt', 'r') as ifile:
    for line in ifile.readlines():
        items = [i.strip() for i in line.split('\t')]
        calib_fxns[(items[0], items[3], items[4])] = {items[5]: items[6], items[7]: items[8], items[9]: items[10]}

In [34]:
for quarter in range(4):
    obs_10 = np.percentile(quarters(hist_obs_tx, HIST_START, HIST_END)[quarter], 10)
    obs_90 = np.percentile(quarters(hist_obs_tx, HIST_START, HIST_END)[quarter], 90)
    for model in best_models_tx:
        mod = quarters(hist_mods_tx[model] - 273.15, HIST_START, HIST_END)[quarter].flatten()
        print('{0}: min modeled value does not exceed observed 10th percentile  {1}'.format(model, min(mod) <= obs_10))
        print('{0}: max modeled value does not exceed observed 90th percentile  {1}'.format(model, max(mod) >= obs_90))


GFDL-CM4: min modeled value does not exceed observed 10th percentile  True
GFDL-CM4: max modeled value does not exceed observed 90th percentile  True
CanESM5: min modeled value does not exceed observed 10th percentile  True
CanESM5: max modeled value does not exceed observed 90th percentile  True
ACCESS-CM2: min modeled value does not exceed observed 10th percentile  True
ACCESS-CM2: max modeled value does not exceed observed 90th percentile  True
GFDL-CM4: min modeled value does not exceed observed 10th percentile  True
GFDL-CM4: max modeled value does not exceed observed 90th percentile  True
CanESM5: min modeled value does not exceed observed 10th percentile  True
CanESM5: max modeled value does not exceed observed 90th percentile  True
ACCESS-CM2: min modeled value does not exceed observed 10th percentile  True
ACCESS-CM2: max modeled value does not exceed observed 90th percentile  True
GFDL-CM4: min modeled value does not exceed observed 10th percentile  True
GFDL-CM4: max modeled