# `clisops` regridding functionalities - powered by `xesmf`

The regridding functionalities of clisops consist of the regridding operator/function `regrid` in `clisops.ops`, allowing one-line remapping of `xarray.Datasets` or `xarray.DataArrays`, while orchestrating the use of classes and functions in `clisops.core`:
- the `Grid` and `Weights` classes, to check and pre-process input as well as output grids and to generate the remapping weights
- a `regrid` function, performing the remapping by applying the generated weights on the input data

For the weight generation and the regridding, the [xESMF](https://github.com/pangeo-data/xESMF) `Regridder` class is used, which itself allows an easy application of many of the remapping functionalities of [ESMF](https://earthsystemmodeling.org/)/[ESMPy](https://github.com/esmf-org/esmf/blob/develop/src/addon/ESMPy/README.md).

In [None]:
# Imports

%matplotlib inline
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import psyplot.project as psy
import numpy as np
import xarray as xr
import cf_xarray as cfxr

from pathlib import Path
from git import Repo
# Set required environment variable for ESMPy
import os 
os.environ['ESMFMKFILE'] = str(Path(os.__file__).parent.parent / 'esmf.mk')
import xesmf as xe

import clisops as cl # atm. the regrid-main-martin branch of clisops
import clisops.ops as clops
import clisops.core as clore
from clisops.utils import dataset_utils
from roocs_grids import get_grid_file, grid_dict, grid_annotations

print(f"Using xarray in version {xr.__version__}")
print(f"Using cf_xarray in version {cfxr.__version__}")
print(f"Using xESMF in version {xe.__version__}")
print(f"Using clisops in version {cl.__version__}")

xr.set_options(display_style='html')

## Turn off warnings?
import warnings
warnings.simplefilter("ignore")

In [None]:
# Initialize test data

# Initialize mini-esgf-data
MINIESGF_URL="https://github.com/roocs/mini-esgf-data"
branch = "master"
MINIESGF = Path(Path.home(),".mini-esgf-data", branch)

# Retrieve mini-esgf test data
if not os.path.isdir(MINIESGF):
    repo = Repo.clone_from(MINIESGF_URL, MINIESGF)
    repo.git.checkout(branch)
else:
    repo = Repo(MINIESGF)
    repo.git.checkout(branch)
    repo.remotes[0].pull()
    
MINIESGF=Path(MINIESGF,"test_data")

## `clisops.ops.regrid`

One-line remapping with `clisops.ops.regrid`
```python
def regrid(
    ds: Union[xr.Dataset, str, Path],
    *,
    method: Optional[str] = "nearest_s2d",
    adaptive_masking_threshold: Optional[Union[int, float]] = 0.5,
    grid: Optional[
        Union[xr.Dataset, xr.DataArray, int, float, tuple, str]
    ] = "adaptive",
    output_dir: Optional[Union[str, Path]] = None,
    output_type: Optional[str] = "netcdf",
    split_method: Optional[str] = "time:auto",
    file_namer: Optional[str] = "standard",
    keep_attrs: Optional[Union[bool, str]] = True,
) -> List[Union[xr.Dataset, str]]   
```
The different options for the `method`, `grid` and `adaptive_masking_threshold` parameters are described in below sections:

*  [clisops.core.Grid](#clisops.core.Grid)
*  [clisops.core.Weights](#clisops.core.Weights)
*  [clisops.core.regrid](#clisops.core.regrid)


### Remap a global `xarray.Dataset` to a global 2.5 degree grid using the bilinear method

#### Load the dataset

In [None]:
ds_vert_path = Path(MINIESGF, "badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/AERmon/"
                              "o3/gn/v20190710/o3_AERmon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc")
ds_vert = xr.open_dataset(ds_vert_path)
ds_vert

#### Take a look at the grid

In [None]:
# Create 2D coordinate variables
lon,lat = np.meshgrid(ds_vert["lon"].data, ds_vert["lat"].data)

# Plot
plt.figure(figsize=(8,5))
plt.scatter(lon[::3, ::3], lat[::3, ::3], s=0.5)  
plt.xlabel('lon')
plt.ylabel('lat')

#### Remap to global 2.5 degree grid with the bilinear method

In [None]:
ds_remap = clops.regrid(ds_vert, method="bilinear", grid="2pt5deg", output_type="xarray")[0]
ds_remap

#### Plot the remapped data next to the source data

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(18,4), subplot_kw={'projection': ccrs.PlateCarree()})
for ax in axes:
    ax.coastlines()

# Source data
ds_vert.o3.isel(time=0, lev=0).plot.pcolormesh(ax=axes[0], x="lon", y="lat", shading="auto")
axes[0].title.set_text("Source - MPI-ESM1-2-LR ECHAM6 (T63L47, ~1.9° resolution)")
# Remapped data
ds_remap.o3.isel(time=0, lev=0).plot.pcolormesh(ax=axes[1], x="lon", y="lat", shading="auto")
axes[1].title.set_text("Target - regular lat-lon (2.5° resolution)")

### Remap regional `xarray.Dataset` to a regional grid of adaptive resolution using the bilinear method
Adaptive resolution means, that the regular lat-lon target grid will have approximately the same resolution as the source grid.

#### Load the dataset

In [None]:
ds_cordex_path = Path(MINIESGF, "pool/data/CORDEX/data/cordex/output/EUR-22/GERICS/MPI-M-MPI-ESM-LR/"
                                "rcp85/r1i1p1/GERICS-REMO2015/v1/mon/tas/v20191029/"
                                "tas_EUR-22_MPI-M-MPI-ESM-LR_rcp85_r1i1p1_GERICS-REMO2015_v1_mon_202101.nc")
ds_cordex = xr.open_dataset(ds_cordex_path)
ds_cordex

#### Take a look at the grid

In [None]:
plt.figure(figsize=(8,5))
plt.scatter(ds_cordex['lon'][::4, ::4], ds_cordex['lat'][::4, ::4], s=0.1)  
plt.xlabel('lon')
plt.ylabel('lat')

#### Remap to regional regular lat-lon grid of adaptive resolution with the bilinear method

In [None]:
ds_remap = clops.regrid(ds_cordex, method="bilinear", grid="adaptive", output_type="xarray")[0]
ds_remap

#### Plot the remapped data next to the source data

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(18,4), subplot_kw={'projection': ccrs.PlateCarree()})
for ax in axes: ax.coastlines()

# Source data
ds_cordex.tas.isel(time=0).plot.pcolormesh(ax=axes[0], x="lon", y="lat", shading="auto", cmap="RdBu_r")
axes[0].title.set_text("Source - GERICS-REMO2015 (EUR22, ~0.22° resolution)")
# Remapped data
ds_remap.tas.isel(time=0).plot.pcolormesh(ax=axes[1], x="lon", y="lat", shading="auto", cmap="RdBu_r")
axes[1].title.set_text("Target - regional regular lat-lon (adaptive resolution)")

### Remap unstructured `xarray.Dataset` to a global grid of adaptive resolution using the nearest neighbour method

For unstructured grids, at least for the moment, only the nearest neighbour remapping method is supported.

#### Load the dataset

In [None]:
ds_icono_path = Path(MINIESGF, "badc/cmip6/data/CMIP6/CMIP/MPI-M/ICON-ESM-LR/historical/"
                               "r1i1p1f1/Omon/thetao/gn/v20210215/"
                               "thetao_Omon_ICON-ESM-LR_historical_r1i1p1f1_gn_185001.nc")
ds_icono = xr.open_dataset(ds_icono_path)
ds_icono

#### Take a look at the grid

In [None]:
plt.figure(figsize=(16,9))
plt.scatter(ds_icono['longitude'][::2], ds_icono['latitude'][::2], s=0.05)  
plt.xlabel('lon')
plt.ylabel('lat')

#### Remap to global grid of adaptive resolution with the nearest neighbour method

In [None]:
ds_remap = clops.regrid(ds_icono, method="nearest_s2d", grid="adaptive", output_type="xarray")[0]
ds_remap

#### Plot source data and remapped data

(Using [psyplot](https://psyplot.github.io/) to plot the unstructured data since xarray does not (yet?) support it.)

In [None]:
# Source data
maps=psy.plot.mapplot(ds_icono_path, cmap="RdBu_r", title="Source - ICON-ESM-LR ICON-O (Ruby-0, 40km resolution)", 
                      time=[0], lev=[0])

In [None]:
# Remapped data
plt.figure(figsize=(9,4));
ax = plt.axes(projection=ccrs.PlateCarree());
ds_remap.thetao.isel(time=0, lev=0).plot.pcolormesh(ax=ax, x="lon", y="lat", shading="auto",
                                                    cmap="RdBu_r", vmin = -1, vmax=40)
ax.title.set_text("Target - regular lat-lon (adaptive resolution)")
ax.coastlines()

<a id='clisops.core.Grid'></a>

## `clisops.core.Grid`

### Create a grid object from an `xarray.Dataset`

#### Load the dataset

In [None]:
dso_path = Path(MINIESGF, "badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-HR/historical/r1i1p1f1/Omon/tos/gn/"
                          "v20190710/tos_Omon_MPI-ESM1-2-HR_historical_r1i1p1f1_gn_185001.nc")
dso = xr.open_dataset(dso_path)
dso

#### Create the Grid object

In [None]:
grido = clore.Grid(ds=dso)
grido

The `xarray.Dataset` is attached to the `clisops.core.Grid` object. Auxiliary coordinates and data variables have been (re)set appropriately.

In [None]:
grido.ds

#### Plot the data

In [None]:
plt.figure(figsize=(9,4))
ax = plt.axes(projection=ccrs.PlateCarree())
grido.ds.tos.isel(time=0).plot.pcolormesh(ax=ax, x=grido.lon, y=grido.lat, shading="auto",
                                          cmap="RdBu_r", vmin = -1, vmax=40)
ax.coastlines()

### Create a grid object from an `xarray.DataArray`

Note that `xarray.DataArray` objects do not support the bounds of coordinate variables to be defined.

#### Extract tos `DataArray`

In [None]:
dao = dso.tos
dao

#### Create Grid object for MPIOM tos dataarray:

In [None]:
grido_tos = clore.Grid(ds=dao)
grido_tos

### Create a grid object using a `grid_instructor`

* global grid: `grid_instructor = (lon_step, lat_step)` or `grid_instructor = step`
* regional grid:`grid_instructor = (lon_start, lon_end, lon_step, lat_start, lat_end, lat_step)` or `grid_instructor = (start, end, step)` 

In [None]:
grid_1deg = clore.Grid(grid_instructor=1)
grid_1deg

In [None]:
grid_1degx2deg_regional = clore.Grid(grid_instructor=(0., 90., 1., 35., 50., 2. ))
grid_1degx2deg_regional

### Create a grid object using a `grid_id`

Makes use of the predefined grids of `roocs_grids`, which is a collection of grids used for example for the [IPCC Atlas](https://github.com/IPCC-WG1/Atlas/tree/main/reference-grids) and for [CMIP6 Regridding Weights generation](https://docs.google.com/document/d/1BfVVsKAk9MAsOYstwFSWI2ZBt5mrO_Nmcu7rLGDuL08/edit).

In [None]:
for key, gridinfo in grid_annotations.items(): print(f"- {key:20} {gridinfo}")

In [None]:
grid_era5 = clore.Grid(grid_id = "0pt25deg_era5")
grid_era5

### `clisops.core.Grid` objects can be compared to one another

Optional verbose output gives information on where the grids differ: lat, lon, lat_bnds, lon_bnds, mask?

#### Compare the tos dataset to the tos dataarray

In [None]:
comp = grido.compare_grid(grido_tos, verbose = True)
print("Grids are equal?", comp)

#### Compare both 0.25° ERA5 Grids

In [None]:
# Create the Grid object
grid_era5_lsm = clore.Grid(grid_id = "0pt25deg_era5_lsm", compute_bounds=True)

In [None]:
# Compare
comp = grid_era5.compare_grid(grid_era5_lsm, verbose=True)
print("Grids are equal?", comp)

### Strip `clisops.core.Grid` objects of all `data_vars` and `coords` unrelated to the horizontal grid

In [None]:
grid_era5_lsm.ds

The parameter `keep_attrs` can be set, the default is `False`.

In [None]:
grid_era5_lsm._drop_vars(keep_attrs=False)
grid_era5_lsm.ds

### Transfer coordinate variables between `clisops.core.Grid` objects that are unrelated to the horizontal grid

The parameter `keep_attrs` can be set, the default is `True`. All settings for `keep_attrs` are described later in section [clisops.core.regrid](#clisops.core.regrid).

#### Load the dataset

In [None]:
ds_vert_path = Path(MINIESGF, "badc/cmip6/data/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/historical/r1i1p1f1/"
                              "AERmon/o3/gn/v20190710/o3_AERmon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_185001.nc")
ds_vert = xr.open_dataset(ds_vert_path)
ds_vert

#### Create grid object

In [None]:
grid_vert = clore.Grid(ds_vert)
grid_vert

#### Transfer the coordinates to the ERA5 grid object

In [None]:
grid_era5_lsm._transfer_coords(grid_vert, keep_attrs=True)
grid_era5_lsm.ds

<a id='clisops.core.Weights'></a>

## `clisops.core.Weights`

Create regridding weights to regrid between two grids. Supported are the following of [xESMF's remapping methods](https://pangeo-xesmf.readthedocs.io/en/latest/notebooks/Compare_algorithms.html):
* `nearest_s2d`
* `bilinear`
* `conservative`
* `patch`

### Create 2-degree target grid

In [None]:
grid_2deg = clore.Grid(grid_id="2deg_lsm", compute_bounds=True)
grid_2deg

### Create conservative remapping weights using the `clisops.core.Weights` class
`grid_in` and `grid_out` are `Grid` objects

In [None]:
%time weights = clore.Weights(grid_in = grido, grid_out = grid_2deg, method="conservative")

### Local weights cache

Weights are cached on disk and do not have to be created more than once. The default cache directory is platform-dependent and set via the package `platformdirs`. For Linux it is `'/home/my_user/.local/share/clisops/weights_dir'` and can optionally be adjusted:

- permanently by modifying the parameter `grid_weights: local_weights_dir` in the `roocs.ini` configuration file that can be found in the clisops installation directory
- or temporarily via:
```python
from clisops import core as clore
clore.weights_cache_init("/dir/for/weights/cache")
```

In [None]:
from clisops.core.regrid import CONFIG
print(CONFIG["clisops:grid_weights"]["local_weights_dir"])

In [None]:
!ls -sh {CONFIG["clisops:grid_weights"]["local_weights_dir"]}

In [None]:
!cat {CONFIG["clisops:grid_weights"]["local_weights_dir"]}/weights_*_conservative.json

Now the weights will be read directly from the cache

In [None]:
%time weights = clore.Weights(grid_in = grido, grid_out = grid_2deg, method="conservative")

The weights cache can be flushed, which removes all weight and grid files as well as the json files holding the metadata. To see what would be removed, one can use the `dryrun=True` parameter. To re-initialize the weights cache in a different directory, one can use the `weights_dir_init="/new/dir/for/weights/cache"` parameter. Even when re-initializing the weights cache under a new path, using `clore.weights_cache_flush`, no directory is getting removed, only above listed files. When `dryrun` is not set, the files that are getting deleted can be displayed with `verbose=True`.

In [None]:
clore.weights_cache_flush(dryrun=True)

In [None]:
clore.weights_cache_flush(verbose=True)

<a id='clisops.core.regrid'></a>

## `clisops.core.regrid`

This function allows to perform the eventual regridding and provides a resulting `xarray.Dataset`

```python
def regrid(
    grid_in: Grid,
    grid_out: Grid,
    weights: Weights,
    adaptive_masking_threshold: Optional[float] = 0.5,
    keep_attrs: Optional[bool] = True,
):
```

*  `grid_in` and `grid_out` are `Grid` objects, `weights` is a `Weights` object.
*  `adaptive_masking_threshold` (AMT) A value within the [0., 1.] interval that defines the maximum `RATIO` of missing_values amongst the total number of data values contributing to the calculation of the target grid cell value. For a fraction [0., AMT[ of the contributing source data missing, the target grid cell will be set to missing_value, else, it will be re-normalized by the factor `1./(1.-RATIO)`. Thus, if AMT is set to 1, all source grid cells that contribute to a target grid cell must be missing in order for the target grid cell to be defined as missing itself. Values greater than 1 or less than 0 will cause adaptive masking to be turned off. This adaptive masking technique allows to reuse generated weights for differently masked data (e.g. land-sea masks or orographic masks that vary with depth / height).
* `keep_attrs` can have the following settings:
  *  `True` : The resulting `xarray.Dataset` will have all attributes of `grid_in.ds.attrs`, despite attributes that have to be added and altered due to the new grid. 
  *  `False` : The resulting `xarray.Dataset` will have no attributes despite attributes generated by the regridding process.
  *  `"target"` : The resulting `xarray.Dataset` will have all attributes of `grid_out.ds.attrs`, despite attributes generated by the regridding process. Not recommended.
  
  
### In the following an example showing the function application and the effect of the adaptive masking.

In [None]:
ds_out_amt0 = clore.regrid(grido, grid_2deg, weights, adaptive_masking_threshold=-1)

In [None]:
ds_out_amt1 = clore.regrid(grido, grid_2deg, weights, adaptive_masking_threshold=0.5)

#### Plot the resulting data

In [None]:
# Create panel plot of regridded data (global)
fig, axes = plt.subplots(ncols=2, nrows=1, 
                         figsize=(18, 5), # global
                         subplot_kw={'projection': ccrs.PlateCarree()})

ds_out_amt0["tos"].isel(time=0).plot.pcolormesh(ax=axes[0], vmin=0, vmax=30, cmap="plasma")
axes[0].title.set_text("Target (2° regular lat-lon) - No adaptive masking")

ds_out_amt1["tos"].isel(time=0).plot.pcolormesh(ax=axes[1], vmin=0, vmax=30, cmap="plasma")
axes[1].title.set_text("Target (2° regular lat-lon) - Adaptive masking")

for axis in axes.flatten():
    axis.coastlines()
    axis.set_xlabel('lon')
    axis.set_ylabel('lat')

In [None]:
# Create panel plot of regridded data (Japan)
fig, axes = plt.subplots(ncols=3, nrows=1, 
                         figsize=(18, 4), # Japan
                         subplot_kw={'projection': ccrs.PlateCarree()})

grido.ds.tos.isel(time=0).plot.pcolormesh(ax=axes[0], x=grido.lon, y=grido.lat, 
                                          vmin=0, vmax=30, cmap="plasma", shading="auto")
axes[0].title.set_text("Source - MPI-ESM1-2-HR MPIOM (TP04, ~0.4° resolution)")

ds_out_amt0["tos"].isel(time=0).plot.pcolormesh(ax=axes[1], vmin=0, vmax=30, cmap="plasma")
axes[1].title.set_text("Target - No adaptive masking")

ds_out_amt1["tos"].isel(time=0).plot.pcolormesh(ax=axes[2], vmin=0, vmax=30, cmap="plasma")
axes[2].title.set_text("Target - Adaptive masking")

for axis in axes.flatten():
    axis.coastlines()
    axis.set_xlabel('lon')
    axis.set_ylabel('lat')
    axis.set_xlim([125, 150])
    axis.set_ylim([25, 50])