In [None]:
#| default_exp utilities.xarray

In [None]:
%%capture
%load_ext autoreload
%autoreload 2

# utilities.xarray

> Extra functionality for [xarray](https://xarray.pydata.org/)

In [None]:
#| hide
from fastcore.test import test_eq

In [None]:
#| export
from fastcore.basics import patch
from xarray import Coordinates
import numpy as np

In [None]:
#| export
@patch
def __add__(self:Coordinates, other):
    """Adding two `Coordinates` objects combines their coordinates."""
    if not isinstance(other, Coordinates): raise TypeError
    result = self.copy()
    result.update(other)
    return result

Define the `+` operator for `Coordinates` to combine the coordinates: 

In [None]:
coords = Coordinates({'foo': [1, 2]}) + Coordinates({'bar': [1, 2, 3]})
coords

Coordinates:
  * foo      (foo) int64 16B 1 2
  * bar      (bar) int64 24B 1 2 3

In [None]:
#|hide
test_eq(coords, Coordinates({'foo': [1, 2], 'bar': [1, 2, 3]}))

### complement -


In [None]:
#| export
@patch
def complement(self:Coordinates, other:Coordinates):
    """Return coordinates not in other coordinates."""
    return Coordinates({k: v for k, v in self.items() if k not in other})

In [None]:
coords.complement(Coordinates({'bar': [1, 2, 3]}))

Coordinates:
  * foo      (foo) int64 16B 1 2

In [None]:
#|hide
test_eq(_, Coordinates({'foo': [1, 2]}))

### shape -


In [None]:
#| export
@patch(as_prop=True)
def shape(self:Coordinates):
    """Return tuple of sizes of the coordinates."""
    return tuple(self.sizes.values())

In [None]:
coords.shape

(2, 3)

In [None]:
#|hide
test_eq(_, (2, 3))

### size -


In [None]:
#| export
@patch(as_prop=True)
def size(self:Coordinates):
    """Return product of coordinate lengths."""
    return np.prod(self.shape)

In [None]:
coords.size

6

In [None]:
#|hide
test_eq(_, 6)

### intersection -


In [None]:
#| export
@patch
def intersection(self:Coordinates, other:Coordinates):
    """Return coordinates in self and other."""
    return Coordinates({k: v for k, v in self.items() if k in other and v.equals(other[k])})

In [None]:
coords.intersection(Coordinates({'foo': [1, 2]}))

Coordinates:
  * foo      (foo) int64 16B 1 2

In [None]:
#|hide
test_eq(_, Coordinates({'foo': [1, 2]}))

### contain -


In [None]:
#| export
@patch
def contain(self:Coordinates, other:Coordinates):
    """Return true if all coordinates in `other` are in `self`, otherwise false."""
    return self.intersection(other).equals(other)

In [None]:
coords.contain(Coordinates({'foo': [1, 2]}))

True

In [None]:
#|hide
test_eq(_, True)

In [None]:
coords.contain(Coordinates({'foo': [1, 2], 'baz': [4, 5]}))

False

In [None]:
#|hide
test_eq(_, False)

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()