Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch dataloader #25

Merged
merged 9 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: double-quote-string-fixer

- repo: https://github.com/psf/black
rev: 21.12b0
rev: 22.1.0
hooks:
- id: black
args: ["--line-length", "80", "--skip-string-normalization"]
Expand All @@ -37,3 +37,16 @@ repos:
hooks:
- id: prettier
language_version: system

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-setuptools,
types-pkg_resources,
# Dependencies that are typed
numpy,
xarray,
]
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import pytest


Expand Down
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pytest
torch
coverage
pytest-cov
adlfs
-r requirements.txt
20 changes: 14 additions & 6 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ API reference

This page provides an auto-generated summary of Xbatcher's API.

Core
====

.. autoclass:: xbatcher.BatchGenerator
:members:

Dataset.batch and DataArray.batch
=================================

Expand All @@ -22,3 +16,17 @@ Dataset.batch and DataArray.batch

Dataset.batch.generator
DataArray.batch.generator

Core
====

.. autoclass:: xbatcher.BatchGenerator
:members:

Dataloaders
===========
.. autoclass:: xbatcher.loaders.torch.MapDataset
:members:

.. autoclass:: xbatcher.loaders.torch.IterableDataset
:members:
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# All configuration values have a default; values that are commented out
# serve to show the default.

# type: ignore

import os
import sys

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xbatcher
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# type: ignore
import os

from setuptools import find_packages, setup
Expand Down
18 changes: 18 additions & 0 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ def generator(self, *args, **kwargs):
Keyword arguments to pass to the `BatchGenerator` constructor.
'''
return BatchGenerator(self._obj, *args, **kwargs)


@xr.register_dataarray_accessor('torch')
class TorchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def to_tensor(self):
"""Convert this DataArray to a torch.Tensor"""
import torch

return torch.tensor(self._obj.data)

def to_named_tensor(self):
"""Convert this DataArray to a torch.Tensor with named dimensions"""
import torch

return torch.tensor(self._obj.data, names=self._obj.dims)
56 changes: 45 additions & 11 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
from collections import OrderedDict
from typing import Any, Dict, Hashable, Iterator

import xarray as xr

Expand Down Expand Up @@ -99,12 +100,12 @@ class BatchGenerator:

def __init__(
self,
ds,
input_dims,
input_overlap={},
batch_dims={},
concat_input_dims=False,
preload_batch=True,
ds: xr.Dataset,
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
):

self.ds = _as_xarray_dataset(ds)
Expand All @@ -115,7 +116,38 @@ def __init__(
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch

def __iter__(self):
self._batches: Dict[
int, Any
] = self._gen_batches() # dict cache for batches
# in the future, we can make this a lru cache or similar thing (cachey?)

def __iter__(self) -> Iterator[xr.Dataset]:
for batch in self._batches.values():
yield batch

def __len__(self) -> int:
return len(self._batches)

def __getitem__(self, idx: int) -> xr.Dataset:

if not isinstance(idx, int):
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

if idx < 0:
idx = list(self._batches)[idx]

if idx in self._batches:
return self._batches[idx]
else:
raise IndexError('list index out of range')

def _gen_batches(self) -> dict:
# in the future, we will want to do the batch generation lazily
# going the eager route for now is allowing me to fill out the loader api
# but it is likely to perform poorly.
Comment on lines +146 to +149
Copy link
Contributor Author

@jhamman jhamman Aug 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this as something so discuss / work out a design for. It feels quite important that we are able to generate arbitrary batches on the fly. The current implementation eagerly generates batches which will not scale well. However, the pure generator approach doesn't work if you need to randomly access batches (eg via getitem).

batches = []
for ds_batch in self._iterate_batch_dims(self.ds):
if self.preload_batch:
ds_batch.load()
Expand All @@ -130,15 +162,17 @@ def __iter__(self):
]
dsc = xr.concat(all_dsets, dim='input_batch')
new_input_dims = [
dim + new_dim_suffix for dim in self.input_dims
str(dim) + new_dim_suffix for dim in self.input_dims
]
yield _maybe_stack_batch_dims(dsc, new_input_dims)
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
else:
for ds_input in input_generator:
yield _maybe_stack_batch_dims(
ds_input, list(self.input_dims)
batches.append(
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
)

return dict(zip(range(len(batches)), batches))

def _iterate_batch_dims(self, ds):
return _iterate_through_dataset(ds, self.batch_dims)

Expand Down
Empty file added xbatcher/loaders/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any, Callable, Optional, Tuple

import torch

# Notes:
# This module includes two PyTorch datasets.
# - The MapDataset provides an indexable interface
# - The IterableDataset provides a simple iterable interface
# Both can be provided as arguments to the the Torch DataLoader
# Assumptions made:
# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset)
# TODOs:
# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y
# - need to test with additional dataset parameters (e.g. transforms)


class MapDataset(torch.utils.data.Dataset):
def __init__(
self,
X_generator,
y_generator,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher

Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
'''
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform

def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx) -> Tuple[Any, Any]:
if torch.is_tensor(idx):
idx = idx.tolist()
if len(idx) == 1:
idx = idx[0]
else:
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

# TODO: figure out the dataset -> array workflow
# currently hardcoding a variable name
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
y_batch = self.y_generator[idx]['y'].torch.to_tensor()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging that we can't use named tensors here while we wait for pytorch/pytorch#29010


if self.transform:
X_batch = self.transform(X_batch)

if self.target_transform:
y_batch = self.target_transform(y_batch)
return X_batch, y_batch


class IterableDataset(torch.utils.data.IterableDataset):
def __init__(
self,
X_generator,
y_generator,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher

Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
'''

self.X_generator = X_generator
self.y_generator = y_generator

def __iter__(self):
for xb, yb in zip(self.X_generator, self.y_generator):
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())
22 changes: 22 additions & 0 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d):
assert isinstance(bg_acc, BatchGenerator)
for batch_class, batch_acc in zip(bg_class, bg_acc):
assert batch_class.equals(batch_acc)


def test_torch_to_tensor(sample_ds_3d):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == (None, None, None)
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)


def test_torch_to_named_tensor(sample_ds_3d):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_named_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == da.dims
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)
22 changes: 22 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ def test_constructor_coerces_to_dataset():
assert bg.ds.equals(da.to_dataset())


@pytest.mark.parametrize('bsize', [5, 6])
def test_batcher_lenth(sample_ds_1d, bsize):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
assert len(bg) == sample_ds_1d.dims['x'] // bsize


def test_batcher_getitem(sample_ds_1d):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})

# first batch
assert bg[0].dims['x'] == 10
# last batch
assert bg[-1].dims['x'] == 10
# raises IndexError for out of range index
with pytest.raises(IndexError, match=r'list index out of range'):
bg[9999999]

# raises NotImplementedError for iterable index
with pytest.raises(NotImplementedError):
bg[[1, 2, 3]]


# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
# Should we enforce that each batch size always has to be the same
@pytest.mark.parametrize('bsize', [5, 10])
Expand Down