Skip to content

Commit

Permalink
[SGD] Dataset API (#7839)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuisawesome committed Jun 1, 2020
1 parent 21d5b49 commit dcf58a4
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 67 deletions.
1 change: 1 addition & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ Getting Involved
raysgd/raysgd.rst
raysgd/raysgd_pytorch.rst
raysgd/raysgd_tensorflow.rst
raysgd/raysgd_dataset.rst
raysgd/raysgd_ref.rst

.. toctree::
Expand Down
48 changes: 48 additions & 0 deletions doc/source/raysgd/raysgd_dataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Distributed Dataset
===================

The RaySGD ``Dataset`` provides a simple abstraction for training with
distributed data.

.. tip:: Get in touch with us if you're using or considering using `RaySGD <https://forms.gle/26EMwdahdgm7Lscy9>`_!

Setting up a dataset
--------------------

A dataset can be constructed via any python iterable, or a ``ParallelIterator``. Optionally, a batch size, download function, concurrency, and a transformation can also be specified.

When constructing a dataset, a download function can be specified. For example, if a dataset is initialized with a set of paths, a download function can be specified which converts those paths to ``(input, label)`` tuples. The download function can be executed in parallel via ``max_concurrency``. This may be useful if the backing datastore has rate limits, there is high overhead associated with a download, or downloading is computationally expensive. Downloaded data is stored as objects in the plasma store.

An additional, final transformation can be specified via ``Dataset::transform``. This function is guaranteed to take place on the same worker that training will take place on. It is good practice to do operations which produce large outputs, such as converting images to tensors as transformations.

Finally, the batch size can be specified. The batch size is the number of data points used per training step per worker.

.. note:: Batch size should be specified via the dataset's constructor, __not__ the ``config["batch_size"]`` passed into the Trainer constructor. In general, datasets are configured via their own constructor, not the Trainer config, wherever possible.

Using a dataset
---------------

To use a dataset, pass it in as an argument to ``trainer.train()``. A dataset passed in to ``trainer.train`` will take precedence over the trainer's data creator during that training run.

.. code-block:: python
trainer.train(dataset=dataset, num_steps=10) # Trains using a dataset
trainer.train() # Trains with the original data creator
trainer.train(dataset=dataset2, num_steps=20) # Trains using a different dataset
Sharding and Sampling
---------------------

.. note:: These details may change in the future.

Datasets use ParallelIterator actors for sharding. In order to handle datasets which do not shard evenly, and streaming datasets (which may not have a defined size), shards are represented as repeated sequences of data. As a result, num_steps should always be specified when training and some data may be oversampled if the data cannot be evenly sharded.

If the dataset is of a known length (and can be evenly sharded), training for an epoch is eqivalent to setting ``num_steps = len(data) / (num_workers * batch_size)``.

Complete dataset example
------------------------

Below is an example of training a network with a single hidden layer to learn the identity function.

.. literalinclude:: ../../../python/ray/util/sgd/data/examples/mlp_identity.py
:language: python
9 changes: 9 additions & 0 deletions doc/source/raysgd/raysgd_ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ TFTrainer
:members:

.. automethod:: __init__

Dataset
-------

.. autoclass:: ray.util.sgd.data.Dataset
:members:

.. automethod:: __init__

5 changes: 5 additions & 0 deletions python/ray/util/sgd/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ray.util.sgd.data.dataset import Dataset
import logging
logger = logging.getLogger(__name__)

__all__ = ["Dataset"]
92 changes: 92 additions & 0 deletions python/ray/util/sgd/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from ray.util.iter import ParallelIterator, from_iterators


class Dataset():
"""A simple Dataset abstraction for RaySGD.
This dataset is designed to work with RaySGD trainers (currently just
Torch) to provide support for streaming large external datasets, and built
in sharding.
.. code-block:: python
def to_mat(x):
return torch.tensor([[x]]).float()
data = [i * 0.001 for i in range(1000)]
p_iter = iter.from_items(data, num_shards=1, repeat=True)
dataset = Dataset(
p_iter,
batch_size=32,
max_concurrency=1,
download_func=lambda x: (to_mat(x), to_mat(x)))
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=None,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
num_workers=5,
)
for i in range(10):
# Train for another epoch using the dataset
trainer.train(dataset=dataset, num_steps=200)
model = trainer.get_model()
print("f(0.5)=", float(model(to_mat(0.5))[0][0]))
Args:
data (iterable[U] or ParallelIterator[U]): Any existing python
iterable (or iterator), or an existing parallel iterator
to use.
batch_size (int): The batch size for training/inference (default 32).
download_func (U -> (S, Y)): A function which returns two values, the
input and the label (default is the identity function).
max_concurrency (int): The maximum number of concurrent calls to the
download function. See ParallelIterator::for_each for details.
transform (S -> X): A final transformation to be applied to the _input
only_. This is guaranteed to run on the same worker that training
will occur on.
"""

def __init__(self,
data,
batch_size=32,
download_func=None,
max_concurrency=0,
transform=None):
par_iter = None
if isinstance(data, ParallelIterator):
par_iter = data.repartition(1)
else:
par_iter = from_iterators([data], repeat=True)
if download_func:
par_iter = par_iter.for_each(
download_func, max_concurrency=max_concurrency)
self.iter = par_iter.batch(batch_size)

self.batch_size = batch_size
self.max_concurrency = max_concurrency
self.transform = transform

def set_num_shards(self, num_shards):
"""
Reshards the iterator if necessary.
"""
if num_shards != self.iter.num_shards():
print("Setting num shards", num_shards)
self.iter = self.iter.repartition(num_shards)

def get_shard(self, i):
"""
Returns a single, iterable shard.
"""
assert i < self.iter.num_shards(), \
"Trying to get shard {} but there are only {} shards." + \
"Are you sure you called set_num_shards already?".format(
i, self.iter.num_shards()
)

return self.iter.get_shard(i)
1 change: 1 addition & 0 deletions python/ray/util/sgd/data/examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ray-project/*
Empty file.
69 changes: 69 additions & 0 deletions python/ray/util/sgd/data/examples/mlp_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import ray
from ray.util.sgd.torch.torch_trainer import TorchTrainer
from ray.util.sgd.data.dataset import Dataset

import torch
from torch import nn
import torch.nn.functional as F


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 128)
self.fc2 = nn.Linear(128, 1)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x


def model_creator(config):
return Net()


def optimizer_creator(model, config):
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))


def to_mat(x):
return torch.tensor([[x]]).float()


def dataset_creator():
num_points = 32 * 100 * 2
data = [i * (1 / num_points) for i in range(num_points)]
dataset = Dataset(
data,
batch_size=32,
max_concurrency=2,
download_func=lambda x: (to_mat(x), to_mat(x)))
return dataset


def main():
dataset = dataset_creator()
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=None,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
num_workers=2,
)

for i in range(10):
# Train a full epoch using the data_creator
# trainer.train()

# Train for another epoch using the dataset
trainer.train(dataset=dataset, num_steps=100)

model = trainer.get_model()
print("f(0.5)=", float(model(to_mat(0.5))[0][0]))


if __name__ == "__main__":
ray.init()
main()

0 comments on commit dcf58a4

Please sign in to comment.