# xbatcher / pytorch dataloader turtorial

If you have not installed xbatcher, can be done via pip or conda. The docs for xbatcher are here:

https://xbatcher.readthedocs.io/en/latest/

In [1]:
# !pip install xbatcher

# or

# conda install -c conda-forge xbatcher

In [2]:
import numpy as np

import xarray as xr
import xbatcher

In [3]:
# pytorch specfic loaders
import torch

from torch.utils.data import TensorDataset, DataLoader

In [4]:
# nice way to show imported packages
from watermark import watermark
%load_ext watermark

In [5]:
%watermark --iversions

xarray  : 2022.11.0
torch   : 1.13.0
xbatcher: 0.3.0
numpy   : 1.23.5



## xarray:

This is the standard toy dataset for xarray testing

In [6]:
ds = xr.tutorial.load_dataset("air_temperature")
print('number of timesteps:', len(ds.time))

# adding a second variable to be somewhat more realistic than a 1 variable dataset
ds['air2'] = ds.air+42.42 #nonsense numbers

ds

number of timesteps: 2920


## xbatcher intro:

Note: the xbatcher turtorial uses data arrays, this is using the dataset

In [7]:
# number of batches, adjust this!
number_of_batches = 13

div = np.round((len(ds.time))/number_of_batches, 3)
div

224.615

In [8]:
bgen = xbatcher.BatchGenerator(ds, {'time': number_of_batches})

for batch in bgen:
    pass
batch

In [9]:
print('Is the first batch and the last batch the same size, even if not divided cleanly')
len(batch.isel(time=0).sample) == len(batch.isel(time=number_of_batches-1).sample)

Is the first batch and the last batch the same size, even if not divided cleanly


True

This behavior is explained in the xbatcher docs [here](https://xbatcher.readthedocs.io/en/latest/demo.html#Controlling-the-size/shape-of-batches) .

### pytorch stuff

PyTorch data loader info: https://pytorch.org/docs/stable/data.html

you can also define batch sizes in pytorch, but will have less control/might not scale as well. We will not be using this feature.

Worth poking around with to_array() to understand how it's dumping out the np arrays

In [10]:
np.shape(batch.to_array().values)

(2, 1325, 13)

In [11]:
a_tensor = torch.from_numpy(batch.to_array().values)

In [12]:
data_loader = DataLoader(TensorDataset(a_tensor), 
                         batch_size=1, 
                         shuffle=False)

In [13]:
data_loader

<torch.utils.data.dataloader.DataLoader at 0x7f9266e0aca0>

This dataset should probably be processed, normalized, etc., and maybe a differnt shape would be better for your use case. But this should serve as an example for a relatively simple case to go from a multi-variable xarray dataset to a pytorch data loader. 