Skip to content

Commit

Permalink
Merge branch 'master' into feature/tf_dataset_adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
charlielito committed Jan 22, 2021
2 parents 146165b + 1a2fcc9 commit 338217f
Show file tree
Hide file tree
Showing 32 changed files with 339 additions and 49 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ And you are done! For more information check out:

## Why Jax & Elegy?

Given all the well-stablished Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.
Given all the well-established Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.

#### Why Jax?

Expand Down Expand Up @@ -127,7 +127,7 @@ class Linear(elegy.Module):
For more information checkout the **Reference API** section in the [Documentation](https://poets-ai.github.io/elegy).

## Contributing
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contibute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing).
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing).

## About Us
We are some friends passionate about ML.
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/modules-losses-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ model.fit(
...
)
```
Here only `n` and `o` are requested by name and you get as input its values `b` and `c`, the variable `m` with the content of `a` is safely ignored. If you want to request all the avaiable inputs you can use `**kwargs`.
Here only `n` and `o` are requested by name, and you get as input its values `b` and `c`, the variable `m` with the content of `a` is safely ignored. If you want to request all the available inputs you can use `**kwargs`.

## Losses
Losses can request all the available parameters that Elegy provides for dependency injection. A typical loss will request the `y_true` and `y_pred` values (as its common / enforced in Keras). The Mean Squared Error loss for example is easily defined in these terms:
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ And you are done! For more information check out:

## Why Jax & Elegy?

Given all the well-stablished Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.
Given all the well-established Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.

#### Why Jax?

Expand Down Expand Up @@ -127,7 +127,7 @@ class Linear(elegy.Module):
For more information checkout the **Reference API** section in the [Documentation](https://poets-ai.github.io/elegy).

## Contributing
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contibute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing).
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing).

## About Us
We are some friends passionate about ML.
Expand Down
3 changes: 3 additions & 0 deletions elegy/callbacks/progbar_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def update(self, current, values=None, finalize=None):
elif self.verbose == 3:
self.compact_table_progress(current, finalize)

elif self.verbose == 4 and finalize:
self.compact_table_progress(current, finalize)

self._last_update = now

def add(self, n, values=None):
Expand Down
5 changes: 3 additions & 2 deletions elegy/data/array_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import typing as tp
from operator import itemgetter

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -103,9 +104,9 @@ def dataset_generator():

# # Drop last batch
# if drop_remainder and len(indices) < batch_size:
# print("Droping!")
# print("Dropping!")
# continue
inputs_slices = map_structure(lambda x: x[indices], inputs)
inputs_slices = map_structure(itemgetter(indices), inputs)

yield inputs_slices

Expand Down
2 changes: 1 addition & 1 deletion elegy/data/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DataAdapter(object):
def can_handle(x, y=None):
"""Whether the current DataAdapter could handle the input x and y.
Structure wise, x and y can be single object, or list of objects if there
multiple input/output, or dictionary of objects when the intput/output are
multiple input/output, or dictionary of objects when the input/output are
named.
Arguments:
x: input features.
Expand Down
8 changes: 7 additions & 1 deletion elegy/data/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from .tf_dataset_adapter import TFDatasetAdapter
except ImportError:
TFDatasetAdapter = None
try:
from .torch_dataloader_adapter import TorchDataLoaderAdapter
except ImportError:
TorchDataLoaderAdapter = None

ALL_ADAPTER_CLS = [
ArrayDataAdapter,
Expand All @@ -24,6 +28,8 @@

if TFDatasetAdapter is not None:
ALL_ADAPTER_CLS.append(TFDatasetAdapter)
if TorchDataLoaderAdapter is not None:
ALL_ADAPTER_CLS.append(TorchDataLoaderAdapter)


class DataHandler(object):
Expand Down Expand Up @@ -99,7 +105,7 @@ def catch_stop_iteration(self):
"Make sure that your dataset or generator can generate at "
"least `steps_per_epoch * epochs` batches (in this case, "
"{} batches). You may need to use the repeat() function "
"when building your dataset.".format(
"if using tf.data.Dataset.".format(
total_epochs * self._inferred_steps
)
)
Expand Down
4 changes: 2 additions & 2 deletions elegy/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def dispatch_tasks(self, batch_of_indices):

def shutdown(self):
self.worker_pool.close()
for aresult in self.async_results_queue:
for a_result in self.async_results_queue:
# wait for remaining tasks to finish
# process workers will hang otherwise
aresult.wait(timeout=self.timeout)
a_result.wait(timeout=self.timeout)
self.worker_pool.terminate()
self.worker_pool.join()

Expand Down
4 changes: 1 addition & 3 deletions elegy/data/generator_adapter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import itertools
import typing as tp

import numpy as np

from .data_adapter import DataAdapter
from .utils import (
assert_not_namedtuple,
flatten,
is_none_or_empty,
pack_x_y_sample_weight,
unpack_x_y_sample_weight,
flatten,
)


Expand Down
91 changes: 91 additions & 0 deletions elegy/data/torch_dataloader_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from jax._src.lax.lax import remaining
from torch.utils.data import DataLoader

from .data_adapter import DataAdapter
from .utils import is_none_or_empty, map_structure, list_to_tuple


class TorchDataLoaderAdapter(DataAdapter):
"""Adapter that handles torch Dataloaders."""

@staticmethod
def can_handle(x, y=None):
return isinstance(x, DataLoader)

def __init__(
self,
x: DataLoader,
y=None,
steps=None,
sample_weights=None,
training=True,
**kwargs,
):

if not is_none_or_empty(y):
raise ValueError(
"`y` argument is not supported when using " "torch Dataloader as input."
)
if not is_none_or_empty(sample_weights):
raise ValueError(
"`sample_weight` argument is not supported when using "
"torch Dataloader as input."
)

super().__init__(x, y, **kwargs)

self.training = training
self.steps = steps
self._batch_size = x.batch_size
self._dataset = x

self.current_step = 0

def get_dataset(self):
def parse_dataloader_gen():
self.current_step = 0
for batch in iter(self._dataset):
self.current_step += 1
batch = map_structure(lambda x: x.cpu().numpy(), list_to_tuple(batch))
yield batch

return parse_dataloader_gen

def get_size(self):
try:
return len(self._dataset)
except Exception:
return None

@property
def batch_size(self):
return self.representative_batch_size

@property
def representative_batch_size(self):
return self._batch_size

def has_partial_batch(self):
return False

@property
def partial_batch_size(self):
return

def should_recreate_iterator(self):
# if in eval mode should not recreate iterator
# but if in train mode and steps not provided, should recreate at end of each epoch
if not self.training or self.steps is None:
return self.training

steps_dataset = self.get_size()
if steps_dataset is None:
return False

remaining_steps = steps_dataset - self.current_step
# if remaining steps less than needed steps, should recreate dataloader
# TODO: This will drop the last steps of data, how to avoid this?
if remaining_steps < self.steps:
return True
else:
return False
41 changes: 41 additions & 0 deletions elegy/data/torch_dataloader_adapter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import math
from unittest import TestCase

import numpy as np
import torch
from elegy.data.torch_dataloader_adapter import TorchDataLoaderAdapter
from torch.utils.data import DataLoader, TensorDataset


class ArrayDataAdapterTest(TestCase):
def test_basic(self):
batch_size = 10
epochs = 1
x = np.array(np.random.uniform(size=(100, 32, 32, 3)))
y = np.array(np.random.uniform(size=(100, 1)))

dataset = TensorDataset(torch.from_numpy(x), torch.from_numpy(y))
dataloader = DataLoader(dataset, batch_size=batch_size)

data_adapter = TorchDataLoaderAdapter(dataloader)

dataset_length = x.shape[0]
num_steps = math.ceil(dataset_length / batch_size) * epochs
iterator_fn = data_adapter.get_dataset()
for i, batch in zip(range(num_steps), iterator_fn()):
batch_x, batch_y = batch
assert batch_x.shape == (batch_size, *x.shape[1:])
assert batch_y.shape == (batch_size, *y.shape[1:])
np.testing.assert_array_equal(
batch_x,
x[
(i * batch_size)
% dataset_length : (i * batch_size)
% dataset_length
+ batch_size
],
)

assert data_adapter.get_size() * batch_size == x.shape[0]
assert data_adapter.batch_size == batch_size
assert i == num_steps - 1
2 changes: 1 addition & 1 deletion elegy/losses/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def cosine_similarity(
Arguments:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
axis: The dimension along which the cosinemsimilarity is computed.
axis: The dimension along which the cosine similarity is computed.
Returns:
cosine similarity Values. If reduction is NONE, this has
Expand Down
3 changes: 1 addition & 2 deletions elegy/metrics/binary_accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def test_compatibility(self):
),
)

#
def test_cummulative(self):
def test_cumulative(self):

tm = tfk.metrics.BinaryAccuracy(threshold=0.3)
em = elegy.metrics.BinaryAccuracy(threshold=0.3)
Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def f1(

class F1(Metric):
"""
The metric calculates the armonic mean between precision and recall. This value is ultimately returned as
The metric calculates the Harmonic mean between precision and recall. This value is ultimately returned as
`f1`.
If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.
Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/f1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_basic(self):
)

#
def test_cummulative(self):
def test_cumulative(self):
em = elegy.metrics.F1(threshold=0.3)
tm = tfa.metrics.F1Score(2, average="micro", threshold=0.3)

Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/mean_absolute_percentage_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class MeanAbsolutePercentageError(Mean):
"""
Computes the cumulative mean absoluted percetage error between `y_true` and `y_pred`.
Computes the cumulative mean absoluted percentage error between `y_true` and `y_pred`.
Usage:
```python
Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/mean_absolute_percentage_error_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_basic(self):
)

#
def test_cummulative(self):
def test_cumulative(self):

tm = tfk.metrics.MeanAbsolutePercentageError()
em = elegy.metrics.MeanAbsolutePercentageError()
Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/precision_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_compatibility(self):
)

#
def test_cummulative(self):
def test_cumulative(self):
tm = tfk.metrics.Precision(thresholds=0.3)
em = elegy.metrics.Precision(threshold=0.3)

Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/recall_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_basic(self):
),
)

def test_cummulative(self):
def test_cumulative(self):
tm = tfk.metrics.Recall(thresholds=0.3)
em = elegy.metrics.Recall(threshold=0.3)

Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def call(
sample_weight: Optional weighting of each example. Defaults to 1.
Returns:
Array with the cummulative reduce.
Array with the cumulative reduce.
"""
total = self.add_parameter(
"total",
Expand Down
2 changes: 1 addition & 1 deletion elegy/metrics/reduce_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def call(
sample_weight: Optional weighting of each example. Defaults to 1.
Returns:
Array with the cummulative reduce metric.
Array with the cumulative reduce metric.
"""

cm_metric = self.add_parameter(
Expand Down

0 comments on commit 338217f

Please sign in to comment.