# xbatcher / pytorch dataloader turtorial

If you have not installed xbatcher, can be done via pip or conda. The docs for xbatcher are here:

https://xbatcher.readthedocs.io/en/latest/

In [1]:
# !pip install xbatcher

# or

# conda install -c conda-forge xbatcher

In [2]:
import numpy as np

import xarray as xr

# importing xbatcher
import xbatcher
from xbatcher import BatchGenerator
from xbatcher.loaders.torch import IterableDataset

In [3]:
# pytorch specfic loaders
import torch

from torch.utils.data import TensorDataset, DataLoader

In [4]:
# nice way to show imported packages
from watermark import watermark
%load_ext watermark

In [5]:
%watermark --iversions

xbatcher: 0.3.0
torch   : 1.13.0
xarray  : 2022.11.0
numpy   : 1.23.5



## xarray:

This is the standard toy dataset for xarray testing

In [6]:
ds = xr.tutorial.load_dataset("air_temperature")
print('number of timesteps:', len(ds.time))

# adding a second variable to be somewhat more realistic than a 1 variable dataset
ds['air2'] = ds.air+42.42 #nonsense numbers
ds['target_air'] = ds.air-100 #nonsense numbers

ds

number of timesteps: 2920


## xbatcher intro:

Size of batches of explained [here](https://xbatcher.readthedocs.io/en/latest/demo.html#Controlling-the-size/shape-of-batches) .

### pytorch stuff

PyTorch data loader info: https://pytorch.org/docs/stable/data.html

you can also define batch sizes in pytorch, but will have less control/might not scale as well. We will not be using this feature.

Worth poking around with to_array() to understand how it's dumping out the np arrays

In [7]:
x = ds[['air', 'air2']]
y = ds['target_air']

In [9]:
number_of_batches = 12

x_gen = BatchGenerator(x, {"time": number_of_batches})
y_gen = BatchGenerator(y, {"time": number_of_batches})

In [10]:
dataset = IterableDataset(x_gen, y_gen) # this is an xbatcher method

In [11]:
loader = torch.utils.data.DataLoader(dataset, batch_size=None)

In [12]:
type(loader)

torch.utils.data.dataloader.DataLoader

This dataset should probably be processed, normalized, etc., and maybe a differnt shape would be better for your use case. But this should serve as an example for a relatively simple case to go from a multi-variable xarray dataset to a pytorch data loader. 