# Creating custom accessors

## Introduction

An accessor is a way of attaching a custom function to xarray types so that it can be called as if it were a method while retaining a clear separation between "core" xarray API and custom API. It enables you to easily extend and customize xarray's functionality while limiting naming conflicts and minimizing the chances of your code breaking with xarray upgrades.

If you've used [rioxarray](https://corteva.github.io/rioxarray/stable/) (e.g. `da.rio.crs`) or [hvplot](https://hvplot.holoviz.org/) (e.g. `ds.hvplot()`), you may have already used an xarray accessor without knowing it!

The [Xarray documentation](https://docs.xarray.dev/en/stable/internals/extending-xarray.html) has some more technical details, and this tutorial provides example custom accessors and their uses.

## Why create a custom accessor

- You can easily create a custom suite of tools that work on Xarray objects
- It keeps your workflows cleaner and simpler
- Your project-specific code is easy to share
- It's easy to implement: you don't need to integrate any code into Xarray
- it makes it easier to perform checks and write code documentation because you only have to create them once!

## Easy steps to create your own accessor

1. Create your custom class, including the mandatory `__init__` method
2. Add the `xr.register_dataarray_accessor()` or `xr.register_dataset_accessor()` 
3. Use your custom functions 

## Example 1: accessing scipy functionality

For example, imagine you're a statistician who regularly uses a special `skewness` function which acts on dataarrays but is only of interest to people in your specific field.

You can create a method which applies this skewness function to an xarray objects, and then register the method under a custom `stats` accessor like this

In [None]:
from scipy.stats import skew


@xr.register_dataarray_accessor("stats")
class StatsAccessor:
    def __init__(self, da):
        self._da = da

    def skewness(self, dim):
        return self._da.reduce(func=skew, dim=dim)

Now we can conveniently access this functionality via the `stats` accessor

In [None]:
ds['air'].stats.skewness(dim="time")

Notice how the presence of `.stats` clearly differentiates our new "accessor method" from core xarray methods.

## Example 2: creating your own workflows

Perhaps you find yourself running similar code for multiple xarray objects or across related projects. By packing your code into an extension, it makes it easy to repeat the same operation while reducing the likelihood of [human introduced] errors.

Consider someone who frequently converts their elevations to be relative to the geoid (rather than the ellipsoid) using a custom, local conversion (otherwise, we'd recommend using an established conversion library like [pyproj](https://pypi.org/project/pyproj/) to switch between datums).

In [None]:
@xr.register_dataarray_accessor("geoidxr")
class GeoidXR:
    """
    An extension for an XArray dataset that will calculate geoidal elevations from a local source file.
    """
    # ----------------------------------------------------------------------
    # Constructors

    def __init__(
        self,
        xrds,
    ):

        self._xrds = xrds
        # Running this function on init will check that my dataset has all the needed dimensions and variables
        # as specific to my workflow, saving time and headache later if they were missing and the computation fails
        # partway through.
        self._validate(self, req_dim = ['x','y','dtime'], req_vars = {'elevation':['x','y','dtime']})
        

    # ----------------------------------------------------------------------
    # Methods

    @staticmethod
    def _validate(self, req_dim=None, req_vars=None):
        '''
        Make sure the xarray dataset has the correct dimensions and variables

        req_dim : list of str
            List of all required dimension names

        req_vars : list of str
            List of all required variable  names
        '''

        if req_dim is not None:
            if all([dim not in list(self._xrds.dims) for dim in req_dim]):
                raise AttributeError("Required dimensions are missing")
        if req_vars is not None:
            if all([var not in self._xrds.variables for var in req_vars.keys()]):
                raise AttributeError("Required variables are missing")


    # Notice that 'geoid' has been added to the req_vars list
    def to_geoid(self, req_dim=['dtime','x','y'], req_vars={'elevation':['x','y','dtime','geoid']},
                    source=None):
        """
        Get geoid layer from your local file, which is provided to the function as "source",
        and apply the offset to all elevation values.
        Adds 'geoid_offset' keyword to "offsets" attribute so you know the geoid offset was applied.

        req_dim : list of str
            List of all required dimension names.

        req_vars : list of str
            List of all required variable  names

        source : str
            Full path to your source file containing geoid offsets
        """

        # check to make sure you haven't already run this function (and are thus applying the offset twice)
        try:
            values = (self._xrds.attrs['offset_names'])
            assert 'geoid_offset' not in values, "You've already applied the geoid offset!"
            values = list([values])+ ['geoid_offset']
        except KeyError:
            values = ['geoid_offset']

        self._validate(self, req_dim, req_vars)

        # read in your geoid values
        # WARNING: this implementation assumes your geoid values are in the same CRS and grid as the data you are applying
        # them to. If not, you will need to reproject and/or resample them to match the data to which you are applying them.
        # That step is not included here to emphasize the accessor aspect of the workflow.
        with rasterio.open(source) as src:
            geoid = src['geoid_varname']
        
        # As noted above, this step will fail or produce unreliable results if your data is not properly gridded
        self._xrds['elevation'] = self._xrds.elevation - geoid

        self._xrds.attrs['offset_names'] = values

        return self._xrds


Now, each time we want to convert our ellipsoid data to the geoid, we only have to run one line of code, and it will also perform a multitude of checks for us to make sure we're performing exactly the operation we expect. Imagine the possibilities (and decrease in frustration)!

In [None]:
ds = ds.geoidxr.to_geoid(source='/Path/to/Custom/source/file.nc')