# Preprocessing for downscaling exploratory book

This notebook serves as a single place for long running operations that may need some babysitting that produce outputs to be used in other notebooks in this exploratory project.

It includes the following pieces:


1. Regridding historical CMIP6 precipitation data to the target 4km ERA5 grid in 3338 via conservative interpolation and bias-adjustment via quantile delta mapping. 

   * Done here because the CMIP6 data available for exploring the bias-adjustment have only been regridded via bilinear interpolation, and because the current regridding pipeline does not currently (April 2025) allow conservative regridding 
   * This also includes processing of the indicators for the downscaled data

In [None]:
# config cell
import shutil
from pathlib import Path
import xesmf as xe
import xarray as xr
from xclim import units
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
import dask
import baeda
from xclim import units, sdba


dask.config.set({"large-graph-warning-threshold": "100MB"})

# target grid file
target_grid_file = Path("/beegfs/CMIP6/kmredilla/downscaling/era5_target_slice.nc")

# cmip6 dir
cmip6_dir = Path("/beegfs/CMIP6/arctic-cmip6/CMIP6")

era5_dir = Path("/center1/CMIP6/kmredilla/era5_zarr")


# tmp dir for writing inputs/outputs
tmp_dir = Path("/beegfs/CMIP6/kmredilla/downscaling/eda")

## 1. Downscaling GFDL-ESM4 with conservative regridding and QDM adjustment

Setup - get the output filepaths and spin up a cluster (makes a big cluster for fast compute!)

In [None]:
tmp_regrid_fn = "{var_id}_{model}_{scenario}_regrid_{interp_method}.zarr"
tmp_adj_fn = "{var_id}_{model}_{scenario}_adj_{interp_method}.zarr"

# regridded path
hist_regrid_path = tmp_dir.joinpath(
    tmp_regrid_fn.format(
        var_id="pr",
        model="GFDL-ESM4",
        scenario="historical",
        interp_method="conservative",
    )
)

# adjusted path
hist_adj_path = tmp_dir.joinpath(
    tmp_adj_fn.format(
        var_id="pr",
        model="GFDL-ESM4",
        scenario="historical",
        interp_method="conservative",
    )
)
# for interactive nb testing
cluster = SLURMCluster(
    cores=28,
    processes=14,
    # n_workers=14,
    memory="128GB",
    # queue="debug",
    queue="t2small",
    # walltime="01:00:00",
    walltime="12:00:00",
    log_directory="/beegfs/CMIP6/kmredilla/tmp",
    account="cmip6",
    interface="ib0",
)
client = Client(cluster)

cluster.scale(n=112)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 42605 instead


In [None]:
var_id = "pr"

# load and prep data for regridding
hist_fps = [
    list(
        cmip6_dir.glob(
            f"CMIP/NOAA-GFDL/GFDL-ESM4/historical/r1i1p1f1/day/{var_id}/gr1/v20190726/{var_id}_day_GFDL-ESM4_historical_r1i1p1f1_gr1_{year}0101-*1231.nc"
        )
    )[0]
    for year in [1950, 1970, 1990, 2010]
]
hist_ds = xr.open_mfdataset(hist_fps, parallel=True, engine="h5netcdf")

# other datasets do not have time as a dimension for the bounds variables!
# GFDL-ESM4 has this wrong I believe. Should be no need for time as a dimension of bnds variables. Grid should be fixed through time.
# need to rectify this for GFDL-ESM4, as we get an assertion error in the underlying cf_xarray accessor get_bounds_dim_name
hist_ds["lon_bnds"] = hist_ds.lon_bnds.isel(time=0, drop=True)
hist_ds["lat_bnds"] = hist_ds.lat_bnds.isel(time=0, drop=True)

# load target
target_ds = xr.open_dataset(target_grid_file, engine="h5netcdf")

# add bounds variables (needed for consvervative regridding)
target_ds = target_ds.cf.add_bounds("lon")
target_ds = target_ds.cf.add_bounds("lat")

# Initialize regridder
regridder = xe.Regridder(
    hist_ds,
    target_ds,
    "conservative",
    unmapped_to_nan=True,
    periodic=True,
    ignore_degenerate=True,
)
hist_regrid_ds = regridder(hist_ds, keep_attrs=True)


# Prep the data for downscaling before writing to zarr, including chunking:
def prep_regrid(regrid_ds):
    var_id = list(regrid_ds.data_vars)[0]
    target_unit = baeda.units_lu[var_id]
    sim = baeda.drop_non_coord_vars(regrid_ds)[var_id]
    sim = units.convert_units_to(
        sim.assign_coords(time=sim.time.dt.floor("D")), target_unit
    )

    return sim


hist = prep_regrid(hist_regrid_ds)

# subset time
hist = hist.sel(time=slice("1965-01-01", "2014-12-31"))

# chunk em
chunk_kwargs = {"time": -1, "x": 10, "y": 10}
hist = hist.chunk(**chunk_kwargs)

Write to Zarr, triggering computation. This should help with later `dask`-ing of adjustment. We can have a dask array of regridded data instead of a dask array that includes the regrid operation, minimizing the task graph. 

In [None]:
if hist_regrid_path.exists():
    shutil.rmtree(hist_regrid_path)

_ = hist.to_dataset().to_zarr(hist_regrid_path)

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Now run the bias-adjustment.

Close previous stuff and connect to the new zarr store.

In [None]:
hist_ds.close()
del hist_ds

hist_ds = xr.open_zarr(hist_regrid_path)

# connect to ERA5 data
era5_stores = {
    # "t2max": era5_dir.joinpath("t2max_era5.zarr"),
    "pr": era5_dir.joinpath("pr_era5.zarr"),
}
era5_ds = baeda.open_era5_dataset(era5_stores)
ref = era5_ds[var_id].chunk(chunk_kwargs)

# QDM: train the adjustment
train_kwargs = dict(
    ref=ref,
    hist=hist_ds[var_id],
    nquantiles=50,
    group="time.dayofyear",
    window=31,
    kind=baeda.varid_adj_kind_lu[var_id],
)
if var_id in baeda.adapt_freq_thresh_lu:
    train_kwargs.update(
        adapt_freq_thresh=baeda.adapt_freq_thresh_lu[var_id],
        jitter_under_thresh_value=baeda.jitter_under_thresh_lu[var_id],
    )

qdm_train = sdba.QuantileDeltaMapping.train(**train_kwargs)

# set up the QDM adjustment for historical
hist_adj = qdm_train.adjust(
    hist_ds[var_id],
    extrapolation="constant",
    interp="nearest",
)
hist_adj.name = var_id
hist_adj = hist_adj.transpose("time", "y", "x")

Write the adjusted (downscaled) data to Zarr, triggering the adjustment computation. 

In [None]:
if hist_adj_path.exists():
    shutil.rmtree(hist_adj_path)

_ = hist_adj.to_dataset().to_zarr(hist_adj_path)

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Now run the indicators and write to zarr:

(load the downscaled data first, seems to perform much better for this step)

In [None]:
# seems to perform better if loaded?
# this whole cell could take up to 10 minutes
hist_adj_ds = xr.open_zarr(hist_adj_path)
hist_adj_ds = hist_adj_ds.load()

hist_idx = baeda.run_indicators(hist_adj_ds[var_id])

hist_idx_path = tmp_dir.joinpath(
    tmp_adj_fn.format(
        var_id=var_id + "idx",
        model="GFDL-ESM4",
        scenario="historical",
        interp_method="conservative",
    )
)
if hist_idx_path.exists():
    shutil.rmtree(hist_idx_path)

_ = hist_idx.chunk(**chunk_kwargs).to_zarr(hist_idx_path)

Close the cluster. 

In [18]:
cluster.close()