# Custom Indexes

While `xoak` provides some built-in index wrappers, it is easy to wrap and register new indexes. 

In [None]:
import numpy as np
import xarray as xr
import xoak

This returns the list of built-in indexes in `xoak`:

In [None]:
xoak.indexes

## Example: add a brute-force index

Every `xoak` supported index is a subclass of `xoak.IndexAdapter` that must implement the `build` and `query` methods. The `xoak.register_index` decorator may be used to register a new index.

Let's create and register a new "index", which simply performs brute-force nearest-neighbor lookup by computing the pairwise distances between all index and query points and finding the minimum distance. 

In [None]:
from sklearn.metrics.pairwise import pairwise_distances_argmin_min


@xoak.register_index('brute_force')
class BruteForceIndex(xoak.IndexAdapter):
    
    def build(self, points):
        # there is no index to build here, just return the points
        return points
    
    def query(self, index, points):
        positions, distances = pairwise_distances_argmin_min(points, index)
        return distances, positions


This new index now appears in the list of `xoak` registered indexes:

In [None]:
xoak.indexes

Let's use this index in the basic example below:

In [None]:
# create mesh
shape = (100, 100)
lat = np.random.uniform(-90, 90, size=shape)
lon = np.random.uniform(-180, 180, size=shape)

field = lat + lon

ds_mesh = xr.Dataset(
    coords={'lat': (('x', 'y'), lat), 'lon': (('x', 'y'), lon)},
    data_vars={'field': (('x', 'y'), field)},
)

# set the brute-force index (doesn't really build any index in this case)
ds_mesh.xoak.set_index(['lat', 'lon'], 'brute_force')

# create trajectory points
ds_trajectory = xr.Dataset({
    'latitude': ('trajectory', np.linspace(-10, 40, 30)),
    'longitude': ('trajectory', np.linspace(-150, 150, 30))
})

# select mesh points
ds_selection = ds_mesh.xoak.sel(
    lat=ds_trajectory.latitude,
    lon=ds_trajectory.longitude
)

# plot results
ds_trajectory.plot.scatter(x='longitude', y='latitude', c='k', alpha=0.7);
ds_selection.plot.scatter(x='lon', y='lat', hue='field', alpha=0.9);