# TaskLoader tour

This notebook demonstrates how you can set up `TaskLoader` objects to load different kinds of tasks that can be used to train models that accomplish different objectives. The `TaskLoader` is very flexible and can be used to load tasks for spatial interpolation, forecasting, and downscaling, or some combination of these. The ability to use `TaskLoader`s to automatically compose ConvNP models in TensorFlow or PyTorch adds to the convenience of `deepsensor`.

In this notebook, we will demonstrate a few different tasks that can be loaded using `TaskLoader` objects. Antarctic temperature will be our target variable, using gridded ERA5 data and off-grid station data to construct tasks. However, the same principles can be applied to any other target variable and data source, as long as the data is in `xarray` or `pandas` format.

## Imports/set-up

In [4]:
# Load the "autoreload" extension so that code can change
%load_ext autoreload
# Always reload modules so that as you change code in src, it gets loaded
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import deepsensor.torch as deepsensor
# import deepsensor.tensorflow as deepsensor

In [6]:
from deepsensor.data.processor import DataProcessor
from deepsensor.data.loader import TaskLoader
from deepsensor.model.convnp import ConvNP

In [7]:
import pandas as pd
import xarray as xr

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")

In [8]:
date = "2018-01-01"

## Load data

In [9]:
# era5_raw_ds = xr.open_mfdataset('../../deepsensor_old/data/antarctica/gridded/processed/*/*.nc')
era5_raw_ds = xr.open_mfdataset('../../deepsensor_old/data/antarctica/gridded/interim/tas_anom/*.nc')
era5_raw_ds

Unnamed: 0,Array,Chunk
Bytes,202.60 kiB,2.86 kiB
Shape,"(25933,)","(366,)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 202.60 kiB 2.86 kiB Shape (25933,) (366,) Dask graph 71 chunks in 143 graph layers Data type int64 numpy.ndarray",25933  1,

Unnamed: 0,Array,Chunk
Bytes,202.60 kiB,2.86 kiB
Shape,"(25933,)","(366,)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.57 GiB,109.46 MiB
Shape,"(25933, 280, 280)","(366, 280, 280)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.57 GiB 109.46 MiB Shape (25933, 280, 280) (366, 280, 280) Dask graph 71 chunks in 143 graph layers Data type float32 numpy.ndarray",280  280  25933,

Unnamed: 0,Array,Chunk
Bytes,7.57 GiB,109.46 MiB
Shape,"(25933, 280, 280)","(366, 280, 280)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
aux_raw_ds = xr.open_mfdataset('../../deepsensor_old/data/antarctica/auxiliary/interim/*25000m/*.nc')
aux_raw_ds = aux_raw_ds[['surface', 'mask']]
aux_raw_ds

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 306.25 kiB 306.25 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 306.25 kiB 306.25 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 306.25 kiB 306.25 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,612.50 kiB,612.50 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 612.50 kiB 612.50 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,612.50 kiB,612.50 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [15]:
station_raw_df = pd.read_csv('../../deepsensor_old/data/antarctica/station/interim/XY_station.csv')
station_raw_df = station_raw_df.rename(columns={'date': 'time'})
station_raw_df['time'] = pd.to_datetime(station_raw_df['time'])
station_raw_df = station_raw_df.set_index(['time', 'y', 'x', 'station']).sort_index()[['tas']]
station_raw_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,tas
time,y,x,station,Unnamed: 4_level_1
1948-04-01,-3.638593e+06,1.401585e+06,Macquarie_Island,5.885714
1948-04-02,-3.638593e+06,1.401585e+06,Macquarie_Island,6.675000
1948-04-03,-3.638593e+06,1.401585e+06,Macquarie_Island,4.775000
1948-04-04,-3.638593e+06,1.401585e+06,Macquarie_Island,5.150000
1948-04-05,-3.638593e+06,1.401585e+06,Macquarie_Island,4.037500
...,...,...,...,...
2022-06-14,1.447591e+06,-6.904655e+05,Halley_6a,-22.577778
2022-06-15,5.380525e+05,-1.606516e+06,Sky_Blu,-26.587500
2022-06-15,8.072843e+05,-1.351549e+06,Limbert,-23.887500
2022-06-15,9.270759e+05,-2.308923e+06,Rothera,-9.895833


## Normalise data

In [16]:
data_processor = DataProcessor(x1_name='y', x1_map=(0, 3.5e6), x2_name='x', x2_map=(0, 3.5e6))

In [17]:
era5_ds, station_df = data_processor([era5_raw_ds, station_raw_df])
aux_ds = data_processor(aux_raw_ds, method="min_max")

## Let's load some tasks!

### ERA5 spatial interpolation

A spatial interapolation model can be trained by randomly sampling grid cells from `xarray` objects, which is simple to do with `TaskLoader`.

In [23]:
import time
task_loader = TaskLoader(context=era5_ds['t2m'], target=era5_ds['t2m'], xarray_interp_method="linear", discrete_xarray_sampling=False)
tic = time.time()
for i in range(1000):
    task = task_loader(date, 0.1, 0.1)
print(f"Continuous and linear, time taken: {time.time() - tic:.2f} s")

import time
task_loader = TaskLoader(context=era5_ds['t2m'], target=era5_ds['t2m'], xarray_interp_method="nearest", discrete_xarray_sampling=False)
tic = time.time()
for i in range(1000):
    task = task_loader(date, 0.1, 0.1)
print(f"Continuous and nearest, time taken: {time.time() - tic:.2f} s")

task_loader = TaskLoader(context=era5_ds['t2m'], target=era5_ds['t2m'], discrete_xarray_sampling=True)
tic = time.time()
for i in range(1000):
    task = task_loader(date, 0.1, 0.1)
print(f"Discrete, time taken: {time.time() - tic:.2f} s")

Continuous and linear, time taken: 12.62 s
Continuous and nearest, time taken: 12.23 s
Discrete, time taken: 3.78 s


In [None]:
task_loader = TaskLoader(context=[era5_ds['t2m'], aux_ds], target=era5_ds['t2m'])
print(task_loader)

In [None]:
model = ConvNP(data_processor, task_loader, verbose=False)
task = task_loader(date, (0.1, "all"), "all")
print(task)

In [None]:
fig = deepsensor.plot.context_encoding(model, task, task_loader)
plt.show()

### Station spatial interpolation

Generating interpolation tasks from `pandas` station data is slightly more involved than with `xarray` gridded data. We must set up a 'link' between the station context set and the station target set when instantiating the `TaskLoader`. This is used for splitting the data into context and target sets when generating `Task`s from the `TaskLoader`.

In [None]:
task_loader = TaskLoader(context=[station_df, aux_ds], target=station_df, links=[(0, 0)])
print(task_loader)

In [None]:
model = ConvNP(data_processor, task_loader, verbose=False)
task = task_loader(date, ("split", "all"), "split", split_frac=0.7)
print(task)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 7), subplot_kw=dict(projection=ccrs.LambertAzimuthalEqualArea(0, -90)))
ax.set_extent([-3.5e6, 3.5e6, -3.5e6, 3.5e6], crs=ccrs.LambertAzimuthalEqualArea(0, -90))
ax.stock_img()
ax.coastlines(linewidth=0.25)
deepsensor.plot.offgrid_context(ax, task, data_processor, task_loader, plot_target=True, add_legend=True, linewidths=0.5)
plt.show()

In [None]:
# Zoom-in on the station context set
fig = deepsensor.plot.context_encoding(model, task, task_loader, context_set_idxs=0, size=7, return_axes=True)
plt.show()

In [None]:
# Plot the whole encoding
fig = deepsensor.plot.context_encoding(model, task, task_loader)
plt.show()

### ERA5 forecasting

Using the `context_delta_t` and `target_delta_t` arguments, you can specify a time difference between the context and target.
In this case, we generate a task where the context is the current and previous time step and the target is the next time step.

By printing the `repr` of the `TaskLoader` object, we get more verbose variable IDs showing the time indexes of the context and target sets.

In [None]:
task_loader = TaskLoader(context=[era5_ds['t2m'], era5_ds['t2m'], aux_ds], target=era5_ds['t2m'],
                         context_delta_t=[-1, 0, 0], target_delta_t=1)
print(repr(task_loader))

In [None]:
model = ConvNP(data_processor, task_loader, verbose=False)
task = task_loader(date, "all", "all")
print(task)

In [None]:
fig = deepsensor.plot.context_encoding(model, task, task_loader)
plt.show()

### ERA5 downscaling using station targets

We will pass the station data as the target set. This set-up can be used to train a model to downscale gridded data to station data.

TODO: Passing hi-res auxiliary information via an Anna-style output MLP is not currently supported.

In [None]:
task_loader = TaskLoader(context=[era5_ds['t2m'], aux_ds], target=station_df)
print(task_loader)

In [None]:
model = ConvNP(data_processor, task_loader, verbose=False)
task = task_loader(date, "all", "all")
print(task)

In [None]:
fig = deepsensor.plot.context_encoding(model, task, task_loader)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 7), subplot_kw=dict(projection=ccrs.LambertAzimuthalEqualArea(0, -90)))
ax.set_extent([-3.5e6, 3.5e6, -3.5e6, 3.5e6], crs=ccrs.LambertAzimuthalEqualArea(0, -90))
ax.stock_img()
ax.coastlines(linewidth=0.25)
deepsensor.plot.offgrid_context(ax, task, data_processor, task_loader, plot_target=True, add_legend=True, linewidths=0.5)
plt.show()

### TODO: Satellite data interpolation

## What's missing?

Is some functionality missing that you would like to see? Please open an issue on the [GitHub repository](https://github.com/tom-andersson/deepsensor/tree/main).

