-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/tf_dataset_adapter
- Loading branch information
Showing
32 changed files
with
339 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from jax._src.lax.lax import remaining | ||
from torch.utils.data import DataLoader | ||
|
||
from .data_adapter import DataAdapter | ||
from .utils import is_none_or_empty, map_structure, list_to_tuple | ||
|
||
|
||
class TorchDataLoaderAdapter(DataAdapter): | ||
"""Adapter that handles torch Dataloaders.""" | ||
|
||
@staticmethod | ||
def can_handle(x, y=None): | ||
return isinstance(x, DataLoader) | ||
|
||
def __init__( | ||
self, | ||
x: DataLoader, | ||
y=None, | ||
steps=None, | ||
sample_weights=None, | ||
training=True, | ||
**kwargs, | ||
): | ||
|
||
if not is_none_or_empty(y): | ||
raise ValueError( | ||
"`y` argument is not supported when using " "torch Dataloader as input." | ||
) | ||
if not is_none_or_empty(sample_weights): | ||
raise ValueError( | ||
"`sample_weight` argument is not supported when using " | ||
"torch Dataloader as input." | ||
) | ||
|
||
super().__init__(x, y, **kwargs) | ||
|
||
self.training = training | ||
self.steps = steps | ||
self._batch_size = x.batch_size | ||
self._dataset = x | ||
|
||
self.current_step = 0 | ||
|
||
def get_dataset(self): | ||
def parse_dataloader_gen(): | ||
self.current_step = 0 | ||
for batch in iter(self._dataset): | ||
self.current_step += 1 | ||
batch = map_structure(lambda x: x.cpu().numpy(), list_to_tuple(batch)) | ||
yield batch | ||
|
||
return parse_dataloader_gen | ||
|
||
def get_size(self): | ||
try: | ||
return len(self._dataset) | ||
except Exception: | ||
return None | ||
|
||
@property | ||
def batch_size(self): | ||
return self.representative_batch_size | ||
|
||
@property | ||
def representative_batch_size(self): | ||
return self._batch_size | ||
|
||
def has_partial_batch(self): | ||
return False | ||
|
||
@property | ||
def partial_batch_size(self): | ||
return | ||
|
||
def should_recreate_iterator(self): | ||
# if in eval mode should not recreate iterator | ||
# but if in train mode and steps not provided, should recreate at end of each epoch | ||
if not self.training or self.steps is None: | ||
return self.training | ||
|
||
steps_dataset = self.get_size() | ||
if steps_dataset is None: | ||
return False | ||
|
||
remaining_steps = steps_dataset - self.current_step | ||
# if remaining steps less than needed steps, should recreate dataloader | ||
# TODO: This will drop the last steps of data, how to avoid this? | ||
if remaining_steps < self.steps: | ||
return True | ||
else: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import math | ||
from unittest import TestCase | ||
|
||
import numpy as np | ||
import torch | ||
from elegy.data.torch_dataloader_adapter import TorchDataLoaderAdapter | ||
from torch.utils.data import DataLoader, TensorDataset | ||
|
||
|
||
class ArrayDataAdapterTest(TestCase): | ||
def test_basic(self): | ||
batch_size = 10 | ||
epochs = 1 | ||
x = np.array(np.random.uniform(size=(100, 32, 32, 3))) | ||
y = np.array(np.random.uniform(size=(100, 1))) | ||
|
||
dataset = TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) | ||
dataloader = DataLoader(dataset, batch_size=batch_size) | ||
|
||
data_adapter = TorchDataLoaderAdapter(dataloader) | ||
|
||
dataset_length = x.shape[0] | ||
num_steps = math.ceil(dataset_length / batch_size) * epochs | ||
iterator_fn = data_adapter.get_dataset() | ||
for i, batch in zip(range(num_steps), iterator_fn()): | ||
batch_x, batch_y = batch | ||
assert batch_x.shape == (batch_size, *x.shape[1:]) | ||
assert batch_y.shape == (batch_size, *y.shape[1:]) | ||
np.testing.assert_array_equal( | ||
batch_x, | ||
x[ | ||
(i * batch_size) | ||
% dataset_length : (i * batch_size) | ||
% dataset_length | ||
+ batch_size | ||
], | ||
) | ||
|
||
assert data_adapter.get_size() * batch_size == x.shape[0] | ||
assert data_adapter.batch_size == batch_size | ||
assert i == num_steps - 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.