Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tf dataset adapter #144

Merged
merged 6 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 7 additions & 9 deletions elegy/data/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from .list_adapter import ListsOfScalarsDataAdapter
from .dataset import DataLoaderAdapter

try:
from .tf_dataset_adapter import TFDatasetAdapter
except ImportError:
TFDatasetAdapter = None
try:
from .torch_dataloader_adapter import TorchDataLoaderAdapter
except ImportError:
Expand All @@ -22,6 +26,8 @@
DataLoaderAdapter,
]

if TFDatasetAdapter is not None:
ALL_ADAPTER_CLS.append(TFDatasetAdapter)
if TorchDataLoaderAdapter is not None:
ALL_ADAPTER_CLS.append(TorchDataLoaderAdapter)

Expand Down Expand Up @@ -81,6 +87,7 @@ def catch_stop_iteration(self):
try:
yield
# context.async_wait()

except (StopIteration):
if (
self._adapter.get_size() is None
Expand Down Expand Up @@ -137,15 +144,6 @@ def _infer_steps(self, steps, dataset):
raise ValueError(
"When passing a generator, you " "must specify how many steps to draw."
)
# size = cardinality.cardinality(dataset)
# if size == cardinality.INFINITE and steps is None:
# raise ValueError(
# "When passing an infinitely repeating dataset, you "
# "must specify how many steps to draw."
# )
# if size >= 0:
# return size.numpy().item()
# return None

@property
def _samples(self):
Expand Down
98 changes: 98 additions & 0 deletions elegy/data/tf_dataset_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Implementation based on tf.keras.engine.data_adapter.py
# https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/engine/data_adapter.py


from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops

from .data_adapter import DataAdapter
from .utils import is_none_or_empty, map_structure, flatten


class TFDatasetAdapter(DataAdapter):
"""Adapter that handles `tf.data.Dataset`."""

@staticmethod
def can_handle(x, y=None):
return isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2))

def __init__(self, x, y=None, sample_weights=None, steps=None, **kwargs):
super().__init__(x, y, **kwargs)
# Note that the dataset instance is immutable, its fine to reuse the user
# provided dataset.
self._dataset = x

# The user-provided steps.
self._user_steps = steps

self._validate_args(y, sample_weights, steps)

# Since we have to know the dtype of the dataset when we build the
# dataset, we have to look at a batch to infer the structure.
peek = next(iter(x))

self._first_batch_size = int(list(flatten(peek))[0].shape[0])

def get_dataset(self):
def parse_tf_data_gen():
for batch in iter(self._dataset):
batch = map_structure(lambda x: x.numpy(), batch)
yield batch

return parse_tf_data_gen

def get_size(self):
size = cardinality.cardinality(self._dataset)
if size == cardinality.INFINITE and self._user_steps is None:
raise ValueError(
"When passing an infinitely repeating tf.data.Dataset, you "
"must specify how many steps to draw."
)
elif size == cardinality.INFINITE:
return self._user_steps
elif size >= 0:
return size.numpy().item()

@property
def batch_size(self):
return self.representative_batch_size

@property
def representative_batch_size(self):
return self._first_batch_size

@property
def partial_batch_size(self):
return

def has_partial_batch(self):
return False

def should_recreate_iterator(self):
# If user doesn't supply `steps`, or if they supply `steps` that
# exactly equals the size of the `Dataset`, create a new iterator
# each epoch.
return (
self._user_steps is None
or cardinality.cardinality(self._dataset).numpy() == self._user_steps
)

def _validate_args(self, y, sample_weights, steps):
"""Validates `__init__` arguments."""
# Arguments that shouldn't be passed.
if not is_none_or_empty(y):
raise ValueError(
"`y` argument is not supported when using " "tf.Data.dataset as input."
)
if not is_none_or_empty(sample_weights):
raise ValueError(
"`sample_weight` argument is not supported when using "
"tf.Data.dataset as input."
)

size = cardinality.cardinality(self._dataset).numpy()
if size == cardinality.INFINITE and steps is None:
raise ValueError(
"When providing an infinitely repeating tf.data.Dataset, you must specify "
"the number of steps to run."
)
88 changes: 88 additions & 0 deletions elegy/data/tf_dataset_adapter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import math
from unittest import TestCase

import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from elegy.data.tf_dataset_adapter import TFDatasetAdapter


class ArrayDataAdapterTest(TestCase):
def test_basic(self):
batch_size = 10
epochs = 1
x = np.array(np.random.uniform(size=(100, 32, 32, 3)))
y = np.array(np.random.uniform(size=(100, 1)))
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(batch_size)

data_adapter = TFDatasetAdapter(dataset, steps=None)

num_steps = math.ceil(x.shape[0] / batch_size) * epochs
iterator_fn = data_adapter.get_dataset()
for i, batch in zip(range(num_steps), iterator_fn()):
batch_x, batch_y = batch
assert batch_x.shape == (batch_size, *x.shape[1:])
assert batch_y.shape == (batch_size, *y.shape[1:])
np.testing.assert_array_equal(
batch_x, x[i * batch_size : (i + 1) * batch_size]
)

assert data_adapter.get_size() * batch_size == x.shape[0]
assert data_adapter.batch_size == batch_size

def test_only_x_repeat(self):
batch_size = 10
epochs = 2

x = np.array(np.random.uniform(size=(100, 32, 32, 3)))
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()

dataset_length = x.shape[0]
num_steps = math.ceil(dataset_length / batch_size) * epochs

data_adapter = TFDatasetAdapter(
dataset, steps=math.ceil(dataset_length / batch_size)
)

iterator_fn = data_adapter.get_dataset()
for i, batch in zip(range(num_steps), iterator_fn()):
batch_x = batch
assert batch_x.shape == (batch_size, *x.shape[1:])
np.testing.assert_array_equal(
batch_x,
x[
(i * batch_size)
% dataset_length : (i * batch_size)
% dataset_length
+ batch_size
],
)

assert data_adapter.get_size() * batch_size == x.shape[0]
assert data_adapter.batch_size == batch_size
assert i == num_steps - 1

def test_error(self):
batch_size = 10
epochs = 2
x = np.array(np.random.uniform(size=(100, 32, 32, 3)))
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.batch(batch_size)

data_adapter = TFDatasetAdapter(dataset, steps=None)

num_steps = math.ceil(x.shape[0] / batch_size) * epochs
iterator_fn = data_adapter.get_dataset()
iterator = iterator_fn()

with self.assertRaises(StopIteration):
for i in range(num_steps):
batch = next(iterator)
batch_x = batch
assert batch_x.shape == (batch_size, *x.shape[1:])
np.testing.assert_array_equal(
batch_x, x[i * batch_size : (i + 1) * batch_size]
)
135 changes: 135 additions & 0 deletions examples/mnist_tf_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
from datetime import datetime

import elegy
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow as tf
import typer
from elegy.callbacks.tensorboard import TensorBoard
from tensorboardX.writer import SummaryWriter

from utils import plot_history


def main(debug: bool = False, eager: bool = False, logdir: str = "runs"):

if debug:
import debugpy

print("Waiting for debugger...")
debugpy.listen(5678)
debugpy.wait_for_client()

current_time = datetime.now().strftime("%b%d_%H-%M-%S")
logdir = os.path.join(logdir, current_time)

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

def preprocess_images(images):
images = images.reshape((images.shape[0], 28, 28, 1)) / 255.0
return images.astype("float32")

X_train = preprocess_images(X_train)
X_test = preprocess_images(X_test)

print("X_train:", X_train.shape, X_train.dtype)
print("y_train:", y_train.shape, y_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
print("y_test:", y_test.shape, y_test.dtype)

class CNN(elegy.Module):
def call(self, image: jnp.ndarray, training: bool):
@elegy.to_module
def ConvBlock(x, units, kernel, stride=1):
x = elegy.nn.Conv2D(units, kernel, stride=stride, padding="same")(x)
x = elegy.nn.BatchNormalization()(x, training)
x = elegy.nn.Dropout(0.2)(x, training)
return jax.nn.relu(x)

x: np.ndarray = image.astype(jnp.float32) / 255.0

# base
x = ConvBlock()(x, 32, [3, 3])
x = ConvBlock()(x, 64, [3, 3], stride=2)
x = ConvBlock()(x, 64, [3, 3], stride=2)
x = ConvBlock()(x, 128, [3, 3], stride=2)

# GlobalAveragePooling2D
x = jnp.mean(x, axis=[1, 2])

# 1x1 Conv
x = elegy.nn.Linear(10)(x)

return x

model = elegy.Model(
module=CNN(),
loss=elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optax.adam(1e-3),
run_eagerly=eager,
)

# show summary
model.summary(X_train[:64])

batch_size = 64
train_size = 60000
test_size = 10000
# Create tf datasets
train_dataset = (
tf.data.Dataset.from_tensor_slices((X_train, y_train))
.shuffle(train_size)
.batch(batch_size)
.repeat()
)
test_dataset = (
tf.data.Dataset.from_tensor_slices((X_test, y_test))
.shuffle(test_size)
.batch(batch_size)
)

history = model.fit(
train_dataset,
epochs=10,
steps_per_epoch=200,
validation_data=test_dataset,
callbacks=[TensorBoard(logdir=logdir)],
)

plot_history(history)

model.save("models/conv")

model = elegy.model.load("models/conv")

print(model.evaluate(x=X_test, y=y_test))

# get random samples
idxs = np.random.randint(0, 10000, size=(9,))
x_sample = X_test[idxs]

# get predictions
y_pred = model.predict(x=x_sample)

# plot results
with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
figure = plt.figure(figsize=(12, 12))
for i in range(3):
for j in range(3):
k = 3 * i + j
plt.subplot(3, 3, k + 1)

plt.title(f"{np.argmax(y_pred[k])}")
plt.imshow(x_sample[k], cmap="gray")
tbwriter.add_figure("Conv classifier", figure, 100)

plt.show()


if __name__ == "__main__":
typer.run(main)
Empty file removed tests/__init__.py
Empty file.