# Interpolation of a time series

This example shows how to interpolate a time series using the library.

In this example, we consider the time series of MSLA maps distributed by AVISO/CMEMS.

## Initialize Dataset

Here we load the dataset from the zarr store. Note that this very large dataset initializes nearly instantly, and we can see the full list of variables and coordinates.

In [1]:
import intake
cat = intake.Catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
ds = cat["sea_surface_height"].to_dask()
ds

<xarray.Dataset>
Dimensions:    (latitude: 720, longitude: 1440, nv: 2, time: 8901)
Coordinates:
    crs        int32 ...
    lat_bnds   (time, latitude, nv) float32 dask.array<chunksize=(5, 720, 2), meta=np.ndarray>
  * latitude   (latitude) float32 -89.875 -89.625 -89.375 ... 89.625 89.875
    lon_bnds   (longitude, nv) float32 dask.array<chunksize=(1440, 2), meta=np.ndarray>
  * longitude  (longitude) float32 0.125 0.375 0.625 ... 359.375 359.625 359.875
  * nv         (nv) int32 0 1
  * time       (time) datetime64[ns] 1993-01-01 1993-01-02 ... 2017-05-15
Data variables:
    adt        (time, latitude, longitude) float64 dask.array<chunksize=(5, 720, 1440), meta=np.ndarray>
    err        (time, latitude, longitude) float64 dask.array<chunksize=(5, 720, 1440), meta=np.ndarray>
    sla        (time, latitude, longitude) float64 dask.array<chunksize=(5, 720, 1440), meta=np.ndarray>
    ugos       (time, latitude, longitude) float64 dask.array<chunksize=(5, 720, 1440), meta=np.ndarray

In [None]:
# The "crs" coordinates must be removed to avoid a bug when determining the 3rd axis.
ds = ds.drop("crs")

## Handle the time series

We implement a class to handle a time series and on demand loading the data required to interpolate data over a specific time period.

In [2]:
import datetime
import numpy as np
import pandas as pd
import pyinterp.backends.xarray


class GridSeries:
    """Handling of MSLA AVISO maps"""

    def __init__(self, ds):
        self.ds = ds
        self.series, self.dt = self._load_ts()
        
    @staticmethod
    def _is_sorted(array):
        indices = np.argsort(array)
        return np.all(indices == np.arange(len(indices)))

    def _load_ts(self):
        """Loading the time series into memory."""
        time = self.ds.time
        assert self._is_sorted(time)

        series = pd.Series(time)
        frequency = set(np.diff(series.values.astype("datetime64[s]")).astype("int64"))
        if len(frequency) != 1:
            raise RuntimeError(
                "Time series does not have a constant step between two "
                f"grids: {frequency} seconds")
        return series, datetime.timedelta(seconds=float(frequency.pop()))
    
    def load_dataset(self, varname, start, end):
        """Loading the time series into memory for the defined period.

        Args:
            varname (str): Name of the variable to be loaded into memory.
            start (datetime.datetime): Date of the first map to be loaded.
            end (datetime.datetime): Date of the last map to be loaded.

        Return:
            pyinterp.backends.xarray.Grid3D: The interpolator handling the
            interpolation of the grid series.
        """
        if start < self.series.min() or end > self.series.max():
            raise IndexError(
                f"period [{start}, {end}] out of range [{self.series.min()}, "
                f"{self.series.max()}]")
        first = start - self.dt
        last = end + self.dt

        selected = self.series[(self.series >= first) & (self.series < last)]
        print(f"fetch data from {selected.min()} to {selected.max()}")
        
        data_array = ds[varname].isel(time=selected.index)
        return pyinterp.backends.xarray.Grid3D(data_array)

In [3]:
x=ds["sla"]

In [4]:
x.dims

('time', 'latitude', 'longitude')

## Load dataset

Finally, the functions necessary to load the test set into memory are added. This file contains several columns defining the float identifier, the date of the measurement, the longitude and the latitude of the measurement.

In [5]:
def cnes_jd_to_datetime(seconds):
    """Convert a date expressed in seconds since 1950 into a calendar
    date."""
    return datetime.datetime.utcfromtimestamp(
        ((seconds / 86400.0) - 7305.0) * 86400.0)


def load_positions():
    """Loading and formatting the dataset."""
    df = pd.read_csv("../tests/dataset/positions.csv",
                     header=None,
                     sep=r";",
                     usecols=[0, 1, 2, 3],
                     names=["id", "time", "lon", "lat"],
                     dtype=dict(id=np.uint32,
                                time=np.float64,
                                lon=np.float64,
                                lat=np.float64))
    df.mask(df == 1.8446744073709552e+19, np.nan, inplace=True)
    df["time"] = df["time"].apply(cnes_jd_to_datetime)
    df.set_index('time', inplace=True)
    df["sla"] = np.nan
    return df.sort_index()

df = load_positions()

## Implementation of interpolation

We create the object that will handle the download of data for the periods required for the interpolation.

In [6]:
gs = GridSeries(ds)

The function below, allows to cluster the processing period into sub-periods in order to load the grids in blocks.

In [7]:
def periods(df, grid_series, frequency='W'):
    """Return the list of periods covering the time series loaded in
    memory."""
    period_start = df.groupby(
        df.index.to_period(frequency))["sla"].count().index

    for start, end in zip(period_start, period_start[1:]):
        start = start.to_timestamp()
        if start < grid_series.series[0]:
            start = grid_series.series[0]
        end = end.to_timestamp()
        yield start, end
    yield end, df.index[-1] + grid_series.dt


Finally, the interpolation function is written for one of the sub-periods selected by the function `periods`.

In [8]:
def interpolate(df, grid_series, start, end):
    """Interpolate the time series over the defined period."""
    interpolator = grid_series.load_dataset("sla", start, end)
    mask = (df.index >= start) & (df.index < end)
    selected = df.loc[mask, ["lon", "lat"]]
    df.loc[mask, ["sla"]] = interpolator.trivariate(dict(
        longitude=selected["lon"].values,
        latitude=selected["lat"].values,
        time=selected.index.values),
        interpolator="inverse_distance_weighting",
        num_threads=0)

In [None]:
for start, end in periods(df, gs):
    interpolate(df, gs, start, end)

fetch data from 2015-12-27 00:00:00 to 2016-01-04 00:00:00
fetch data from 2016-01-03 00:00:00 to 2016-01-11 00:00:00
fetch data from 2016-01-10 00:00:00 to 2016-01-18 00:00:00
fetch data from 2016-01-17 00:00:00 to 2016-01-25 00:00:00
fetch data from 2016-01-24 00:00:00 to 2016-02-01 00:00:00
fetch data from 2016-01-31 00:00:00 to 2016-02-08 00:00:00
fetch data from 2016-02-07 00:00:00 to 2016-02-15 00:00:00
fetch data from 2016-02-14 00:00:00 to 2016-02-22 00:00:00
fetch data from 2016-02-21 00:00:00 to 2016-02-29 00:00:00
fetch data from 2016-02-28 00:00:00 to 2016-03-07 00:00:00
fetch data from 2016-03-06 00:00:00 to 2016-03-14 00:00:00
fetch data from 2016-03-13 00:00:00 to 2016-03-21 00:00:00
fetch data from 2016-03-20 00:00:00 to 2016-03-28 00:00:00
fetch data from 2016-03-27 00:00:00 to 2016-04-04 00:00:00
fetch data from 2016-04-03 00:00:00 to 2016-04-11 00:00:00
fetch data from 2016-04-10 00:00:00 to 2016-04-18 00:00:00
fetch data from 2016-04-17 00:00:00 to 2016-04-25 00:00:

Visualization of the SLA for a float.

In [None]:
float_id = 62423050
selected_float = df[df.id == float_id]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
%matplotlib inline

rule = mdates.rrulewrapper(mdates.MONTHLY, bymonthday=1, interval=1)
formatter = mdates.DateFormatter('%b/%m/%d')
loc = mdates.RRuleLocator(rule)

def plot(ax, x, y, title):
    ax.plot(x, y)
    ax.xaxis.set_major_locator(loc)
    ax.xaxis.set_major_formatter(formatter)
    ax.grid(True)
    ax.set_title(title)
    labels = ax.get_xticklabels()
    plt.setp(labels, rotation=30, fontsize=10)    

In [None]:
fig = plt.figure(figsize=(9, 5))
ax = fig.add_subplot(131)
plot(ax, selected_float.index, selected_float.sla, "SLA")
ax = fig.add_subplot(132)
plot(ax, selected_float.index, selected_float.lon, "longitude")
ax = fig.add_subplot(133)
plot(ax, selected_float.index, selected_float.lat, "latitude")
