Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch dataloader #25

Merged
merged 9 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xbatcher
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
55 changes: 45 additions & 10 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import itertools
from collections import OrderedDict
from collections.abc import Iterator
from typing import Any, Dict, Hashable

import xarray as xr

Expand Down Expand Up @@ -99,12 +101,12 @@ class BatchGenerator:

def __init__(
self,
ds,
input_dims,
input_overlap={},
batch_dims={},
concat_input_dims=False,
preload_batch=True,
ds: xr.Dataset,
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
):

self.ds = _as_xarray_dataset(ds)
Expand All @@ -115,7 +117,38 @@ def __init__(
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch

def __iter__(self):
self._batches: Dict[
int, Any
] = self._gen_batches() # dict cache for batches
# in the future, we can make this a lru cache or similar thing (cachey?)

def __iter__(self) -> Iterator[xr.Dataset]:
for batch in self._batches.values():
yield batch

def __len__(self) -> int:
return len(self._batches)

def __getitem__(self, idx: int) -> xr.Dataset:

if not isinstance(idx, int):
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

if idx < 0:
idx = list(self._batches)[idx]

if idx in self._batches:
return self._batches[idx]
else:
raise IndexError('list index out of range')

def _gen_batches(self) -> dict:
# in the future, we will want to do the batch generation lazily
# going the eager route for now is allowing me to fill out the loader api
# but it is likely to perform poorly.
Comment on lines +146 to +149
Copy link
Contributor Author

@jhamman jhamman Aug 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this as something so discuss / work out a design for. It feels quite important that we are able to generate arbitrary batches on the fly. The current implementation eagerly generates batches which will not scale well. However, the pure generator approach doesn't work if you need to randomly access batches (eg via getitem).

batches = []
for ds_batch in self._iterate_batch_dims(self.ds):
if self.preload_batch:
ds_batch.load()
Expand All @@ -132,13 +165,15 @@ def __iter__(self):
new_input_dims = [
dim + new_dim_suffix for dim in self.input_dims
]
yield _maybe_stack_batch_dims(dsc, new_input_dims)
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
else:
for ds_input in input_generator:
yield _maybe_stack_batch_dims(
ds_input, list(self.input_dims)
batches.append(
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
)

return dict(zip(range(len(batches)), batches))

def _iterate_batch_dims(self, ds):
return _iterate_through_dataset(ds, self.batch_dims)

Expand Down
Empty file added xbatcher/loaders/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any, Callable, Optional, Tuple

import torch

# Notes:
# This module includes two PyTorch datasets.
# - The MapDataset provides an indexable interface
# - The IterableDataset provides a simple iterable interface
# Both can be provided as arguments to the the Torch DataLoader
# Assumptions made:
# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset)
# TODOs:
# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y
# - need to test with additional dataset parameters (e.g. transforms)


class MapDataset(torch.utils.data.Dataset):
def __init__(
self,
X_generator,
y_generator,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher

Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
'''
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform

def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx) -> Tuple[Any, Any]:
if torch.is_tensor(idx):
idx = idx.tolist()
assert len(idx) == 1

# TODO: figure out the dataset -> array workflow
# currently hardcoding a variable name
X_batch = self.X_generator[idx]['x'].data
y_batch = self.y_generator[idx]['y'].data

if self.transform:
X_batch = self.transform(X_batch)

if self.target_transform:
y_batch = self.target_transform(y_batch)
print('x_batch.shape', X_batch.shape)
return X_batch, y_batch


class IterableDataset(torch.utils.data.IterableDataset):
def __init__(
self,
X_generator,
y_generator,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher

Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
'''

self.X_generator = X_generator
self.y_generator = y_generator

def __iter__(self):
for xb, yb in zip(self.X_generator, self.y_generator):
yield (xb['x'].data, yb['y'].data)
22 changes: 22 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ def sample_ds_1d():
return ds


@pytest.mark.parametrize('bsize', [5, 6])
def test_batcher_lenth(sample_ds_1d, bsize):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
assert len(bg) == sample_ds_1d.dims['x'] // bsize


def test_batcher_getitem(sample_ds_1d):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})

# first batch
assert bg[0].dims['x'] == 10
# last batch
assert bg[-1].dims['x'] == 10
# raises IndexError for out of range index
with pytest.raises(IndexError, match=r'list index out of range'):
bg[9999999]

# raises NotImplementedError for iterable index
with pytest.raises(NotImplementedError):
bg[[1, 2, 3]]


# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
# Should we enforce that each batch size always has to be the same
@pytest.mark.parametrize('bsize', [5, 10])
Expand Down
78 changes: 78 additions & 0 deletions xbatcher/tests/test_torch_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pytest
import xarray as xr

torch = pytest.importorskip('torch')

from xbatcher import BatchGenerator
from xbatcher.loaders.torch import IterableDataset, MapDataset


@pytest.fixture(scope='module')
def ds_xy():
n_samples = 100
n_features = 5
ds = xr.Dataset(
{
'x': (
['sample', 'feature'],
np.random.random((n_samples, n_features)),
),
'y': (['sample'], np.random.random(n_samples)),
},
)
return ds


def test_map_dataset(ds_xy):

x = ds_xy['x']
y = ds_xy['y']

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

dataset = MapDataset(x_gen, y_gen)

# test __getitem__
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert isinstance(x_batch, np.ndarray)

# test __len__
assert len(dataset) == len(x_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset)

for x_batch, y_batch in loader:
assert len(x_batch) == len(y_batch)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
# TODO: also need to revisit the variable extraction bits here
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])


def test_iterable_dataset(ds_xy):

x = ds_xy['x']
y = ds_xy['y']

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

dataset = IterableDataset(x_gen, y_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset)

for x_batch, y_batch in loader:
assert len(x_batch) == len(y_batch)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
# TODO: also need to revisit the variable extraction bits here
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])