# Climate projections

In [None]:
# basic
import os 
import numpy as np
import xarray as xr
import geopandas as gpd
import glob 

# climate-related
import pyet
import gcsfs
import intake
import cftime
import regionmask
from xclim import core 
from xclim import sdba
from xclim import set_options

# others
import xesmf as xe
from datetime import datetime
from tqdm.notebook import tqdm

import warnings
warnings.simplefilter("ignore") 

import beepy as beep

In [None]:
# spatial extent
lat_coords = np.arange(-56,-40, 0.5)
lon_coords = np.arange(-76,-67, 0.5)

# periods 
baseline_period = slice("1985-01-01", "2019-12-31") # ISIMIP3b bias adjustment protocol
inter_future_period = slice("2020-01-01", "2059-12-31") # Future period to bias correct
far_future_period   = slice("2060-01-01", "2099-12-31") # Future period to bias correct
full_period       = slice("1980-01-01", "2099-12-31") # period

periods = [baseline_period, inter_future_period, far_future_period]

# encoding / chuncks
chunks_dict   = {"lon": 10, "lat": 10, "time": -1}
encode_t2m    = {'t2m':  {'dtype': 'int16', 'scale_factor': 0.01, '_FillValue': -9999}}
encode_tasmin = {'tasmin':  {'dtype': 'int16', 'scale_factor': 0.01, '_FillValue': -9999}}
encode_tasmax = {'tasmax':  {'dtype': 'int16', 'scale_factor': 0.01, '_FillValue': -9999}}
encode_pr     = {"pr": {"zlib": True, "complevel": 1, "dtype": "float32"}}
encode_pr_alt = {'pr':      {'dtype': 'int16', 'scale_factor': 0.01, '_FillValue': -9999}}
encode_pet    = {'pet':     {'dtype': 'int16', 'scale_factor': 0.01, '_FillValue': -9999}}
encode_all    = {**encode_tasmin, **encode_tasmax, **encode_pr_alt}

# not all variables or scenarios: 
# ACCESS-CM2, CMCC-CM2-SR5, IITM-ESM, FGOALS-g3

# problem member id
# BCC-CSM2-MR, 

# two outputs per ssp
# MPI-ESM1-2-HR 

os.chdir('/home/rooda/Dropbox/')

# check what gcm to use (not only generic list) 
gcm_list  = ["GFDL-ESM4", "IPSL-CM6A-LR", "MIROC6", "MPI-ESM1-2-LR", "MRI-ESM2-0"]  

ssp_list  = ["ssp126", "ssp585"] 

## 1. Download and preprocess selected GCMs

In [None]:
def fix_time(ds):
    ds = ds.convert_calendar(calendar = 'gregorian', align_on = 'date', missing = np.nan)
    ds = ds.sel(time=full_period)
    return ds

url = "https://storage.googleapis.com/cmip6/pangeo-cmip6.json"
dataframe = intake.open_esm_datastore(url)

dataframe = dataframe.search(experiment_id = ['historical'] + ssp_list, # scenarios  
                             table_id      = 'day', # time-step
                             variable_id   = ['tasmax', 'tasmin',  'pr'], # variables
                             source_id     = gcm_list,
                             member_id     = "r1i1p1f1")  # models

kwargs = {'zarr_kwargs':{'consolidated':True,'use_cftime':True},'aggregate':True}

datasets = dataframe.to_dataset_dict(preprocess = fix_time, **kwargs)
datasets.keys()
beep.beep(1)

In [None]:
datasets.keys()

In [None]:
for gcm in tqdm(gcm_list):
    gcm_historical = next(val for key, val in datasets.items() if gcm + ".historical" in key)
    gcm_historical = gcm_historical.convert_calendar(calendar = 'gregorian', align_on = 'date', missing = np.nan)
    gcm_historical = gcm_historical.sel(time = slice("1980-01-01", "2014-12-31"))
    
    for ssp in tqdm(ssp_list, leave = False):
        gcm_ssp = next(val for key, val in datasets.items() if gcm + "." + ssp in key)
        gcm_ssp = gcm_ssp.convert_calendar(calendar = 'gregorian', align_on = 'date', missing = np.nan)
        
        gcm_ssp = xr.concat([gcm_historical, gcm_ssp], dim = "time")
        gcm_ssp = gcm_ssp.chunk(dict(time=-1, lat = 10, lon = 10))
        gcm_ssp = gcm_ssp.interpolate_na(dim="time", method="linear")

        gcm_ssp = gcm_ssp.drop(["height", "dcpp_init_year", "member_id", "time_bounds"])
        gcm_ssp = gcm_ssp.sel(member_id=0, drop=True).sel(dcpp_init_year=0, drop=True)
        gcm_ssp.coords['lon'] = (gcm_ssp.coords['lon'] + 180) % 360 - 180
        gcm_ssp = gcm_ssp.sortby(gcm_ssp.lon)
        gcm_ssp = gcm_ssp.interp(lat = lat_coords, lon = lon_coords)

        if 'bnds' in gcm_ssp:
            if len(gcm_ssp.bnds.values) > 0:
                gcm_ssp = gcm_ssp.sel(bnds = 1, drop = True)
        
        # change units
        gcm_ssp["tasmin"] = gcm_ssp.tasmin - 273.15 # to degC
        gcm_ssp["tasmax"] = gcm_ssp.tasmax - 273.15 # to degC
        gcm_ssp["pr"]  = (gcm_ssp.pr*84600) # to mm day-1
        gcm_ssp.pr.attrs["units"]      = "mm d-1"
        gcm_ssp.tasmax.attrs["units"]  = "degC"
        gcm_ssp.tasmin.attrs["units"]  = "degC"
        gcm_ssp.to_netcdf("/home/rooda/Hydro_results/future_raw/" + gcm + "_" + ssp + ".nc")
        beep.beep(1)
        
beep.beep(2)

## 2. Bias correction

## 2.1 Baseline data

In [None]:
# source data
pmet_pp   = xr.open_dataset("Patagonia/Data/Zenodo/v11/PP_PMETsim_1980_2020_v11d.nc",   chunks = chunks_dict)
pmet_tmax = xr.open_dataset("Patagonia/Data/Zenodo/v11/Tmax_PMETsim_1980_2020_v11d.nc", chunks = chunks_dict)
pmet_tmin = xr.open_dataset("Patagonia/Data/Zenodo/v11/Tmin_PMETsim_1980_2020_v11d.nc", chunks = chunks_dict)

regridder = xe.Regridder(pmet_pp, pmet_tmax, "nearest_s2d")
pmet_pp   = regridder(pmet_pp)

pmet_hist = xr.merge([pmet_pp, pmet_tmax, pmet_tmin]).sel(time = baseline_period)
pmet_hist = pmet_hist.rename({'longitude': 'lon','latitude': 'lat', 'pp':'pr', 'tmax':'tasmax','tmin':'tasmin'})

# subset area
shape = gpd.read_file("Patagonia/GIS South/Basins_Patagonia_all.shp")[["geometry"]]
shape = shape.buffer(0.20) 
mask  = regionmask.mask_geopandas(shape, pmet_hist)   >= 0
pmet_hist   = pmet_hist.where(mask, drop = True)

pmet_hist.pr.attrs["units"]    = "mm d-1"
pmet_hist.tasmax.attrs["units"]  = "degC"
pmet_hist.tasmin.attrs["units"]  = "degC"

pmet_hist["pr"]  = pmet_hist["pr"].astype("float32")
pmet_hist = pmet_hist.chunk("auto")

In [None]:
# due to RAM constrains
pmet_hist.pr.to_netcdf("/home/rooda/Hydro_results/PMETsim_historical_pr.nc")
pmet_hist.tasmax.to_netcdf("/home/rooda/Hydro_results/PMETsim_historical_tasmax.nc")
pmet_hist.tasmin.to_netcdf("/home/rooda/Hydro_results/PMETsim_historical_tasmin.nc")

## 2.2 Interp 

In [None]:
# load
pmet_hist = xr.open_dataset("/home/rooda/Hydro_results/PMETsim_historical_tasmax.nc", chunks = "auto")

for gcm in tqdm(gcm_list):    
    for ssp in tqdm(ssp_list, leave = False):
        for period in tqdm(periods, leave = False):
            
            model_ssp = xr.open_dataset("/home/rooda/Hydro_results/future_raw/" + gcm + "_" + ssp + ".nc", chunks = "auto")
            model_ssp = model_ssp.sel(time  = period)

            model_ssp = model_ssp.interp(lat = pmet_hist.lat, lon = pmet_hist.lon,  method = "linear", kwargs={"fill_value": "extrapolate"})
            model_ssp = model_ssp.where(pmet_hist.tasmax[0].notnull())
            model_ssp = model_ssp.chunk("auto")
            model_ssp = model_ssp.astype("float32").unify_chunks()
            model_ssp['time'] = model_ssp.indexes['time'].normalize()
            
            years = str(period)[7:11] + "_" + str(period)[21:25]

            # save file
            model_ssp.tasmax.to_netcdf("/home/rooda/Hydro_results/future_interp/TASMAX_" + gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_tasmax)
            model_ssp.tasmin.to_netcdf("/home/rooda/Hydro_results/future_interp/TASMIN_" +  gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_tasmin)           
            model_ssp.pr.to_netcdf("/home/rooda/Hydro_results/future_interp/PP_" + gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_pr_alt)
            model_ssp.close()

## 2.3 MBC

In [None]:
chunks_dict = {"lon": 50, "lat": 50, "time": -1}

# load
pmet_hist_pr     = xr.open_dataset("/home/rooda/Hydro_results/PMETsim_historical_pr.nc", chunks = chunks_dict).pr
pmet_hist_tasmax = xr.open_dataset("/home/rooda/Hydro_results/PMETsim_historical_tasmax.nc", chunks = chunks_dict).tasmax
pmet_hist_tasmin = xr.open_dataset("/home/rooda/Hydro_results/PMETsim_historical_tasmin.nc", chunks = chunks_dict).tasmin

In [None]:
for gcm in tqdm(gcm_list):    
    for ssp in tqdm(ssp_list, leave = False):
        
        model_baseline_tasmax = xr.open_dataset("/home/rooda/Hydro_results/future_interp/TASMAX_" + gcm + "_" + ssp + "_1985_2019.nc", chunks = chunks_dict).tasmax
        model_baseline_tasmin = xr.open_dataset("/home/rooda/Hydro_results/future_interp/TASMIN_" + gcm + "_" + ssp + "_1985_2019.nc", chunks = chunks_dict).tasmin
        model_baseline_pr     = xr.open_dataset("/home/rooda/Hydro_results/future_interp/PP_" + gcm + "_" + ssp + "_1985_2019.nc", chunks = chunks_dict).pr
        model_baseline_pr_ad, pth, dP0 = sdba.processing.adapt_freq(pmet_hist_pr, model_baseline_pr, thresh="0.1 mm d-1", group="time")

        for period in tqdm(periods[1:3], leave = False):
            
            years = str(period)[7:11] + "_" + str(period)[21:25]

            model_future_tasmax = xr.open_dataset("/home/rooda/Hydro_results/future_interp/TASMAX_" + gcm + "_" + ssp + "_" + years + ".nc", chunks = chunks_dict).tasmax
            model_future_tasmin = xr.open_dataset("/home/rooda/Hydro_results/future_interp/TASMIN_" + gcm + "_" + ssp + "_" + years + ".nc", chunks = chunks_dict).tasmin
            model_future_pr     = xr.open_dataset("/home/rooda/Hydro_results/future_interp/PP_" + gcm + "_" + ssp + "_" + years + ".nc", chunks = chunks_dict).pr
            
            # a) Perform an initial univariate adjustment 
            qdm_tmax = sdba.QuantileDeltaMapping.train(ref = pmet_hist_tasmax,  hist = model_baseline_tasmax,  kind = "+", nquantiles=20, group="time.month")
            qdm_tmin = sdba.QuantileDeltaMapping.train(ref = pmet_hist_tasmin,  hist = model_baseline_tasmin,  kind = "+", nquantiles=20, group="time.month")
            qdm_pp   = sdba.QuantileDeltaMapping.train(ref = pmet_hist_pr,      hist = model_baseline_pr_ad,   kind = "*", nquantiles=20, group="time.month")

            qdm_tmax  = qdm_tmax.adjust(model_future_tasmax, extrapolation="constant", interp="nearest").transpose('time', 'lat', 'lon')
            qdm_tmin  = qdm_tmin.adjust(model_future_tasmin, extrapolation="constant", interp="nearest").transpose('time', 'lat', 'lon')
            qdm_pp    = qdm_pp.adjust(model_future_pr,       extrapolation="constant", interp="nearest").transpose('time', 'lat', 'lon')

            dref      = xr.Dataset(dict(tasmax = pmet_hist_tasmax,      tasmin = pmet_hist_tasmin,      pr = pmet_hist_pr))
            scen_hist = xr.Dataset(dict(tasmax = model_baseline_tasmax, tasmin = model_baseline_tasmin, pr = model_baseline_pr_ad))
            scen_ssp  = xr.Dataset(dict(tasmax = qdm_tmax,              tasmin = qdm_tmin,              pr = qdm_pp))
            #scen_hist["time"] = dref.time # correct date (15 -> 01)

            ## b) Stack the variables to multivariate arrays and standardize them
            ref   = sdba.processing.stack_variables(dref) # Stack the variables (tasmax and pr)
            scenh = sdba.processing.stack_variables(scen_hist)
            scens = sdba.processing.stack_variables(scen_ssp)

            ref, _, _          = sdba.processing.standardize(ref) # Standardize
            allsim, savg, sstd = sdba.processing.standardize(xr.concat((scenh, scens), "time"))

            hist = allsim.sel(time = scenh.time)
            sim  = allsim.sel(time = scens.time)

            ## c) Perform the N-dimensional probability density function transform
            out = sdba.adjustment.NpdfTransform.adjust(ref, hist, sim, base=sdba.QuantileDeltaMapping, base_kws={"nquantiles": 20, "group": "time.month"}, n_iter=20)  
            model_ssp_bc = sdba.processing.reordering(out, scens, group="time.month")
            model_ssp_bc = sdba.processing.unstack_variables(model_ssp_bc)

            ## d) Restoring the trend
            model_ssp_bc = sdba.processing.reordering(sim, scens, group="time")
            model_ssp_bc = sdba.processing.unstack_variables(model_ssp_bc)
            model_ssp_bc = model_ssp_bc.transpose('time', 'lat', 'lon')
            model_ssp_bc["tasmin"] = model_ssp_bc.tasmin.where(model_ssp_bc.tasmax > model_ssp_bc.tasmin, model_ssp_bc.tasmax ) # force check Tmax > Tmin
            
            # save file
            model_ssp_bc.tasmax.to_netcdf("/home/rooda/Hydro_results/future_corrected/TASMAX_" + gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_tasmax)
            model_ssp_bc.tasmin.to_netcdf("/home/rooda/Hydro_results/future_corrected/TASMIN_" +  gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_tasmin)           
            model_ssp_bc.pr.to_netcdf("/home/rooda/Hydro_results/future_corrected/PP_" + gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_pr_alt)


# Potential evaporation (+ average temperature)

In [None]:
for gcm in tqdm(gcm_list):   
    for ssp in tqdm(ssp_list, leave = False):
        for period in tqdm(periods[1:3], leave = False):
            
            years = str(period)[7:11] + "_" + str(period)[21:25]

            model_tmax = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/TASMAX_" +  gcm + "_" + ssp + "_" + years + ".nc", chunks = "auto")
            model_tmin = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/TASMIN_" +  gcm + "_" + ssp + "_" + years + ".nc", chunks = "auto")
            model_tavg = (model_tmax.tasmax + model_tmin.tasmin)/2
            model_tavg = model_tavg.rename("t2m").to_dataset()
            
            lat  = model_tavg.lat * np.pi / 180  

            model_pet  = pyet.hargreaves(model_tavg.t2m, model_tmax.tasmax, model_tmin.tasmin, lat = lat, clip_zero = True)
            model_pet  = model_pet.rename("pet")
            
            model_tavg.to_netcdf("/home/rooda/Hydro_results/future_corrected/T2M_" +  gcm + "_" + ssp + "_" + years + ".nc", encoding = encode_t2m)
            model_pet.to_netcdf("/home/rooda/Hydro_results/future_corrected/PET_" +  gcm + "_" + ssp + "_" + years + ".nc",  encoding = encode_pet)

## Concat

In [None]:
for gcm in tqdm(gcm_list):   
    for ssp in tqdm(ssp_list, leave = False):

            # read
            model_tmax = xr.open_mfdataset("/home/rooda/Hydro_results/future_corrected/TASMAX_{}_{}*.nc".format(gcm,ssp), chunks = "auto")
            model_tavg = xr.open_mfdataset("/home/rooda/Hydro_results/future_corrected/T2M_{}_{}*.nc".format(gcm,ssp),    chunks = "auto")
            model_tmin = xr.open_mfdataset("/home/rooda/Hydro_results/future_corrected/TASMIN_{}_{}*.nc".format(gcm,ssp), chunks = "auto")
            model_pet  = xr.open_mfdataset("/home/rooda/Hydro_results/future_corrected/PET_{}_{}*.nc".format(gcm,ssp),    chunks = "auto")
            model_prcp = xr.open_mfdataset("/home/rooda/Hydro_results/future_corrected/PP_{}_{}*.nc".format(gcm,ssp),     chunks = "auto")

            # save
            model_tmax.to_netcdf("/home/rooda/Hydro_results/future_corrected/TASMAX_{}_{}.nc".format(gcm,ssp), encoding = encode_tasmax)
            model_tavg.to_netcdf("/home/rooda/Hydro_results/future_corrected/T2M_{}_{}.nc".format(gcm,ssp),    encoding = encode_t2m)
            model_tmin.to_netcdf("/home/rooda/Hydro_results/future_corrected/TASMIN_{}_{}.nc".format(gcm,ssp), encoding = encode_tasmin)
            model_pet.to_netcdf( "/home/rooda/Hydro_results/future_corrected/PET_{}_{}.nc".format(gcm,ssp),    encoding = encode_pet)
            model_prcp.to_netcdf("/home/rooda/Hydro_results/future_corrected/PP_{}_{}.nc".format(gcm,ssp),     encoding = encode_pr_alt)

            files = glob.glob("/home/rooda/Hydro_results/future_corrected/*{}_{}_*.nc".format(gcm,ssp))
            for file in files: 
                os.remove(file)

## Test

In [None]:
for gcm in tqdm(gcm_list):   
    for ssp in tqdm(ssp_list, leave = False):

            # read
            model_tmax = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/TASMAX_{}_{}.nc".format(gcm,ssp), chunks = "auto").tasmax
            model_tavg = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/T2M_{}_{}.nc".format(gcm,ssp),    chunks = "auto").t2m
            model_tmin = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/TASMIN_{}_{}.nc".format(gcm,ssp), chunks = "auto").tasmin
            model_pet  = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/PET_{}_{}.nc".format(gcm,ssp),    chunks = "auto").pet
            model_prcp = xr.open_dataset("/home/rooda/Hydro_results/future_corrected/PP_{}_{}.nc".format(gcm,ssp),     chunks = "auto").pr

            # correct coordinates
            assert model_tmax.lon.min() < 0
            assert model_tavg.lon.min() < 0
            assert model_tmin.lon.min() < 0
            assert model_pet.lon.min()  < 0
            assert model_prcp.lon.min() < 0

            # check for na values
            assert np.all(model_tmax.sel(lat = -45, lon = -72, method = "nearest")[:].data < 1e5)
            assert np.all(model_tavg.sel(lat = -45, lon = -72, method = "nearest")[:].data < 1e5)
            assert np.all(model_tmin.sel(lat = -45, lon = -72, method = "nearest")[:].data < 1e5)
            assert np.all(model_pet.sel(lat = -45, lon = -72, method = "nearest")[:].data < 1e5)
            assert np.all(model_prcp.sel(lat = -45, lon = -72, method = "nearest")[:].data < 1e5)

            # "normal" values  
            assert model_prcp.max().values > 10
            assert model_tavg.mean().values > 5

            # check for negative values