Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New dataset pipeline: draft #292

Open
albertz opened this issue May 19, 2020 · 23 comments
Open

New dataset pipeline: draft #292

albertz opened this issue May 19, 2020 · 23 comments
Assignees
Labels
good first issue Should be a good starting point to get familiar with RETURNN, hopefully not too hard. TensorFlow

Comments

@albertz
Copy link
Member

albertz commented May 19, 2020

Existing configs should work as before, without any change in behavior.
The datasets themselves (all what derives from class Dataset) will stay as is, as well as their API.

It would also be nice to support TensorFlow datasets (TFDS) (more) directly.
There is also TensorFlow I/O which provide further dataset functions, and can read MP3, Ogg, etc directly.

This (here) is about what follows from there, i.e. how to get the data from the dataset into TF, and how to build up mini-batches.
The file TFDataPipeline.py is currently there for this purpose.
The standard and working implementation (FeedDictDataProvider) builds up the mini-batches with Numpy code, and then uses the TF session run feed dict.
There were plans (right from the beginning of the RETURNN TF implementation) to make this via TF queues instead, but that was never finished.
These plans are outdated now because tf.data is the better API for this. (Related is #171. See this tf.data example for the high level logic to loop over epochs, initialize tf.data.Iterator and tf.data.Dataset.)
Effectively this would be used instead of the current feed dict approach
(another DataProviderBase implementation).

Multi-GPU training is another aspect. For that, it makes sense to have the dataset loading and mini-batch building living in a separate (sub) process. This can make sense even for single GPU.
This process would then send the mini-batches to each computing node.
We would use distributed TensorFlow for this, even for single GPU, to have unified code.
(See #296 for a draft on distributed TensorFlow support in RETURNN.)
(Alternatives would be ZeroMQ TF ops (eg. this, or this, or this), or some own custom TF op.
However, distributed TensorFlow would allow us to extend our multi-GPU training code later more easily. Maybe Horovod does not cover all use cases.)
We would not use the input pipeline provided by the tf.distribute.Strategy API, and be more flexible on this. Our own processing would by default not use sharding (because that is inefficient in general; requires extra work to make efficient). We rather would have a single dedicated worker for the dataset, and its output would get distributed to the other train workers. I.e. we would have the training and data preprocessing decoupled. There is no TF Strategy implementation which covers this case (TF feature request: Decoupling preprocessing and training). (RETURNN would not be restricted to this though; it would also support the TF dataset pipeline with sharding.)

TPU training is another aspect which has some more constraints (beyond multi-GPU training), like having fixed predefined batch sizes. By using the tf.data.Dataset pipeline, this should be relatively easy to accomplish. (Although maybe we would want to tell our RETURNN Data class that we have a fixed batch dim, but this is maybe a minor detail, and maybe also irrelevant (or not too relevant) for the new dataset pipeline.)

Further preprocessing on sequence level (such as custom feature extraction via pure TF, or custom data augmentation via pure TF) can be allowed by custom TF code, or a custom RETURNN TF network just for that. It might make sense to be able to run this on GPU as well.

A further aspect is to have the CPU->GPU transfer of the data asynchronously, so that the session run call of the training does not first need to wait for that. TF queues or tf.data or some custom solution can be used for that. If we already did the data preprocessing on GPU, we also probably want to avoid the GPU->CPU->GPU roundtrip (if it is all the same single GPU).

Another aspect to keep in mind might be streaming / online decoding (see e.g. #276). Maybe the same data network could be used in an online setting.
We could use the keep_over_epoch (hidden state) functionality for that, which mostly provides exactly that functionality. Then the dataset would split up an incoming stream into chunks and pass on those chunks. This is also more efficient than just passing on individual frames.
Maybe the dataset could optionally provide the information when to reset the state.
(This aspect would likely not part of the initial implementation, but should be easy to add later.)

Another note: The current computation of "computing time" (the log output like train epoch 1, finished after 2941 steps, 0:28:58 elapsed (99.3% computing time)) would not be valid anymore, or should be extended. More specifically, we should measure how much time is spend in waiting for data.

High level ideas:

  • A tf.data.Dataset would wrap the existing RETURNN dataset API as a TF dataset. It would give the access to individual sequences. It would call dataset.load_seqs and dataset.get_data. This would live (by default) in a dataset subprocess (although it should be flexible to possibly run also in the main process).

  • Reimplement the mini-batch building logic (Dataset._generate_batches + FeedDictDataProvider.get_next_batch) in pure TF. tf.data also provides pure TF code for this pupose already, to prepare mini-batches based on the sequences. This is a good base, which we can extend if needed.

  • There could be some generic way to define the mini batch building logic.
    Originally the idea was also to use RETURNN layers for this, but this would not quite work, as RETURNN layers work on tensors, but here we would work with TF datasets
    (which represent a (possible infinite) stream of tensors (or possibly nested structure of tensors),
    and do operations on the stream (e.g. bucketing, batching), not necessarily on individual tensors).
    We would assume a very simple default behavior, which would mostly do some standard mini batching logic.

  • The user could write a custom data_network dict for the config, which is just as network, to define a RETURNN network.
    This would be separate from the mini-batch building logic, as this would run on the sequences (or a single sequence).
    This can and should use existing RETURNN layers. E.g. if the dataset is configured such that it returns the raw audio, you can do the feature extraction in pure TF, easily add data augmentation and other things.
    This data network would be some optional aspect for now.

  • The dataset (and maybe the data network) could (by default) be created in the dataset subprocess, and be executed in that process. The output of it would be forwarded to the main RETURNN computing instances via distributed TensorFlow. If there are multiple computing instances (for multi-GPU), the mini-batches would be distributed to them.

So, the pipeline would look like this:

  • RETURNN Dataset wrapped as a TF tf.data.Dataset. (We would pass it the Dataset instance, where the init_func would initialize the epoch. And an ExternData instance to specify what data we expect and get from the dataset.)
  • Optionally shuffle.
  • Optionally chunking (we need to implement this). Maybe followed by additional shuffle.
  • PaddedBatchDataset (via dataset.padded_batch) to form batches,
    or alternatively bucket_by_sequence_length for bucketing.
  • Optionally shuffle.
  • Optionally PrefetchDataset (in the dataset process).
  • The data loading itself could be distributed, across multiple data loaders. E.g. via tf.data.experimental.service.distribute and DispatchServer.
  • If multi node / multi GPU training, some way to distribute the batches to each node.
    (Could be implemented with pure TF, quite straight-forward, via _GeneratorDataset, via distributed TensorFlow.)
    (MultiDeviceIterator/_PerDeviceGenerator does exactly that via distributed TF. See code.)
  • PrefetchDataset in the main process, living on GPU to have async CPU->GPU transfer.

Optionally we could also run multiple datasets in subprocesses, each handling a subset, and then interleave (interleave, sample_from_datasets) them. This would help if the dataset loading is slow but parallelization would help.

In the config, you can provide an option dataset_pipeline, which is supposed to be a function returnn.InputContext -> tf.data.Dataset. This function is supposed to return the final dataset as used directly as-is for training. The final dataset is supposed to match the data keys from extern_data. More specifically, the elements should be a dict where the key is the data-key, and the value the tensor. Size placeholders would have the special key "size:%s:%i" % (key, i) (similar as in TFNetwork.get_fetches_dict).
In distributed TF case, this function will get executed on every worker, and also on dedicated dataset pipeline worker(s). It would be controlled via scopes that some part of this only runs on the dataset pipeline worker, and then some remaining part on all the train workers. (So we might end up with some parts in the TF computation graph which are not used, depending on the worker type, but this is not really critical or relevant.)

returnn.InputContext would be a class with an API such as:

  • get_dataset_name(): Would return "train", "dev" or so (corresponding to the config, or also eval_datasets in the config).
  • get_returnn_dataset(**kwargs): This would get the initial RETURNN dataset (configured via train, dev etc as usual). kwargs could maybe overwrite parts (if you would want to do sharding here or so -- by default you would use a single dataset worker).
    (This would also include seq length/size information.)
  • (get_dataset_worker_scope() (or get_producer_worker_scope()): To be used as with context.get_dataset_worker_scope():, for the dataset/producer worker(s). Not sure if needed?)
  • Everything outside the dataset worker scope would run on all workers. (Or do we need a
    get_train_worker_scope() or get_std_worker_scope() or get_consumer_worker_scope()?)
  • (Maybe using MultiDeviceIterator. Maybe using tf.distribute.InputContext. We could (or even should) make use of tf.distribute.get_strategy(), as this would always return sth sensible, and then our code will work if any TF distributes strategies are used.)
  • map_producer_to_consumer(dataset: tf.data.Dataset) -> tf.data.Dataset: Operating on the batches, which maps the dataset(s) from dataset worker(s) to the consumer (train worker).
    With distributed TF, if there are multiple dataset workers, it would join the data (e.g. using interleave). If there are multiple consumers (train workers), it would evenly distribute the batches such that each consumer gets exactly the same amount (e.g. via MultiDeviceIterator, or tf.data.experimental.service.distribute / DispatchServer).
    With Horovod and without distributed TF, this would use shard, but print a warning that this is inefficient, and that distributed TF should be enabled (can be together with Horovod).
  • get_default_max_seqs(): Returns an int, the number of seqs for a batch (batch size, in number of seqs, i.e. the batch dimension).
    The default implementation would return sth like config.int("max_seqs") (asserting that "max_seqs" is defined).
  • add_seq_length(dataset: Dataset) -> Dataset: Basically dataset.map(lambda item: {**item, 'size:data:0': tf.shape(item['data'])[0]}) or similar (via).
  • get_consumer_device() -> str: E.g. "gpu", for the consumer (trainer) worker.
  • prefetch_to_consumer_device(dataset: Dataset) -> Dataset: Basically tf.data.experimental.prefetch_to_device(context.get_consumer_device())(dataset)

If just dataset_pipeline = True, we would have a default implementation, like this:

def dataset_pipeline(context):
  dataset = context.get_returnn_dataset()
  dataset = dataset.padded_batch(context.get_default_max_seqs())
  dataset = context.map_producer_to_consumer(dataset)
  dataset = context.prefetch_to_consumer_device(dataset)
  return dataset

Here an example for multi-producer case (which would not be implemented/supported directly, but should be easy to add later):

def dataset_pipeline(context):
  context.set_multiple_producers([{"path": "file1.zip"}, {"path": "file2.zip"}, {"path": "file3.zip"}])
  dataset = context.get_returnn_dataset(**context.get_kwargs_per_producer())
  dataset = dataset.padded_batch(context.get_default_max_seqs())
  dataset = context.map_producer_to_consumer(dataset)
  dataset = context.prefetch_to_consumer_device(dataset)
  return dataset

Example for Librispeech via TFDS:

def dataset_pipeline(context):
  import tensorflow_datasets as tfds
  dataset = tfds.load(name="librispeech", split=context.get_split_name())
  dataset = dataset.shuffle(1024)
  dataset = context.add_seq_length(dataset)
  dataset = dataset.padded_batch(32)
  dataset = context.map_producer_to_consumer(dataset)
  dataset = context.prefetch_to_consumer_device(dataset)
  return dataset
@albertz albertz changed the title New dataset pipeline New dataset pipeline: draft May 19, 2020
@albertz
Copy link
Member Author

albertz commented May 22, 2020

Some notes about tf.data (mostly for myself) (maybe I will move that somewhere later, e.g. to TFDataPipeline.py, or to the wiki...):

Why to use tf.data at all? What were the problem with the TF queue (tf.FIFOQueue etc)?
One of the main problem was that there was no easy way to reopen/reinit a queue, after it got closed (e.g. after reading the dev set). See initial TF queue code in TFDataPipeline.py, or here, here.

Some relevant resources, StackOverflow questions:

For reference, the tf.data example use is (from here):

iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)

dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)

# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())

# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
  # Initialize the iterator to `dataset_range`
  sess.run(range_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

  # Initialize the iterator to `dataset_evens`
  sess.run(evens_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

Take PrefetchDataset as an example:

def prefetch_dataset(input_dataset, buffer_size, output_types, output_shapes, slack_period=0, legacy_autotune=True, name=None):
  r"""Creates a dataset that asynchronously prefetches elements from `input_dataset`.

  Args:
    input_dataset: A `Tensor` of type `variant`.
    buffer_size: A `Tensor` of type `int64`.
      The maximum number of elements to buffer in an iterator over
      this dataset.
    output_types: A list of `tf.DTypes` that has length `>= 1`.
    output_shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`) that has length `>= 1`.
    ...

  Returns:
    A `Tensor` of type `variant`.
  """
  ...
  _, _, _op = _op_def_lib._apply_op_helper(
        "PrefetchDataset", input_dataset=input_dataset,
                           buffer_size=buffer_size, output_types=output_types,
                           output_shapes=output_shapes,
                           slack_period=slack_period,
                           legacy_autotune=legacy_autotune, name=name)
  _result = _op.outputs[:]
  ...
  return _result

In C++: class PrefetchDatasetOp : public UnaryDatasetOpKernel,

// Encapsulates the work required to plug unary Datasets into the core
// TensorFlow graph execution engine.
class UnaryDatasetOpKernel : public DatasetOpKernel { ... };

and

// Encapsulates the work required to plug a DatasetBase into the core TensorFlow
// graph execution engine.
class DatasetOpKernel : public OpKernel {
 public:
  DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
  void Compute(OpKernelContext* ctx) final;

 protected:
  // Subclasses should implement this method. It will be called during Compute
  // execution.
  virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
};

and

void DatasetOpKernel::Compute(OpKernelContext* ctx) {
  DatasetBase* dataset = nullptr;
  MakeDataset(ctx, &dataset);
  if (ctx->status().ok()) {
    Tensor* output = nullptr;
    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
    OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
  }
}

So pretty straight-forward. The dataset (instance of DatasetBase) is wrapped in a DT_VARIANT tensor. So all interesting things are there. Mostly in the header dataset.h.

Some interesting bits:

// A simple background worker that executes closures asynchronously and without
// blocking.
//
// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
// to avoid blocking an executor thread that may be required by the blocking
// work.
//
// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
// purpose because its current implementation (in Eigen) uses a finite-length
// queue and will block the caller when full. This can lead to deadlock under
// heavy load. Since the number of concurrent work items in each user of a
// `BackgroundWorker` is at most one per op invocation, the dynamic allocation
// overhead is tolerable.
class BackgroundWorker { ... };

Datasets usually operate asynchronously.
E.g. take PrefetchDatasetOp::Dataset.
Most logic actually seems to be in PrefetchDatasetOp::Dataset::Iterator. That ones owns the internal buffer (std::deque<BufferElement> buffer_ GUARDED_BY(*mu_);).
When is that async background fetching code run though? During the session run of the main train step? Or always during the lifetime of the session? It looks like on the first call to GetNextInternal of the iterator, it will call EnsurePrefetchThreadStarted which starts the background thread (PrefetchThread).

// A cut-down version of `OpKernelContext` for running computations in
// iterators. Note that we cannot simply use `OpKernelContext` here because we
// might run computation in an iterator whose lifetime is not nested within the
// lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
//
// TODO(mrry): We're making some daring assumptions about the lifetime of the
// runner passed in here. A runner will be deleted when the original step ends,
// but all existing runners only close over session-lifetime (or longer-lived)
// state, so we can make a copy of the function. There's nothing in the
// definition of the API from which we took the runner to guarantee that what we
// are doing is safe. We should formalize the properties here.
class IteratorContext { ... };

So it means that dataset iterators can be fetched outside a session run step.
That effectively means that any dataset iterator cannot call back to some graph op. Except maybe for MapDatasetOp::Dataset, which calls back to the mapping function?
Also there is _GeneratorDataset which also seems to allow this.
Small demo (via):

  # noinspection PyProtectedMember
  from tensorflow.python.data.ops.dataset_ops import _GeneratorDataset as GeneratorDataset
  from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter

  def raise_out_of_range_error():
    empty_dataset = tf.data.Dataset.from_tensor_slices(tf.fill([0], 0))
    return DatasetV1Adapter(empty_dataset).make_one_shot_iterator().get_next()

  @tf.function(autograph=False)
  def init_func(x):
    with tf.control_dependencies([tf.print(["init_func", x]), v.assign(0)]):
      return tf.identity(x)

  @tf.function(autograph=False)
  def next_func(x):
    res = tf.identity(v)
    with tf.control_dependencies([res]):
      with tf.control_dependencies([tf.print(["next_func", x, res])]):
        end_check = tf.cond(
          pred=tf.greater_equal(res, 13),
          true_fn=raise_out_of_range_error,
          false_fn=lambda: tf.constant(0))
        with tf.control_dependencies([end_check]):
          with tf.control_dependencies([v.assign_add(1)]):
            return tf.identity(res)

  @tf.function(autograph=False)
  def finalize_func(x):
    with tf.control_dependencies([tf.print(["finalize_func", x, v])]):
      return tf.identity(x)

  generator_dataset = GeneratorDataset(
    init_args=tf.constant("dummy_init_args"),
    init_func=init_func,
    next_func=next_func,
    finalize_func=finalize_func)
  generator_dataset_v1 = DatasetV1Adapter(generator_dataset)
  ds_iter = tf_compat_v1.data.make_initializable_iterator(generator_dataset_v1)
  ds_iter_init = ds_iter.make_initializer(generator_dataset_v1)

  with tf.compat.v1.Session() as session:
    session.run(ds_iter_init)
    while True:
      try:
        print(session.run(ds_iter.get_next()))
      except tf.errors.OutOfRangeError:
        print("OutOfRangeError")
        break

This allows for a very straight forward implementation of wrapping up a RETURNN Dataset as a TF tf.data.Dataset. (We would pass it the Dataset instance, where the init_func would initialize the epoch. And an ExternData instance to specify what data we expect and get from the dataset.)
Together with PaddedBatchDataset (via dataset.padded_batch) and then PrefetchDataset, this gives us already what we need (for single GPU). Optionally also shuffle in between.
Optionally we could also run the dataset in a subprocess, and then communicate with the main process. But how?
Optionally we could also run multiple datasets in subprocesses, each handling a subset, and then interleave (interleave, sample_from_datasets) them. This would help if the dataset loading is slow but parallelization would help.
Optionally we could use bucket_by_sequence_length for bucketing.

How to design the configuration such that this becomes very straight-forward to configure and also flexible?


How to make this efficient for our multi-GPU training?
In our multi-GPU training, there is an independent RETURNN instance for every GPU, which communicate via Horovod. Horovod does not provide a simple distribute op (or does it?).
So we need a simple way to distribute the output of iterator.get_next() on each node. I think MultiDeviceIterator or _PerDeviceGenerator does exactly this already, via distributed TF.
The dataset itself could either run in node 0, or in some independent process. Node 0 could maybe start this as a subprocess.
See RETURNN wiki: distributed TensorFlow for further thoughts about this.

@albertz
Copy link
Member Author

albertz commented May 22, 2020

(Note, I assigned you to this issue, but this does not really mean an assignment. This is rather that you might be interested in this, and maybe might to take part in planning, brainstorming, maybe also implementation.)

@albertz
Copy link
Member Author

albertz commented May 22, 2020

@curufinwe Ha, the tf.data.Dataset has this function:

  def shard(self, num_shards, index):
    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.

    This dataset operator is very useful when running distributed training, as
    it allows each worker to read a unique subset.

    When reading a single input file, you can skip elements as follows:

    ```python
    d = tf.data.TFRecordDataset(input_file)
    d = d.shard(num_workers, worker_index)
    d = d.repeat(num_epochs)
    d = d.shuffle(shuffle_buffer_size)
    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
    ```

That sounds very much like our implementation...
But I really wonder whether that is how they do distributed training?

@albertz
Copy link
Member Author

albertz commented May 22, 2020

Lingvo does not seem to use tf.data. I wonder why. Maybe just because the project is older than tf.data. (Maybe @kazuki-irie knows?)
(Question on StackOverflow.)

@curufinwe
Copy link
Collaborator

Regarding Lingvo: yes, the project is older than the tf.data API

@albertz
Copy link
Member Author

albertz commented May 26, 2020

(Sorry, clicked wrong button. Actually I also wanted to add @patrick-wilken but somehow that's not possible? Maybe because he is not part of the GitHub RETURNN team? I added you now. Pending invite.)

@albertz
Copy link
Member Author

albertz commented May 26, 2020

Btw, update: The plan on the pipeline is mostly finished (see initial comment). Please review.

What's missing is mostly the design of how you would configure that in the config. Which should be flexible, simple, and extensible (such that we can easily add other features later, esp all those mentioned above). I'm still thinking about that. (Maybe you have suggestions?)

@JackTemaki
Copy link
Collaborator

Before I start to review your Ideas, I want to clarify if I understood the current situation correctly (for TF backend):

  • Currently only the FeedDictDataProvider is used (according to _get_new_data_provider in TFEngine)
  • FeedDictDataProvider contains a single thread, which creates a batch with its own withget_next_batch and adds this to a python queue
  • get_next_batch calls the dataset.get_data for all sequences in the batch, which then loads the actual data (or uses the cached data from the cached datasets).
  • get_next_batch uses the information from the Batch object provided from a batch generator to determine the correct padding.
  • The Runner which calls the sess.run calls get_feed_dicton the FeedDictDataProvider, which will dequeue the next batch and use it for the current training step.

My additional questions would be:

  • If the FeedDictDataProvider thread can not queue the data fast enough, at which position would the training thread wait? At queue.get()?
  • Does that mean the data loading is only executed in a single thread, and as such using only a single CPU core (unless there are subthreads from numpy/librosa calls)?

@albertz
Copy link
Member Author

albertz commented May 27, 2020

Before I start to review your Ideas, I want to clarify if I understood the current situation correctly (for TF backend):

  • Currently only the FeedDictDataProvider is used (according to _get_new_data_provider in TFEngine)

Yes, that's the only complete implementation, and also the only used implementation (you will see when you check code usages).

  • FeedDictDataProvider contains a single thread, which creates a batch with its own withget_next_batch and adds this to a python queue

Yes, the dataset.load_seqs, dataset.get_data and the batch construction works in parallel in an own thread, and there is a simple Python queue.

  • get_next_batch calls the dataset.get_data for all sequences in the batch, which then loads the actual data (or uses the cached data from the cached datasets).

Yes. Or depending on the dataset implementation. Originally loading the data was done in load_seqs. But this is an implementation detail of each dataset.

  • get_next_batch uses the information from the Batch object provided from a batch generator to determine the correct padding.

Yes, but this is an implementation detail, which we can easily change.

  • The Runner which calls the sess.run calls get_feed_dicton the FeedDictDataProvider, which will dequeue the next batch and use it for the current training step.

Yes. But we can extend or modify this API if this would match our new pipeline.

You are asking about somewhat irrelevant implementation details here. We can easily change any of that, as we want.

This draft here is not too much about these implementation details, but rather about a high level organization.

My additional questions would be:

  • If the FeedDictDataProvider thread can not queue the data fast enough, at which position would the training thread wait? At queue.get()?

Yes, sure. This is just a standard Python Queue. Queue.get will block if there are no items in the queue. See the Python documentation on Queue.

  • Does that mean the data loading is only executed in a single thread, and as such using only a single CPU core (unless there are subthreads from numpy/librosa calls)?

There is no point here to have multiple threads. Multiple threads would not help in any way here.

Multiple threads would help inside specific dataset implementations. But this is independent of all this discussion here. E.g. RASR could use multiple threads.

We could also extend CombinedDataset such that each sub dataset in there runs in an own thread (or even subprocess). But also this is independent from the discussion here.

But in any case, a single thread is enough to read from the RETURNN dataset (load_seqs and get_data). The further processing pipeline after that might use multiple threads or not, depending on what you do. This gets us more to this proposal here. The draft would have RETURNN flexible enough that you can do that in any way you want. It could e.g. also allow that you have multiple instances of RETURNN datasets in there (which each would run in an own thread or subprocess).

@JackTemaki
Copy link
Collaborator

JackTemaki commented May 28, 2020

I also tried to create an example of my own to understand how tf.data works:

import tensorflow as tf
import numpy

class DummyDataset():

    def get_data_generator(self):
        for i in range(1,10,1):
            print("build sequence %i" % i)
            data = {'sequence': i*numpy.ones((i,), dtype="float32"),
                    'target': i}
            yield data

    def get_data_types(self):
        return {'sequence': tf.float32, 'target': tf.int64}

    def get_data_shapes(self):
        return  {'sequence': tf.TensorShape([None]),
                 'target': tf.TensorShape([])}

def transform(target):
    target['sequence'] = target['sequence'] + 0.1
    shape = (tf.shape(target['sequence'])[0] -1,)
    target['stop_token'] = tf.concat([tf.zeros(shape), tf.ones((1,))], axis=0)
    return target

dummy_dataset = DummyDataset()

data = tf.data.Dataset.from_generator(generator=dummy_dataset.get_data_generator,
                                      output_shapes=dummy_dataset.get_data_shapes(),
                                      output_types=dummy_dataset.get_data_types())
data = data.map(transform)
data = data.padded_batch(4, padded_shapes={'sequence': [None], 'target': [], 'stop_token': [None]})

iter = data.make_initializable_iterator()
batch_dict = iter.get_next()

with tf.Session() as sess:
    sess.run(iter.initializer)
    while True:
        try:
            res = sess.run([batch_dict])
            print(res)
        except tf.errors.OutOfRangeError:
            break

The output is

build sequence 1
build sequence 2
build sequence 3
build sequence 4
[{'sequence': array([[1.1, 0. , 0. , 0. ],
       [2.1, 2.1, 0. , 0. ],
       [3.1, 3.1, 3.1, 0. ],
       [4.1, 4.1, 4.1, 4.1]], dtype=float32),
       'target': array([1, 2, 3, 4]),
       'stop_token': array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)}]
build sequence 5
build sequence 6
build sequence 7
build sequence 8
[{'sequence': array([[5.1, 5.1, 5.1, 5.1, 5.1, 0. , 0. , 0. ],
       [6.1, 6.1, 6.1, 6.1, 6.1, 6.1, 0. , 0. ],
       [7.1, 7.1, 7.1, 7.1, 7.1, 7.1, 7.1, 0. ],
       [8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1]], dtype=float32),
       'target': array([5, 6, 7, 8]),
       'stop_token': array([[0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)}]
build sequence 9
[{'sequence': array([[9.1, 9.1, 9.1, 9.1, 9.1, 9.1, 9.1, 9.1, 9.1]], dtype=float32),
  'target': array([9]),
  'stop_token': array([[0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)}]

(I apologize for starting with 1)

So wouldn't it be a possible approach to just extend the RETURNN Dataset with the generator function (and providing the correct types/shapes)? Or is there a drawback by doing this?

Edit:
Providing the shapes is not even necessary, it seems only the generator and the type are sufficient.

@albertz
Copy link
Member Author

albertz commented May 28, 2020

Yes, but you are again discussing minor implementation details. I even mentioned this already in the draft ("RETURNN Dataset wrapped as a TF tf.data.Dataset"). This is really the simple and straight-forward part. Let's not discuss this here (for now at least, but really, this is anyway straight forward).

The main discussion here is about the remaining parts. E.g. how would the configuration in the config look like. Basically let's clarify the open outstanding parts of the draft (see the TODO).

@albertz
Copy link
Member Author

albertz commented May 29, 2020

Update: I think the draft is mostly ready now. Please check if this suits all possible use cases (multi GPU training, TPU training, having multiple dataset workers, etc, whatever you can imagine). It might be annoying/tricky to change the design/API later.

@JackTemaki
Copy link
Collaborator

Yes, but you are again discussing minor implementation details.

I wanted to start with the easier parts before I assume wrong things.

So far I did not fully understand how the distributed training works (I read the Wiki page and some of the API descriptions for TF1 but I need more time), so I will skip this for now.

The parts that I would like to discuss are the config design and the possible multithreading of the feature pipeline. So the default pipeline you mentioned is clear:

def dataset_pipeline(context):
  dataset = context.get_returnn_dataset()
  dataset = dataset.padded_batch(context.get_default_max_seqs())
  dataset = context.map_producer_to_consumer(dataset)
  return dataset

And also the returnn.InputContext() seems reasonable. The question is now, how would this look like in practice? If I understand it correctly, all non-TF processing should stay inside the original datasets, and be fixed, and the context.get_returnn_dataset() call will return a tf.data.Dataset containing a dict of single sequences, such as in my example.

So my questions are:

  • At which point should the user be able to add his processing, and in which way?
    Is this a separate data_network that will be applied at the end when all datasets are joined? (E.g. with a MetaDataset or CombinedDataset)

  • Will there be a possibility to bypass MetaDataset and CombinedDataset, by using tf.data.Dataset.zipand tf.data.Dataset.interleave instead? So the data pipeline will get a list of multiple returnn.InputContext?

  • Will the feature processing only rely on the RETURNN layers? And what about possible parameters then? And how would you deal with the situation when you want different processing for train and dev? You can access it from the InputContext API, but how can this be used in the network? Or do you want to have something like: train_data_network and dev_data_network?

  • How does the parallel processing work? I know that there currently is a limitation for multithreading, as e.g. in the case of CachedDataset2 each sequence is loaded in a single load_seqs(idx, idx+1) call when the result of get_seq_length is needed for the batch generator.
    So would I have to define 4 OggZipDatasets on my own with 4 exclusive segment lists which run in their own process each, and then interleave those in a custom defined tf.data pipeline?

  • Will the user be able to define a custom tf.data pipeline before or after the processing network, or both? (right now only after the processing network seems to make sense for me, but maybe I am missing something)

Maybe some questions are unnecessary details, but those are the most interesting questions for me from the user perspective.

@albertz
Copy link
Member Author

albertz commented May 29, 2020

Yes, but you are again discussing minor implementation details.

I wanted to start with the easier parts before I assume wrong things.

So far I did not fully understand how the distributed training works (I read the Wiki page and some of the API descriptions for TF1 but I need more time), so I will skip this for now.

The parts that I would like to discuss are the config design and the possible multithreading of the feature pipeline.

The multithreading would be covered by distributed TF.
But we do not need to support it fully in advance. It's just important that we design it in a way that it would directly work and it's trivial to add.

So the default pipeline you mentioned is clear:

def dataset_pipeline(context):
  dataset = context.get_returnn_dataset()
  dataset = dataset.padded_batch(context.get_default_max_seqs())
  dataset = context.map_producer_to_consumer(dataset)
  return dataset

And also the returnn.InputContext() seems reasonable. The question is now, how would this look like in practice? If I understand it correctly, all non-TF processing should stay inside the original datasets, and be fixed, and the context.get_returnn_dataset() call will return a tf.data.Dataset containing a dict of single sequences, such as in my example.

So my questions are:

  • At which point should the user be able to add his processing, and in which way?

In the config. Under the option dataset_pipeline. Exactly as in the example you just posted. Or by just setting dataset_pipeline = True, which would use the default pipeline. dataset_pipeline = None would not use it at all but fall back to our old pipeline.

Is this a separate data_network that will be applied at the end when all datasets are joined? (E.g. with a MetaDataset or CombinedDataset)

This data_network would not be part of the initial design. This is somewhat optional. We could later add this as a function to the InputContext, sth like apply_data_network(net_dict: dict, dataset: Dataset) -> Dataset or so.

  • Will there be a possibility to bypass MetaDataset and CombinedDataset, by using tf.data.Dataset.zipand tf.data.Dataset.interleave instead? So the data pipeline will get a list of multiple returnn.InputContext?

What do you mean by bypass? You can configure the RETURNN dataset as before, in whatever way you want (using CombinedDataset or not, what you want).
You can also use Dataset.zip or Dataset.interleave in whatever way you want (in dataset_pipeline).
I don't quite understand the question.

  • Will the feature processing only rely on the RETURNN layers? And what about possible parameters then? And how would you deal with the situation when you want different processing for train and dev? You can access it from the InputContext API, but how can this be used in the network? Or do you want to have something like: train_data_network and dev_data_network?

As said, this initial design does not cover the feature processing at all. It is simply that function dataset_pipeline for now. Although you could do some feature processing in there, but it would be pure TF code (although you can of course also directly use RETURNN layers there, as you want).

  • How does the parallel processing work? I know that there currently is a limitation for multithreading, as e.g. in the case of CachedDataset2 each sequence is loaded in a single load_seqs(idx, idx+1) call when the result of get_seq_length is needed for the batch generator.

All of this is draft is not about the RETURNN dataset. As I explained, this is totally separate, and irrelevant here. (Although it would allow to work around it.)

You could load multiple RETURNN datasets though. They would automatically be parallelized (in multiple dataset/producer workers - via distributed TF). (This probably would not be implemented initially, but the design of the API should make this possible.)

So would I have to define 4 OggZipDatasets on my own with 4 exclusive segment lists which run in their own process each, and then interleave those in a custom defined tf.data pipeline?

You could do that. This would be again via distributed TF. (But as said, this would probably not be implemented/supported initially. We just should make sure that this is easy to add later, and the API supports it directly, or can easily be extended.)
We maybe should prepare an example already for it. That's what the kwargs in get_returnn_dataset are for. I imagine sth like this:

def dataset_pipeline(context):
  context.set_multiple_producers([{"path": "file1.zip"}, {"path": "file2.zip"}, {"path": "file3.zip"}])
  dataset = context.get_returnn_dataset(**context.get_kwargs_per_producer())
  dataset = dataset.padded_batch(context.get_default_max_seqs())
  dataset = context.map_producer_to_consumer(dataset)
  return dataset

Note that in most of our cases, the dataset loading is probably fast enough.
In the current multi-GPU implementation, it was only slow because we did it on each node, and wasted lots of resources.

  • Will the user be able to define a custom tf.data pipeline before or after the processing network, or both? (right now only after the processing network seems to make sense for me, but maybe I am missing something)

Both. It's all part of dataset_pipeline.

@albertz
Copy link
Member Author

albertz commented May 31, 2020

One small remaining question:

Should this new dataset pipeline (i.e. when you set dataset_pipeline) use distributed TensorFlow by default (i.e. have one dedicated worker for the dataset, and one worker for the training/evaluation or whatever)? (This is only a question about the default. Of course it should be possible to work without distributed TF. Also, even with this default, the training/evaluation would not be distributed; only the dataset would run in an own process and communication would work via distributed TF.)
Or should it not use distributed TF, and thus run in the same process? (Unless you explicitly enable distributed TF; or we implement something like SubprocessDataset, which wraps any dataset and runs it in a subprocess, and then you can explicitly use that if you want.)

Maybe it is more sensible to not use distributed TF by default, and keep these two things mostly orthogonal (although they are very related, as you see here in this issue).
Distributed TF could be enabled via a separate option, like distributed_tf = True or so.
We definitely should test distributed TF right from the beginning, though. E.g. functions like map_producer_to_consumer would be a no-op without distributed TF. To really test whether the design makes sense, it would be good to cover that.

Another point: Not using distributed TF, but using the new dataset pipeline together with Horovod, that would be somewhat incompatible, or requires explicit sharding (just as we do now in FeedDictDataProvider). I'm not sure if we should handle that somewhat explicitly (e.g. in map_producer_to_consumer). We also should print a warning that this is inefficient, and which recommends to enable distributed TF for that.

@kkm000
Copy link
Contributor

kkm000 commented Jun 1, 2020

I'm not sure I understand if you guys are using the word "distributed" in the same sense it's used in TF. Distributed across GPUs within a single machine, or distributed across machines on a network? Their "distributed" is the former.

I also wanted to add [someone] but somehow that's not possible? Maybe because he is not part of the GitHub RETURNN team?

Yup. A long-standing FR. Long as in long enough to likely never be implemented.

@JackTemaki: Be careful about Dataset.from_generator, the first note re process affinity. This may bottleneck if you start to spread the data pipeline across machines.

@albertz
Copy link
Member Author

albertz commented Jun 1, 2020

I'm not sure I understand if you guys are using the word "distributed" in the same sense it's used in TF. Distributed across GPUs within a single machine, or distributed across machines on a network? Their "distributed" is the former.

No, distributed TF (tf.distribute) covers both, distributed across GPUs within a single machine, and also distributed across machines. The latter via the server (tf.distribute.Server) / client architecture. No matter if that is between-graph replication or in-graph replication. I collected an overview of all that here. This is already supported since the very first public TF release. E.g. see the whitepaper from 2015, figure 3, and section 3.3.

@JackTemaki: Be careful about Dataset.from_generator, the first note re process affinity. This may bottleneck if you start to spread the data pipeline across machines.

I think this was just a demo. We would probably not implement it that way. We would probably instead use GeneratorDataset more directly/explicitly.
We are well aware of the process affinity, though. We can have multiple dataset producer workers (optionally), or also only a single producer and multiple consumers (across the machines) (communicating via distributed TF). Or some other configuration. This draft here is mostly about planning an API on the RETURNN side which will make it very flexible to allow all of that in an easy way.

@kkm000
Copy link
Contributor

kkm000 commented Jun 2, 2020 via email

@albertz
Copy link
Member Author

albertz commented Jun 2, 2020

Just as a note: I started implementing this. Beginning with only the bare minimum. The first goal is to get single-GPU training to work. I will soon push some first commits.

As this is an optional (non-default) feature, it should not interfere with anything else when not used, so it is save to directly push this to master, and do the work directly there. (See also contribution guidelines.)

@albertz
Copy link
Member Author

albertz commented Jun 2, 2020

Yup, that was my point. Too much to my taste to call the suckers the same word. Keeping a parameter server on a CPU in a multi-GPU host is a 2 orders of magnitude w.r.t latency than a PS on a different host. There are too many “common in principle” things that make a difference between a go and a no-go in the real world.

Note that our default type of multi-GPU training is anyway as much async as possible. I.e. we never used parameter servers, and probably also will not do so in the future (this is somewhat inconsistent with what TF means by "async training", which is almost always the use of a parameter server). I.e. this latency should not matter too much for us.

albertz added a commit that referenced this issue Jun 2, 2020
albertz added a commit that referenced this issue Jun 2, 2020
albertz added a commit that referenced this issue Jun 2, 2020
Seems to work. At least for this simple case.
#292
@albertz
Copy link
Member Author

albertz commented Jun 2, 2020

The simple case (no distributed TF, no dedicated dataset loader workers, no Horovod, i.e. no multi-GPU training) should work now, at least with the default pipeline. You can just set dataset_pipeline = True in your config to try it out. It should already be faster for the single-GPU case (because the CPU->GPU transfer is async now).

There are many outstanding TODOs (check the code). Although most of them are only relevant for distributed TF (#281). I guess this is the next step.

Anyway, already this current support should be tested a bit. So please test it, and report your experience, or any problems (and also debug them if possible). (Test with some reasonably new TF 1.* version.)

Also, this issue can be almost closed then. This was about the design of the API (not really about the underlying implementation). If the design looks fine, i.e. is flexible enough to allow all features we need, and simple enough to configure, and also straight-forward, then this is what we want. Otherwise please comment now! Now it's still possible to change all of it. This will get much uglier and more complicated later.

@albertz
Copy link
Member Author

albertz commented Jun 3, 2020

I'm trying with such an implementation now for dynamic batch sizes via bucket_by_sequence_length:

def dataset_pipeline(context):
    """
    :param InputContext context:
    :rtype: tensorflow.data.Dataset
    """
    import tensorflow as tf
    dataset = context.get_returnn_dataset()
    # max_seq_length = {"bpe": 75}
    for key, limit in max_seq_length.items():
        dataset = dataset.filter(lambda elem: tf.less_equal(elem["size:%s:0" % key], limit))
    dataset = dataset.prefetch(100)

    # dataset = context.padded_batch_dataset(dataset)
    bucket_boundaries = []
    bucket_batch_sizes = [1]
    bs_nseqs = 2
    while bs_nseqs <= max_seqs:  # max_seqs = 200
        # batch_size = 10000
        bs_frames_per_seq = batch_size // bs_nseqs
        if bucket_boundaries:
            assert bucket_boundaries[0] > bs_frames_per_seq
        bucket_boundaries.insert(0, bs_frames_per_seq)
        bucket_batch_sizes.insert(0, bs_nseqs)
        if bs_nseqs == max_seqs:
            break
        bs_nseqs = bs_nseqs * 3 // 2  # We don't want to have too much buckets.
        if bs_nseqs > max_seqs:
            bs_nseqs = max_seqs
    dataset = tf.data.experimental.bucket_by_sequence_length(
        lambda elem: elem["size:data:0"], bucket_boundaries, bucket_batch_sizes)(dataset)

    dataset = context.map_producer_to_consumer(dataset)
    dataset = context.prefetch_to_consumer_device(dataset)
    return dataset

This is also pretty generic, and takes max_seqs, max_seq_length and batch_size into account.

I wonder if we might want to use sth like this as the default?
However, I don't want that the default becomes too complicated, and that there will be further options added for it (to e.g. control this magic number * 3 // 2, or to switch between bucketing and our old logic, etc). The idea is that if you want to customize any of that, you would just define your own dataset_pipeline in your code. This is intended to be the simple and straight-forward. (This is what this issue here is about.)

We could (should) also move that bucketing to an own function in context.

@sarahberanek sarahberanek pinned this issue Jun 4, 2020
@albertz
Copy link
Member Author

albertz commented Jun 5, 2020

Small status report:
I think this is mostly done. This issue was anyway only really about the API design, and that seems good (no objections so far by anyone).
The remaining TODOs (work-in-progress):

  • Implement chunking.
    Might not be too complicated using tf.function and GeneratorDataset. Otherwise we could also introduce some own native op if we need to (but there are already a lot of relevant functions; I don't think anything is missing, or that would make it easier).
    We should also support different time scales here, just as we do currently.
    We should also support Dataset.context_window (maybe there is no easy/clean way to pass through this option from the dataset, but then we should just have something equivalent).
  • Implement the same batching logic as we currently have.
    Again, this might not be too complicated using tf.function and GeneratorDataset.
  • Once we have both chunking + the old batching logic, this might actually be a good default behavior (not the bucketing as suggested before).
  • seq_tag and seq_idx is sometimes needed and used (see feed-dict logic for that). We would need this as well for the new pipeline. Maybe just provide it always?
  • Currently the pipeline uses all data-keys defined in extern_data.
    This is different to before, where it first build the network, and then checked which of the data-keys were actually used, and then only used those.
    This is not really possible anymore, since we must know which data-keys to provide before we build the network.
    This might be solved by having an explicit flag in extern_data for those keys which are really provided by the dataset (sth like provided_by_dataset=True, which would be the default).
  • Test more. There might be some edge cases sometimes (which should all be simple to fix or extend).
  • Support for distributed TF (and also Horovod, but that is implicitly covered).
    E.g. use TF MultiDeviceIterator (the underlying internal dataset for that, i.e. _PerDeviceGenerator and _ReincarnatedPerDeviceGenerator).
    Or tf.data.experimental.service.distribute.
  • Better profiling. The computing time measurement currently will not say anything about the dataset anymore.
    There is tf.data.experimental.latency_stats / LatencyStatsDataset which might measure exactly what we want to know. This could be the last dataset in the pipeline (even after prefetch). Any latency here means that the data was not ready fast enough. If the data is available, there should not be any latency.
    See also the tf.data performance guide and the tf.data performance analysis guide.

@albertz albertz unpinned this issue Jun 10, 2020
Spotlight0xff pushed a commit to Spotlight0xff/returnn that referenced this issue Sep 5, 2020
Spotlight0xff pushed a commit to Spotlight0xff/returnn that referenced this issue Sep 5, 2020
Spotlight0xff pushed a commit to Spotlight0xff/returnn that referenced this issue Sep 5, 2020
Spotlight0xff pushed a commit to Spotlight0xff/returnn that referenced this issue Sep 5, 2020
@albertz albertz added the good first issue Should be a good starting point to get familiar with RETURNN, hopefully not too hard. label Oct 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Should be a good starting point to get familiar with RETURNN, hopefully not too hard. TensorFlow
Projects
None yet
Development

No branches or pull requests

6 participants