In [1]:
from pathlib import Path
import warnings
import cftime
import tqdm
import numpy as np
import xarray as xr
import xclim.indices as xci
from xclim.core.calendar import percentile_doy
from xclim.core.units import convert_units_to, to_agg_units
from xclim.indices.generic import threshold_count
from xclim.indicators import icclim
from config import *
from luts import varid_idx_lu


In [2]:
def rx1day(pr):
    """'Max 1-day precip' - the max daily precip value recorded for a year.

    Args:
        pr (xarray.DataArray): daily total precip values

    Returns:
        Max 1-day precip for each year
    """
    out = xci.max_n_day_precipitation_amount(pr, freq="YS")
    out.attrs["units"] = "mm"

    return out


def rx5day(pr):
    """'Max 5-day precip' - the max 5-day precip value recorded for a year.

    Args:
        pr (xarray.DataArray): daily total precip values

    Returns:
        Max 5-day precip for each year
    """
    out = xci.max_n_day_precipitation_amount(pr, 5, freq="YS")
    out.attrs["units"] = "mm"

    return out


def r10mm(pr):
    """'Heavy precip days' - number of days in a year with over 10mm of precip

    Args:
        pr (xarray.DataArray): daily total precip values

    Returns:
        Number of heavy precip days for each year
    """
    return icclim.R10mm(pr)


def cwd(pr):
    """'Consecutive wet days' - number of the most consecutive days with precip > 1 mm

    Args:
        pr (xarray.DataArray): daily total precip values

    Returns:
        Max number of consecutive wet days for each year
    """
    return xci.maximum_consecutive_wet_days(pr, thresh=f"1 mm/day", freq="YS")


def cdd(pr):
    """'Consecutive dry days' - number of the most consecutive days with precip < 1 mm

    Args:
        pr (xarray.DataArray): daily total precip values

    Returns:
        Max number of consecutive dry days for each year
    """
    return xci.maximum_consecutive_dry_days(pr, thresh=f"1 mm/day", freq="YS")


def convert_times_to_years(time_da):
    """Convert the time values in a time axis (DataArray) to integer year values. Handles cftime types and numpy.datetime64."""
    if time_da.values.dtype == np.dtype("<M8[ns]"):
        # just a double check that we have nanosecond precision since we will divide by 1e9 to get seconds
        assert len(str(time_da.values[0])) == 29
        cftimes = [
            cftime.num2date(t / 1e9, "seconds since 1970-01-01")
            for t in time_da.values.astype(int)
        ]
    elif isinstance(
        time_da.values[0],
        cftime._cftime.Datetime360Day,
    ) or isinstance(
        time_da.values[0],
        cftime._cftime.DatetimeNoLeap,
    ):
        cftimes = time_da.values

    years = [t.year for t in cftimes]

    return years


def compute_indicator(da, idx, coord_labels, kwargs={}):
    """Summarize a DataArray according to a specified index / aggregation function

    Args:
        da (xarray.DataArray): the DataArray object containing the base variable data to b summarized according to aggr
        idx (str): String corresponding to the name of the indicator to compute (assumes value is equal to the name of the corresponding global function)
        coord_labels (dict): dict with model and scenario as keys for labeling resulting xarray dataset coordinates.
        kwargs (dict): additional arguments for the index function being called

    Returns:
        A new data array with dimensions year, latitude, longitude, in that order containing the summarized information
    """
    new_da = (
        globals()[idx](da, **kwargs)
        # .transpose("time", "lat", "lon")
        # .reset_coords(["longitude", "latitude", "height"], drop=True)
    )
    new_da.name = idx
    # get the nodata mask from first time slice
    nodata = np.broadcast_to(np.isnan(da.sel(time=da["time"].values[0])), new_da.shape)
    # remask, because xclim switches nans to 0
    # xclim is inconsistent about the types returned.
    if new_da.dtype in [np.int32, np.int64]:
        new_da.values[nodata] = -9999
    else:
        new_da.values[nodata] = np.nan

    new_dims = list(coord_labels.keys())
    new_da = new_da.assign_coords(coord_labels).expand_dims(new_dims)
    # convert the time dimension to integer years instead of CF time objects
    years = convert_times_to_years(new_da.time)
    new_da = new_da.rename({"time": "year"}).assign_coords({"year": years})

    return new_da


def run_compute_indicators(fps, idx_list, coord_labels, kwargs={}):
    """Open connections to data files for a particular model variable, scenario, and model and compute all requested indicators.

    Args:
        fps (path-like): paths to the files for the variable required for creating the indicators variables
        idx_list (list): indices to derive using data in provided filepath
        # var_id (str): model variable being used for indices
        coord_labels (dict): dict with model and scenario as keys for labeling resulting xarray dataset coordinates.

    Returns:
        summary_das (tuple): tuple of the form (da, index, scenario, model), where da is a DataArray with dimensions of year (summary year), latitude (lat) and longitude (lon)
    """
    with xr.open_mfdataset(fps) as ds:
        var_id = ds.attrs["variable_id"]
        out = []
        for idx in idx_list:
            if idx in ["wsdi", "csdi"]:
                # for these special indices we need to derive percentiles
                #  from the historical data
                with xr.open_mfdataset(hist_fps) as hist_ds:
                    kwargs = {"hist_da": hist_ds[var_id]}
                    out.append(
                        compute_indicator(
                            da=ds[var_id],
                            idx=idx,
                            coord_labels=coord_labels,
                            kwargs=kwargs,
                        )
                    )
                pass
            else:
                out.append(
                    compute_indicator(
                        da=ds[var_id], idx=idx, coord_labels=coord_labels, kwargs=kwargs
                    )
                )

    return out

In [3]:
# variable_id = "pr"
# scenario = "historical"
# model = "CESM2"
regrid_dir = Path("/center1/CMIP6/kmredilla/cmip6_regridding/regrid/")

# models will just be all models in regrid dir
models = [d.name for d in regrid_dir.glob("*")]
scenarios = ["historical", "ssp245", "ssp585"]


In [4]:
def generate_base_kwargs(scenarios, models, var_id, idx_list, input_dir):
    kwargs_list = []
    for scenario in scenarios:
        for model in models:
            fps = list(
                input_dir.joinpath(f"{model}/{scenario}/day/{var_id}").glob("*.nc")
            )

            # not all combinations of model, scenario, and model variable actually exist
            if len(fps) > 0:
                coord_labels = dict(
                    scenario=scenario,
                    model=model,
                )
                kwargs_list.append(
                    dict(
                        fps=fps,
                        idx_list=idx_list,
                        coord_labels=coord_labels,
                    )
                )

    return kwargs_list


def add_xsdi_kwargs(kwargs_list, input_dir):
    """Add kwargs for cold and warm spell duration indices. This should just be the historical data filepaths for computing percentiles."""
    for kwargs in kwargs_list:
        # make sure at least one of the indicators is an "<warm/cold> spell duration index"
        assert np.any([idx in ["wsdi", "csdi"] for idx in kwargs["idx_list"]])

        coord_labels = ["model"]
        fps = list(
            input_dir.joinpath(
                f"{coord_labels['model']}/{coord_labels['scenario']}/day/{kwargs['var_id']}"
            ).glob("*.nc")
        )

        kwargs["hist_fps"] = fps

    return kwargs_list

In [5]:
var_id = "pr"
idx_list = varid_idx_lu[var_id]
kwargs_list = generate_base_kwargs(scenarios, models, var_id, idx_list, regrid_dir)


In [6]:
out = []
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
    for kwarg in tqdm.tqdm(kwargs_list):
        out.append(xr.merge(run_compute_indicators(**kwarg)).compute())


  0%|          | 0/34 [00:00<?, ?it/s]

In [None]:
indicators_ds = xr.merge(out)
