### xESMF - problems when creating a `xarray.Dataset` when using `xesmf.Regridder.__call__`

* `xesmf.Regridder.__call__` does not forward some coordinate variables to the regridded `xarray.Dataset` that are registered as `xarray.Dataset.data_vars` and not as `xarray.Dataset.coords` while it forwards others (like horizontal bounds) that should rather be replaced.
* Trying to use `xarray.open_dataset(path_to_ds, decode_coords='all')` registers all coordinate variables properly BUT removes significant metadata, not allowing `cf_xarray.accessor._get_item.drop_bounds` to distinguish between the coordinate variable and its bounds. Also it registers for example `ps` (surface pressure), that is required for a sigma hybrid vertical axis, under `xarray.Dataset.coords`, preventing `xESMF` from remapping it (it gets dropped).
* I am not aware of a `cf_xarray.CFaccessor.method` / `cf_xarray`-function that redefines / resets `xarray.Dataset.data_vars` and `xarray.Dataset.coords` of an `xarray.Dataset` in the desired manner. So I set up a custom function making use of `cf_xarray`.

In [1]:
import numpy as np
import xarray as xr
import cf_xarray as cfxr
import xesmf as xe
print("Using cf-xarray in version %s" % cfxr.__version__)
print("Using xESMF in version %s" % xe.__version__)
#print(xe.__file__)

import warnings
warnings.simplefilter("ignore") 
#with warnings.catch_warnings():
#        warnings.simplefilter("ignore")

xr.set_options(display_style='html');

Using cf-xarray in version 0.6.2.dev1+g93041e1
Using xESMF in version 0.6.1


## 1 Default approach

### Load the dataset

In [2]:
ds_path_o3 = "o3_AERmon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_201001-201412.nc"
ds_path_tos = "tos_Omon_MPI-ESM1-2-LR_historical_r1i1p1f1_gn_201001-201412.nc"
ds_o3 = xr.open_dataset(ds_path_o3).isel(time=0)
ds_tos = xr.open_dataset(ds_path_tos).isel(time=0)

In [3]:
ds_o3

In [4]:
ds_tos

### Calculate the regridding weights

In [5]:
# Specify a global 1 deg grid as target grid
ds_out = xe.util.grid_global(1,1)
ds_out

In [6]:
# Create regridding weights
def regrid(ds_in, ds_out, method='nearest_s2d', locstream_in=False):
    """Convenience function for calculating regridding weights"""
    return xe.Regridder(ds_in, ds_out, method, locstream_in)

In [7]:
regridder_o3 = regrid(ds_o3, ds_out)
regridder_o3

xESMF Regridder 
Regridding algorithm:       nearest_s2d 
Weight filename:            nearest_s2d_96x192_180x360.nc 
Reuse pre-computed weights? False 
Input grid shape:           (96, 192) 
Output grid shape:          (180, 360) 
Periodic in longitude?      False

In [8]:
regridder_tos = regrid(ds_tos, ds_out)
regridder_tos

xESMF Regridder 
Regridding algorithm:       nearest_s2d 
Weight filename:            nearest_s2d_220x256_180x360.nc 
Reuse pre-computed weights? False 
Input grid shape:           (220, 256) 
Output grid shape:          (180, 360) 
Periodic in longitude?      False

### Perform remapping

Important vertical coordinate variables are lost in the regridded dataset `ds_o3_g1`!
While the old bounds are kept in the regridded dataset `ds_tos_g1`.

In [9]:
ds_tos_g1 = regridder_tos(ds_tos, keep_attrs=True)
ds_o3_g1 = regridder_o3(ds_o3, keep_attrs=True)

In [10]:
ds_tos_g1

In [11]:
ds_o3_g1

### What can be done?

* Make all coordinate variables and auxillary coordinate variables to be recognized as coordinates by xarray.
  * by using the option decode_coordinates when loading the dataset
  * by making use of a custom function and cf_xarray
* Store the remapped variables in ds_out and transfer all necessary (eg. non-horizontal) coordinate variables to ds_out

## 2 Approach using `xarray.open_dataset(path_to_ds, decode_coords='all')`

That causes the variable `ps` in `ds_o3` to be defined under `xarray.Dataset.coords`, leading to xESMF dropping this variable when calling `xesmf.Regridder.__call__`.
For `ds_tos` the `cf_xarray.CFAccessor` cannot uniquely identify the horizontal coordinates, since xarray drops a few essential metadata upon loading the data and decoding the coordinates. 

In [12]:
ds_o3 = xr.open_dataset(ds_path_o3, decode_coords='all').isel(time=0)
ds_tos = xr.open_dataset(ds_path_tos, decode_coords='all').isel(time=0)

In [13]:
ds_o3

In [14]:
ds_tos

In [15]:
regridder_o3 = regrid(ds_o3, ds_out)
regridder_o3

xESMF Regridder 
Regridding algorithm:       nearest_s2d 
Weight filename:            nearest_s2d_96x192_180x360.nc 
Reuse pre-computed weights? False 
Input grid shape:           (96, 192) 
Output grid shape:          (180, 360) 
Periodic in longitude?      False

In [16]:
# variable ps is dropped, but should be remapped as well
ds_o3_g1 = regridder_o3(ds_o3, keep_attrs=True)
ds_o3_g1

In [17]:
# cf_xarray cannot identify longitude or latitude:
ds_tos.cf["latitude"]

KeyError: "Receive multiple variables for key 'latitude': {'vertices_latitude', 'latitude'}. Expected only one. Please pass a list ['latitude'] instead to get all variables matching 'latitude'."

In [18]:
ds_tos.cf.bounds

{}

## 3 Approach using custom function to reset `xarray.Dataset.data_vars` and `xarray.Dataset.coords` (making use of `cf_xarray`)

Variable `ps` in `ds_o3` is not remapped as well, all necessary coordinate variables are kept.
For `ds_tos` the horizontal bounds are dropped.

In [19]:
# It still matters that the datasets have been loaded without using xarray to decode the coords.
#  Since that drops essential metadata for cf_xarray to do its magic.
def set_data_vars_and_coords(ds):
    "Set all non data vars as coordinates and all non-coords as data_vars."
    to_coord = []
    to_datavar = []
    gridtype=""
    
    #########################################
    # Usually dealt with in another routine within clisops.core.regrid.Grid
    # (horizontal grids without lon or lat coord would be filtered out):
    lat=cfxr.accessor._get_with_standard_name(ds, "latitude")#[0]
    lon=cfxr.accessor._get_with_standard_name(ds, "longitude")#[0]
    lat=lat[0]
    lon=lon[0]
    
    # Usually dealt with in another routine
    #  setting the gridtype, nlat, nlon, ncells
    if ds[lat].ndim == 2:
        nlat = ds[lat].shape[0]
        nlon = ds[lon].shape[1]
        ncells = nlat * nlon        
    elif ds[lat].ndim == 1:
        if (
            ds[lat].shape == ds[lon].shape
            and ds[lat].dims[0]==ds[lon].dims[0]
            and len(ds[lat].dims) == 1
        ):
            nlat = ds[lat].shape[0]
            nlon = nlat
            ncells = nlat
            gridtype = "irregular" # locstream-like arranged coordinate variables
        else:
            nlat = ds[lat].shape[0]
            nlon = ds[lon].shape[0]
            ncells = nlat * nlon            
    else:
        return
    #########################################
    
    # Check by horizontal shape
    if ds[lat].ndim == 2:
        for data_var in ds.data_vars:
            if ds[data_var].ndim < 2:
                to_coord.append(data_var)
            elif ds[data_var].shape[-2:] != ds[lat].shape:
                to_coord.append(data_var)
    elif ds[lat].ndim == 1:
        for data_var in ds.data_vars:
            if gridtype == "irregular":                
                if ( len(ds[data_var].shape)>0 
                    and (ds[data_var].shape[-1],) != ds[lat].shape
                   ):
                    to_coord.append(data_var)
            else:
                if not (
                    ds[data_var].shape[-2:] == (nlat, nlon)
                    or ds[data_var].shape[-2:] == (nlon, nlat)
                ):
                    to_coord.append(data_var)
    
    # Check by attributes and names
    bounds=[]
    for bnds in ds.cf.bounds.values(): bounds+=bnds
    for bnds in bounds:
        if bnds in ds.data_vars:
            to_coord.append(bnds)   
    
    for var in ds.coords:        
        if var not in [lat, lon] + bounds:
            if gridtype=="irregular":                
                if (len(ds[var].shape)>0 and
                    ( ds[var].shape[-1]==ncells
                     and ds[var].dims[-1] in ds[lat].dims
                     and var not in ds.dims
                    )):
                    to_datavar.append(var)
            else:               
                if (len(ds[var].shape)>0 and   
                    ( ds[var].shape[-2:] == (nlat, nlon)
                     or ds[var].shape[-2:] == (nlon, nlat)
                    ) and all([dim in ds[var].dims for dim in list(ds[lat].dims) + list(ds[lon].dims)])
                   ):
                    to_datavar.append(var)                    
   
    if to_coord:
        ds = ds.set_coords(list(set(to_coord)))
    if to_datavar:
        ds = ds.reset_coords(list(set(to_datavar)))
    return ds

In [20]:
ds_o3 = xr.open_dataset(ds_path_o3).isel(time=0)
ds_o3 = set_data_vars_and_coords(ds_o3)
ds_o3

In [21]:
ds_tos = xr.open_dataset(ds_path_tos).isel(time=0)
ds_tos = set_data_vars_and_coords(ds_tos)
ds_tos

In [22]:
regridder_o3 = regrid(ds_o3, ds_out)
regridder_o3

xESMF Regridder 
Regridding algorithm:       nearest_s2d 
Weight filename:            nearest_s2d_96x192_180x360.nc 
Reuse pre-computed weights? False 
Input grid shape:           (96, 192) 
Output grid shape:          (180, 360) 
Periodic in longitude?      False

In [23]:
regridder_tos = regrid(ds_tos, ds_out)
regridder_tos

xESMF Regridder 
Regridding algorithm:       nearest_s2d 
Weight filename:            nearest_s2d_220x256_180x360.nc 
Reuse pre-computed weights? False 
Input grid shape:           (220, 256) 
Output grid shape:          (180, 360) 
Periodic in longitude?      False

In [24]:
ds_o3_g1 = regridder_o3(ds_o3, keep_attrs=True)
ds_o3_g1

In [25]:
ds_tos_g1 = regridder_tos(ds_tos, keep_attrs=True)
ds_tos_g1

## 4 Approach storing remapped data in `ds_out`

Eventually, `xarray.Dataset.attrs` and essential `xarray.Dataset.coords` have to be moved manually, since they are not transferred (even if `keep_attrs=True` is set and the coordinates are registered as `xarray.Dataset.coords`.

In [26]:
ds_out["tos"]=regridder_tos(ds_tos.tos)
ds_out["ps"]=regridder_o3(ds_o3.ps)
ds_out["o3"]=regridder_o3(ds_o3.o3, keep_attrs=True)
ds_out

In [27]:
ds_o3

In [28]:
for key in ds_o3.attrs:
    ds_out.attrs[key]=ds_o3.attrs[key]

In [29]:
ds_out