Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] DataLoader architecture updates and TarDataset implementation #49440

Open
VitalyFedyunin opened this issue Dec 16, 2020 · 50 comments
Open
Labels
feature A request for a proper, new feature. module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@VitalyFedyunin
Copy link
Contributor

VitalyFedyunin commented Dec 16, 2020

DataLoader architecture updates and TarDataset implementation

Problem statement

This proposal aims to construct a modular, user-friendly, and performant toolset to address the ambiguous activity referred to as “dataloading” within PyTorch, a simplification attributable to the indivisibility of the DataLoader abstraction prescribed today. In reality, “dataloading” is a diverse set of operations that should be supported by extensible building blocks, out of which the present abstractions, and far more could be easily built. Some typical needs which are scarcely supported in the present implementation include:

  • Lazy loading - Users want to point PyTorch to a remote data source (e.g. http, S3, GCP, Azure, Manifold, Hive) and iterate over the contents without downloading the entire dataset, ideally only downloading samples as-soon-needed. Further, If a user writes such code for a remote storage platform, there is no natural place to contribute it for public or private reuse.
  • Structured data, heterogeneous storage
    • There are hundreds of ways to store a single structured dataset, each requiring a custom or highly configured DataLoader. Users want to take advantage of modularity and not reimplement complete DataSets over and over again. Suppose we have a simple dataset of sample:label pairs of images and ints. There are a number of dimensions whose outer product enumerates the possible storage formats for this data, each requiring a distinct (or specifically configured) DataLoader:

      • Primitive formats - Are images stored in an image storage format (e.g. one of these), as a tensor (.pt), as a serialized (e.g. pickle, json) Python data structure, etc.?
      • Grouping - Are pairs grouped together by directory, by an archive format (tar, HDF5), by a serialization format (e.g. json, protobuf), by common file string (e.g. image00012321.jpg, label00012321.txt), by meaningful filenames (e.g. image_00023423_dog.jpg), by contents of file headers, by pickled Python data structures, etc? Are filenames otherwise meaningful? Are file headers otherwise meaningful?
      • Sharding - Is the dataset partitioned for performance reasons into arbitrary groups, each containing grouped pairs? In which grouping format (e.g. directories, tar, arrow, parquet, HDF5)?
      • Compression - Are files or groups compressed or binarized, e.g. gz, zip, protobuf?
      • Locale - Are the files local, remote via custom request format (e.g. proprietary data, public REST API, kaggle dataset), on an http server, in cloud object storage?
    • The above example is only for an extremely simple data structure case. The reality of data is often dramatically more heterogeneous and complex (e.g. variable-length lists of bounding box points and strings in object detection, highly nested structures of user or product features in ranking).

    • Further, users want to find or contribute ops to decode specific file types (e.g. HDF5) and accelerated kernels (e.g. GPU mp3 decoding).

    • Given PyTorch maintains decoders for many storage formats, users want more powerful top-level abstractions, such as simply pointing PyTorch to a local or remote directory and receiving an iterator over best-efforts deserializations of the files within.

  • Shuffling - Users want control over when shuffling occurs within “Dataloading.” It often makes a big performance and accuracy difference whether samples within a shard are shuffled, samples are globally shuffled, shards are shuffled, etc.
  • Pipelining and parallelism - Users want to be able to pipeline their loading and preprocessing (rather than make multiple CPU passes, for example), specify a number of workers to read and preprocess data, and not worry about whether reading, preprocessing, or model execution are starved. This can include asynchronous processes which prefetch data to feed to others.

Tensorflow addresses many of the above needs with their TFRecord, dramatically simplifying the problem by taking a strong opinion of the data format with which Tensorflow works best. This has been extremely successful from a performance perspective. However, by prescribing a single storage format, all others are demoted, and the diversity of data needs and entrenched formats made ubiquitous adoption of TFRecord for storage practically impossible. We’ve heard directly from users that they do not want to be forced into a single first-class format, and the public datasets (which Google rehosts in TFRecord), tend to agree (by completely disagreeing on format). For this reason, we prefer extensibility over prescription, wherein we provide performant support for a basic set of formats in-tree (e.g. Hive and Manifold internally, tar shard and Arrow externally) but users can plug in modular extensions for new formats easily.

Underlying DataLoader Issues

Beyond the needs described above, the existing DataLoader is also a frequent source of user requests and github issues. Such feedback includes, but is not limited to:

  • Fork and general multi-processing memory usage patterns - There are multiple reports in GitHub that users are confused about how Fork’s copy-on-write and Python’s object counting work together, and that leading to OOMs. And Pytorch users shop for custom solutions as separate list management processes, or sharing binary segments etc.
  • Threading vs Multiprocessing - Different use cases require one or the other. For example, threading generally performs better while multiprocessing works better with third-party libraries with non-threadlocal state.
  • Overcomplication of solutions - TarDataset requires custom shuffling and sampling implemented as Datasets, while our built-in solution requires altering the DataLoader. It would be best to separate data processing (reordering included) from process management.
  • Multiprocessing support - Today, proper pre-fetching is not possible due to the synchronous nature of Datasets. In order to bypass this, users must implement custom multiprocessing-enabled Datasets and DataLoaders themselves.
  • Manual sharding - Sharding is increasingly becoming a mandatory feature, allowing better multiprocessing and distributed execution. Currently users must implement it manually.
    Finally, the ubiquity of the DataLoader necessitates strong backward compatibility. For this reason we do not plan to deprecate any existing functionality, but in some cases may offer a more modern way of doing things.

Solution

Break down Dataset into smaller components DataPipe-s reducing logic to a queue of data-in and a queue of data-out.

DataLoader observes the acyclic graph of DataPipe-s and provides the necessary level of parallelism using multiprocessing and multithreading.

Bear in mind that even if we use IterDataPipe in examples below, all this also applicable to MapDataPipe.

Separating by smaller DataPipe-s and connecting them together

class ListFiles(datapipes.iter.IterDataPipe):
  #...
  def __iter__(self):
      # yield file_names
class LoadFiles(datapipes.iter.IterDataPipe):
  def __init__(self, listfiles_dp):
      self._listfiles_dp = listfiles_dp
      # ...
   def __iter__(self):
      for file_name in listfiles_dp:
          yield (file_name, load_file(file_name))

Will allow us to simplify DataPipe code and make them reusable across various implementations (for example ImageFolder and TarDataset). Also necessary in case of moving memory consuming DataPipe into separate processes.

Turning IterDataPipe (or IterableDataset) and MapDataPipe (or MapDataset) into NonBlockingIterDataPipe and NonBlockingMapDataPipe

Multiprocessing/threading support makes us prefer nonblocking_next over __next__ function. Key difference is that nonblocking_next might throw NotAvailable exception, meaning that data is not yet available and should be requested again with nonblocking_next.

DataPipe (and older Datasets) which implements only nonblocking_next can be easily used as standard DataPipe because parent class provides necessary API:

class NonBlockingIterDataPipe(datapipes.iter.IterDataPipe):
  def __iter__(self):
      return self

  def __next__(self):
      while True:
          try:
              return self.nonblocking_next()
          except StopIteration:
              raise StopIteration
          except NotAvailable:
              time.sleep(DELAY)
              EventLoop.iteration()

  def nonblocking_next(self):
      raise NotImplemented

Existing synchronous DataPipe (and older Datasets) can be turned into non-blocking DataPipe using helper function:

def EnsureNonBlockingNextDataPipe(validated_datapipe):
   if not isinstance(validated_datapipe, IterDataPipe):
       raise Exception('Not IteratableDataset')
   if isinstance(validated_datapipe, NonBlockingIterDataPipe):
       return validated_datapipe
   if not hasattr(validated_datapipe, '_as_iterator'):
       setattr(validated_datapipe, '_as_iterator', None)
   if not hasattr(validated_datapipe, 'nonblocking_next'):
       def nonblocking_next(self):
           if self._as_iterator is None:
               self._as_iterator = iter(self)
           return next(self._as_iterator)
       setattr(validated_datapipe, 'nonblocking_next', nonblocking_next)
       validated_datapipe.nonblocking_next = types.MethodType(nonblocking_next, validated_datapipe)
   return validated_datapipe

Combination of two approaches will allow a mix of old-style DataPipe (and datasets) and new non-blocking datapipes.

As nonblocking_next does not guarantee results to be returned, it can be used to schedule requests ahead:

class Prefetcher(datapipes.iter.NonBlockingIterDataPipe):
  def __init__(self, source_dp, buffer_size = 10):
      self._souce_pd = source_pd
      self._buffer_size = buffer_size
      self._buffer = []
      self._source_depleted = False

  def nonblocking_next(self):
      if not self._source_depleted:
          while len(self._buffer) < self._buffer_size:
              try:
                  data = self._souce_dp.nonblocking_next()
              except NotAvailable:
                  # break or put more requests, depends from implementation
                  break
              except StopIteration:
                  self._source_depleted = True
                  break
              self._buffer.append(data)
      if len(self._buffer):
          data = self._buffer.pop(0)
          return data
      else:
          if self._source_depleted:
              raise StopIteration
          else:
              raise NotAvailable

Similar approach will be applied to MapDataPipe with nonblocking_get(id).

Connecting blocks with queues

Having all datapipes as non-blocking (asynchronous), allows to connect them with a couple of queues.

For example in multiprocessing version, sub process main loop can look like this:

def IteratableDataPipeToQueuesLoop(source_datapipe, req_queue, res_queue):
  steps = 0
  EventLoop.enabled = False
  for _ in IteratableDataPipeBehindQueues(source_datapipe, req_queue, res_queue, raise_stop=True):
      steps += 1
      time.sleep(DELAY)
      pass

def IteratableDataPipeBehindQueues(source_datapipe, req_queue, res_queue, raise_stop = False):
  source_datapipe = EnsureAsyncNextDataset(source_datapipe)
  while True:
      try:
          req_queue.get(block = False)
      except:
          yield True
          continue
      while True:
          try:
              value = source_datapipe.nonblocking_next()
          except NotAvailable:
              yield True
              continue
          except StopIteration:
              res_queue.put(StopIteration())
              if raise_stop:
                  raise StopIteration
              else:
                  yield True
              continue
          res_queue.put(value)
          yield True # Returns control

When main process can transparently access this datapipe with simple wrapper:

class QIteratableDataPipe(datapipes.iter.NonBlockingIterDataPipe):
  def __init__(self, request_queue, response_queue, response_wait_time = 0.00001):
      self._req_q = request_queue
      self._res_q = response_queue
      self._req_sent = False
      self.counter = 0
      self._stop_iteration = False
      self._response_wait_time = response_wait_time

  def nonblocking_next(self):
      if self._stop_iteration:
          raise Exception('next called after receiving StopIteration')
      if not self._req_sent:
          self._req_q.put(self.counter)
          self.counter += 1
          self._req_sent = True
      try:
          value = self._res_q.get(block = True, timeout = self._response_wait_time)
      except:
          raise NotAvailable
      self._req_sent = False
      if isinstance(value, StopIteration):
          self._stop_iteration = True
          raise StopIteration
      return value

Allow to send DataPipe into separate process by few lines of code:

req_queue = multiprocessing.Queue()
res_queue = multiprocessing.Queue()
p2 = multiprocessing.Process(target=IteratableDataPipeToQueuesLoop, args=(source_datapipe, req_queue, res_queue))
p2.start()
separated_source_datapipe = QIteratableDataPipe(req_queue, res_queue)

Please note, that only one request in the queue, is an implementation restriction and not enforced by design.

DataLoaderQueue

The above examples using standard multiprocessing Queue, but it is not the best choice (performance-wise) in some cases and not working in others. Instead we suggest to replace it with higher abstraction DataLoaderQueue.

DataLoaderQueue - used to pass data between elements of a pipeline inside a single thread, between threads, between processes, in distributed env. DataLoader will replace queue with best for the moment implementation, but they all should follow next requirements:

  • Non-blocking
  • Guaranteed delivery
  • Guaranteed no duplicates
  • Guaranteed order
  • Customizable length
  • Queue is always between TWO processes/threads

API:

  • def get(blocking=True) - returns any python structure, or raises NotAvailableException, or raises QueueError
  • def put(data, blocking=True) - data is any Python structure, may raise QueueError

DataLoaderQueue implementation also defines ‘serialization’ technique, from simple pass object reference inside the same thread to IPC calls and full object serialization to be passed via network.

Users API

DataPipe should work as standard iterators (or implement get__item) outside of DataLoader.

numbers_dp = datapipes.iter.Numbers() # Returns range of integers
dp1, dp2, dp3 = datapipes.iter.Multiply(numbers_dp, 3) # Creates 3 copies of input data
def mult100(x):
  return x * 100
dp2_modified = datapipes.iter.Callable(dp2, mult100)
def mult111(x):
  return x * 111
dp3_modified = datapipes.iter.Callable(dp2, mult111)
joined_dp = datapipes.iter.GreedyJoin(dp1, dp2_modified, dp3_modified)
for i in iter(joined_dp):
  print(i) # 0 0 0 1 100 111 222 200 2 ......

DataLoader output should be exactly the same, but different pieces of graph might be executed as separate threads/processes.

for i in DataLoader(joined_dp):
   print(i) # 0 0 0 1 100 111 222 200 2 ......

Naming

There are a number of concepts which we would like to take this refactoring opportunity to clarify, though we also emphasize the importance of backward compatibility. We propose the following naming scheme for the components described within the scope of this doc, including typical end-user code samples.

  • Dataset - A factory producing a data preparation iterator, a graph of DataPipes.
    • ImageNet() -> function or class (doesn’t matter) returning an iterator (DataPipe) over ImageNet batches.
    • There is no Dataset “base class” for the purposes of a given function signature or functionality, it is now only a name. The existing Dataset classes remain for BC, but they simply wrap Datapipe.
  • DataPipe - A node in a data preparation graph, taking one iterator or index to another (e.g. Untar(ListTarFiles(()))).
  • DataLoader - An execution engine for passing data through the datapipe graph and persisting loading settings, taking advantage of device and parallelism opportunities.

Sharding

Sharding should be implemented on the framework level and hidden from DataPipe users. DataPipe developers will get control over sharding settings and running configurations. DataLoader will decide how to split DataPipe into shards and run configuration.

DataPipe blocks will provide information to the DataLoader if they support sharding via datapipe.is_shardable(). If a function is not defined DataPipe will be considered as non-shardable.
DataLoader will callback DataPipe objects with sharding settings using datapipe.sharding_settings(total_num_of_shards, id_of_shard).

Example:

list_files_dp = datapipes.map.ListFiles(root = '.') * marked as non shardable
load_bins_dp = datapipes.map.LoadFiles(list_files_dp) * marked as shardable
decode_images_dp = datapipes.map.DecodeImages(load_bins_dp) * marked as shardable
transform_dp = datapipes.map.TransformImages(decode_images_dp) * marked as shardable
shuffle_dp = datapipes.map.Shuffle(transform_dp) * marked as non shardable
sampler_dp = datapipes.iter.Sampler(shuffle_dp) * marked as non shardable

Individual Process (Thread)

Situations like prefetching and large non-forkable arrays require to spawn separate processes for a DataPipe. DataPipe blocks will provide information to the DataLoader if they are recommended to be executed as separate processing via datapipe.is_separate_process().

Lazy Initialization

In some cases it is inefficient to initialize DataPipe data before usage. For example, we need to postpone loading a full list of files before forking out a file scanner. For this purpose lazy_init function will be called prior to any __len__, __get_item__, __iter__ operators.

Functional DataPipe

DataLoader should not care about any data logic (including sampling, shuffle, and collate).

Moving Sampler from DataLoader into separate DataPipe

We are planning to create Samplers DataPipe for each existing logic as well as a wrapper around existing Sampler classes.
PR: #49363

# use default sequential sampler (basically do nothing)
sequential_sampled_ds = datapipes.iter.Sample(iter_ds) 

# use random sampler with replacement to generate random item from input datapipe
random_sampled_ds = datapipes.iter.Sample(iter_ds, sampler=RandomSampler, replacement=True) 

Note:

All of SamplerDataPipes can be replaced by another Iterable DataPipe, and Sampler is not required in the Data pipeline.

  • RandomSampler without replacement or SubsetSampler can be replaced by ShuffleIterableDataset with different buffer size
  • WeightedSampler -> WeightedShuffleDataset (If needed)
  • BatchSamper -> BatchDataset
  • Other customized samplers can be replaced by Callable DataPipe to run customized sample function
    In general, sampler datapipe is not suggested to be used in the new pipeline, and we keep it in favor of non BC-breaking.
    Example for the replacement of SubsetSampler:
def subset_sampler(ds):
    buffer = []
    for x in ds:
        if len(buffer) == buffer_size:
            idx = random.randint(0, buffer_size - 1)
            yield buffer[idx]
            buffer[idx] = x
        else:
            buffer.append(x)
    random.shuffle(buffer)
    while buffer:
        yield buffer.pop()
out = datapipes.iter.Callable(ds, subset_sampler)

Moving Collate functions from DataLoader into separate DataPipes

We are going to move collate logic out of DataLoader and implement it as IterDataPipe, it will accept old collate functions as argument or can be rewritten entirely.
PR: #48933

batch_dp = datapipes.iter.BatchNumbers() # Returns batch of integers [1,2,3],[4,5,6],..
default_collated_dp = datapipes.iter.Collate(batch_ds) # use original default collate function
for i in DataLoader(default_collated_dp):
   print(i) # tensor([1, 2, 3]), tensor([4, 5, 6]), ...
 
def collate_fn(batch):
    sum = batch[0] + batch[1] + batch[2]
    return torch.tensor(sum, dtype=torch.float)
default_collated_dp = datapipes.iter.Collate(batch_ds, collate_fn=collate_fn)
for i in DataLoader(default_collated_dp):
   print(i) # tensor([6.]), tensor([15.]), ...

Moving Shuffle from DataLoader into separate Datasets

class BufferedShuffleDataset(IterableDataset[T_co]):
r"""Dataset shuffled from the original dataset.
This class is useful to shuffle an existing instance of an IterableDataset.
The buffer with `buffer_size` is filled with the items from the dataset first. Then,
each item will be yielded from the buffer by reservoir sampling via iterator.
`buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
is required to be greater than or equal to the size of dataset.
When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
And, the method to set up a random seed is different based on :attr:`num_workers`.
For single-process mode (:attr:`num_workers == 0`), the random seed is required to
be set before the :class:`~torch.utils.data.DataLoader` in the main process.
>>> ds = BufferedShuffleDataset(dataset)
>>> random.seed(...)
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
function in each worker.
>>> ds = BufferedShuffleDataset(dataset)
>>> def init_fn(worker_id):
... random.seed(...)
>>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
Arguments:
dataset (IterableDataset): The original IterableDataset.
buffer_size (int): The buffer size for shuffling.
"""
dataset: IterableDataset[T_co]
buffer_size: int
def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None:
super(BufferedShuffleDataset, self).__init__()
assert buffer_size > 0, "buffer_size should be larger than 0"
self.dataset = dataset
self.buffer_size = buffer_size
def __iter__(self) -> Iterator[T_co]:
buf: List[T_co] = []
for x in self.dataset:
if len(buf) == self.buffer_size:
idx = random.randint(0, self.buffer_size - 1)
yield buf[idx]
buf[idx] = x
else:
buf.append(x)
random.shuffle(buf)
while buf:
yield buf.pop()

iter_dp = dp # Returns 0, 1, 2, 3, 4, 5, 6, 7, 8,...
shuffled_dp = datapipes.iter.Shuffle(iter_dp) # Returns 5, 2, 9, 0,...

Other functional DataPipes

In order to provide more versatile API, we plan to add more functional DataPipe for users.

  • Batch
  • PaddedBatch
  • unbatch ...
  • Repeat
  • Cache
  • Filter
  • zip
  • ...

Reproducibility and randomness

Should be part of DataLoader implementation, to be able to define random seed in case of various parallelization techniques.

Async (non-blocking) operations also introduce non-determinism of order, so we would need to implement a DataLoader attribute to order of non-blocking calls fulfillments and to guarantee order determinism.

To Do

This document doesn’t touch the problem of varying batch size for different phases of processing. It is archivable by passing a list of objects into the queue and will be considered at the phase of queue implementation. However it is better to put code example here.

This document doesn't cover distributed training in detail. We are going to extend on this topic using additional sharding parameters and queue implementations.

Considerations

User defined sharding was considered unnecessary at the early stages, however, nothing in the proposed architecture prevents from implementing it later.

CPP implementation was considered as non-flexible. However, nothing prevents users from creating DataPipes with CPP internals.

Torchscript can be used inside of DataPipes, but we are not limited to it.

Arrow/Proto/… can be used to pass data between DataPipes.

Error Tracing?

C++

cc @ssnl @VitalyFedyunin @ejguan

@VitalyFedyunin VitalyFedyunin added the module: dataloader Related to torch.utils.data.DataLoader and Sampler label Dec 16, 2020
@pritamdamania87
Copy link
Contributor

@VitalyFedyunin Is it possible to update the RFC with a simple example of what the dataloader would look like while training with DDP where we need to shard data?

@ezyang ezyang added feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 16, 2020
@ezyang
Copy link
Contributor

ezyang commented Dec 16, 2020

Reminder that https://github.com/pytorch/rfcs/ is a thing, and might be easier to do in depth comments on the RFC

@ejguan ejguan pinned this issue Dec 16, 2020
@VitalyFedyunin
Copy link
Contributor Author

Answering sharding question. Let say we have dataset with access to some database of N elements.
We need to shard it and yield only 1/X of elements from this database, non overlapping with other
shards.

class IterateableDataset:
    pass

class NumbersIteratableDataset(IterateableDataset):
    def is_shardable(self):
        return True

    def __init__(self, range = 100):
        self.range = range

    def sharding_settings(self, total_num_of_shards, id_of_shard, seed = 0):
        # Called before __iter__
        self.total_num_of_shards = total_num_of_shards
        self.id_of_shard = id_of_shard
        self.seed = seed

    def __iter__(self):
        for i in range(self.id_of_shard, self.range, self.total_num_of_shards):
            yield (i + self.seed) % self.range

Now lets assume X == 3, so we have three consumers (and it doesn't matter how they are separated by physical/logical nodes).
For example it can be 1 on shard inside one process and two shards on the different machine.

# First
number_of_shards = 3
shard = NumbersIteratableDataset()
shard.sharding_settings(number_of_shards, 0)
# Second
number_of_shards = 3
shard_0 = NumbersIteratableDataset()
shard_0.sharding_settings(number_of_shards, 1)
shard_1 = NumbersIteratableDataset()
shard_1.sharding_settings(number_of_shards, 2)

So we basically need to share between machines only total number of shards and some cross system incremental id

for i in iter(shards_1):
    print(i)
# 2, 5, ...., 98

It is likely that we will need to have access to other data on next epoch, for this we can introduce seed or epock number to the sharding_settings

# Next epoch 
for i, shard in enumerate(shards):
    shard.sharding_settings(number_of_shards, i, seed = 777 * epoch_number )

for i in iter(shards_2):
    print(i)
# 79, 82 .... 97, 0, 3, ... 75

@VitalyFedyunin
Copy link
Contributor Author

The alternative option available when we have only legacy dataset without sharding support, nesting it with something like sharding filter will help

class OldDataset:
    def __iter__(self):
        for x in range(100):
            yield "str%d" % x

class ShardedFilterIterableDataset:
    def __init__(self, source_ds):
        self.source_ds = source_ds

    def sharding_settings(self, total_num_of_shards, id_of_shard, seed = 0):
        # Called before __iter__
        self.total_num_of_shards = total_num_of_shards
        self.id_of_shard = id_of_shard
        self.seed = seed

    def __iter__(self):
        for i, value in enumerate(self.source_ds):
            if i % self.total_num_of_shards == self.id_of_shard:
                yield value

number_of_shards = 3
legacy_ds = OldDataset()
sharded = ShardedFilterIterableDataset(legacy_ds)
sharded.sharding_settings(number_of_shards, 2)

for x in sharded:
    print(x)
# str2, str5 ....

@bgedik
Copy link

bgedik commented Dec 18, 2020

Two features I’d suggest:

  • Checkpointing
  • Uneven end of data handling (for distributed training)

@VitalyFedyunin
Copy link
Contributor Author

VitalyFedyunin commented Dec 18, 2020

Can you please elaborate on the "Uneven end of data handling (for distributed training)" topic. What exactly you are looking for?

Good point about checkpointing, just to be clear, we want to be able to stop straining at the 'middle' of epoch (data loop) and resume it after a cold restart, right?

@bgedik
Copy link

bgedik commented Dec 19, 2020

Say you have n data loader instances, in many distributed training scenarios if one instance runs out of data then all has to run out of data. Otherwise the training hangs.

@carmocca
Copy link
Contributor

* Uneven end of data handling (for distributed training)

Similarly, supporting #25162 (comment)

@jspisak
Copy link
Contributor

jspisak commented Dec 21, 2020

cc @jph00 for viz and any inputs. Also anyone else from the fast.ai community you think should comment here?

@nairbv
Copy link
Collaborator

nairbv commented Dec 22, 2020

nit: IteratableDataset -> IterableDataset

@nlgranger
Copy link

nlgranger commented Dec 28, 2020

Here are a few thoughts to contribute to this discussion:

  • Pre-processing cost varies greatly across domains and buying one-size-fits-all GPU nodes is complicated. Consequently, distributed data loading ought to be a targeted use-case for the data loader re-design so that this variability can be off-loaded to CPU-only servers.
  • I might be missing a point, but why is sharding a topic by itself? Isn't it easy to manage it in a similar fashion to DDP sampler? It's difficult to anticipate all possible use-cases so unless it requires a lot of boilerplate code from users it's probably best to leave this outside the core of pytorch dataloader. Maybe just provide a few utilities and wrappers?
  • Reading the above posts didn't give me a clear idea on how the ordering of requested items is preserved.
  • What motivates the need for async item access? Workers idling during IO? Is this a real-world problem that can't be solved by adding more workers?

I have made a few experiments with data loading before, here is my modest feedback and sorry for the self-advertisement:

  • Iterable and indexable datasets are more easily handled via separate code-bases and APIs, there might be a bit of code duplication but it will simplify both implementations greatly. The access to items and supported data loading strategies vary too much to uniformize them. For instance shuffling an iterable dataset requires to build a buffer where items are shuffled.
  • Sampling and multi-processing don't play well together. A sampler API is needed to standardize when (RL or active learning researchers might want to comment on this maybe) and how (multiprocess-friendly) the next items are sampled and submitted to the dataloader.
  • "Moving Sampler and Shuffle from DataLoader into separate Datasets": see example here (jump to make_sequence method).

Extra minor remarks:

  • How about bringing people who know about the mpi/rdma stuff in the discussion about distributed data-loader?
  • About the OOM due to COW, persistent workers which start when the dataloader should help, it's also more intuitive.
  • About OOM due to shared memory, maybe use a fixed shm memory pool. The previous example uses that as well, it is actually implemented by this function which uses refcounting hooks to track free memory.
  • I don't think a use-case where the order of items is not important should be considered because it doesn't encourage reproducibility
  • Python 3.8 introduced features to help implement Apache Arrow-like zero-copy transfers with pickle: PEP 574, sample use-case which is also used by the dataloader example above.

@npuichigo
Copy link

npuichigo commented Jan 2, 2021

Answering sharding question. Let say we have dataset with access to some database of N elements.
We need to shard it and yield only 1/X of elements from this database, non overlapping with other
shards.

class IterateableDataset:
    pass

class NumbersIteratableDataset(IterateableDataset):
    def is_shardable(self):
        return True

    def __init__(self, range = 100):
        self.range = range

    def sharding_settings(self, total_num_of_shards, id_of_shard, seed = 0):
        # Called before __iter__
        self.total_num_of_shards = total_num_of_shards
        self.id_of_shard = id_of_shard
        self.seed = seed

    def __iter__(self):
        for i in range(self.id_of_shard, self.range, self.total_num_of_shards):
            yield (i + self.seed) % self.range

Now lets assume X == 3, so we have three consumers (and it doesn't matter how they are separated by physical/logical nodes).
For example it can be 1 on shard inside one process and two shards on the different machine.

# First
number_of_shards = 3
shard = NumbersIteratableDataset()
shard.sharding_settings(number_of_shards, 0)
# Second
number_of_shards = 3
shard_0 = NumbersIteratableDataset()
shard_0.sharding_settings(number_of_shards, 1)
shard_1 = NumbersIteratableDataset()
shard_1.sharding_settings(number_of_shards, 2)

So we basically need to share between machines only total number of shards and some cross system incremental id

for i in iter(shards_1):
    print(i)
# 2, 5, ...., 98

It is likely that we will need to have access to other data on next epoch, for this we can introduce seed or epock number to the sharding_settings

# Next epoch 
for i, shard in enumerate(shards):
    shard.sharding_settings(number_of_shards, i, seed = 777 * epoch_number )

for i in iter(shards_2):
    print(i)
# 79, 82 .... 97, 0, 3, ... 75

@VitalyFedyunin The same question. What if the shards can not be evenly dispatched to different workers (maybe GPUs) or the number of items in shards are different?

I saw in pytorch 1.7, join op can be used to address the problem of uneven dataset. But in the case of sharing, the last worker or GPU may still have a whole shard (may contain several items) to be consumed when the others are idle.

Does this need more consideration?

@npuichigo
Copy link

any update? @VitalyFedyunin

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Jan 13, 2021

Reposting this sharding question from the google doc version of this RFC:

dataset.sharding_settings changes the internal state of dataset. What do you think about the alternative which lets dataset.sharding(...) returns a new dataset instead?
I'm thinking about a situation where a dataset is sharded into multiple levels, e.g. in a setup with A GPUs, each with B dataloader worker processes, each worker uses C threads. The alternative seems easier to me for such cases, but with the current proposal I'm not sure how to create multiple levels of sharding.

@npuichigo
Copy link

Reposting this sharding question from the google doc version of this RFC:

dataset.sharding_settings changes the internal state of dataset. What do you think about the alternative which lets dataset.sharding(...) returns a new dataset instead?
I'm thinking about a situation where a dataset is sharded into multiple levels, e.g. in a setup with A GPUs, each with B dataloader worker processes, each worker uses C threads. The alternative seems easier to me for such cases, but with the current proposal I'm not sure how to create multiple levels of sharding.

Can u share the link? Is that public?

@VitalyFedyunin
Copy link
Contributor Author

Hello!

Doc URL is https://docs.google.com/document/d/1ILJQl1QFUWFRbmFW50askG5l4t_8xcvSMbIwHFK6DoU/edit?usp=sharing
but please refer to this PR as the most up-to-date version, as keeping all in sync is a nightmare. And sorry for the delay, I aggregating multiple feedbacks now into a singular update which is the next Tuesday.

@tmbdev
Copy link

tmbdev commented Feb 8, 2021

It looks like an ambitious redesign; I'm looking forward to its implementation. In particular, cleaning up the DataLoader and multiprocessing facilities will be very useful.

I want to point people at related resources that we have developed for these kinds of representations and processing pipelines.

  • The webdataset library is a library of standard IterableDataset implementations that interoperate with the existing PyTorch library and provide functionality similar to this design.
  • WebDataset is also a file format, namely storing datasets in POSIX tar archives with a simple naming convention. If you store your datasets in WebDataset format, you can read them efficiently with the current webdataset library, and you will later be able to read them natively with the new TarIterator
  • The tarp program is a command line program for quick and easy processing of large datasets in WebDataset format; it's a kind of "xargs for tar archives"
  • The AIStore server is an open source, high performance storage server that provides infinitely linearly scable I/O performance for WebDataset; it also provides server-side shuffling offline ETL, and online data transformations for WebDataset-style datasets

@tmbdev
Copy link

tmbdev commented Feb 8, 2021

A comment about the design and objectives of the project in general. I think from a software engineering point of view, it is nice to have support for all the Python I/O and concurrency facilities: async, threading, queues, and multiprocessing, and to allow for arbitrary sources and sinks. If you have the software develoipment resources, by all means, it's good to implement that.

Keep in mind, though, that any I/O pipeline you write as part of the deep learning job is limited to the CPU cores and PCI bandwidth on the machine with the GPU cards. On 8 or 16 GPU nodes, that greatly limits what you can do.

What we do for large scale training jobs is to run the I/O pipelines on separate nodes from the GPU pipelines. There are two common and easy ways of doing that:

  • The AIStore server allows on-the-fly processing pipelines to run on the storage server itself; that is, a sharded dataset can be stored on the distributed AIStore server and when you open each shard, the data augmentation and other kinds of data transformations are run transparently by AIStore.
  • The Tensorcom library lets you write preprocessing, shuffling, batching, and data augmentation pipelines in Python and then sends training batches directly to the GPU nodes using either ZMQ or Direct-to-GPU. The latter gives you 50-100 Gbps network bandwidth per GPU on high-end machines.

Note that this stuff works today and scales linearly (limited only by available hardware).

AIStore data augmentation pipelines are very simple for users, since they really just look like opening a datasets; the fact that the dataset is processed on the fly doesn't concern the DL job. We generally schedule Tensorcom-based pipelines using Kubernetes.

So, while I'm glad that async and threading are making it into the redesigned I/O pipelines, high performance I/O solutions for large scale deep learning probably just end up running a large number of distributed jobs, each of which is fairly simple and sequential.

@maxluk
Copy link

maxluk commented Feb 8, 2021

Thank you for the write up! It moves forward the design of DataLoader and addresses some common problems, so it's a great effort. It's hard to make conclusions on the solution though without end-to-end examples. The proposal is heavy on architectural principles which is good, but lacks description of how end user will interact with the system and how the end solution will look like. Can you include examples of how user code will look like for at least some of the problems you state in the beginning?
Other comments:

  • How APIs for composing data pipelines will look like? What are the goals for the pipelines? Which use cases will they cover? Which transformations are going to be supported out of the box? For which data types?
  • Is there any provision to offloading dataloading to external set of CPU nodes? Or is it a non goal?
  • How would global shuffling look like at architectural level?
  • More details on Caching in addition to mention that there is a functional DataPipe for that?
  • More details on modularity? Here, example of how Azure Blob or some other data reader component would integrate into user code would be especially interesting. Or use structured data source for example.

@VitalyFedyunin
Copy link
Contributor Author

Thanks, @tmbdev we are going to put some time and review the potential of adding AIStore and/or TensorCom as DataPipes.

@maxluk we are currently working on the next iteration of the RFC with extra attention to usage examples, your feedback is important and we will take it into account as well.

Answering other questions:

  • yes, we are considering supporting Distributed Data Loading in the future. But can't promise out of box soluting sooner than early 2022.
  • Shuffle would be implemented as DataPipe, by default we will provide buffer shuffle.
  • Likely we will have simple in-memory caching DataPipe as an example, and help users to develop a set of various caching data pipes.
  • I will add more examples on modularity, as we had it in mind during the design phase.

@tmbdev
Copy link

tmbdev commented Feb 10, 2021

@maxluk Here are YouTube videos explaining how shuffling, distribution, caching, and parallelization work for sequential loaders and pipes.

@npuichigo
Copy link

I saw the recent updates in torch.utils.data.datapipes and tried to utilize that for my own project.

The distributed sharding solution is to add a SamplerIterDataPipe after ListDirFilesIterDataPipe, and the sampler used in SamplerIterDataPipe is used to split the shards to different workers on different gpus.

class DistributedPartitionSampler(Sampler[int]):

    def __init__(self, data_source, *, shuffle=True, seed=0):
        super().__init__(data_source)
        self.data_source = data_source
        self.shuffle = shuffle
        self.epoch = 0
        self.seed = seed

    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.data_source), generator=g).tolist()
        else:
            indices = list(range(len(self.data_source)))

        worker_info = get_worker_info()
        if worker_info is None:
            num_workers = 1
            worker_rank = 0
        else:
            num_workers = worker_info.num_workers
            worker_rank = worker_info.id

        if not dist.is_initialized():
            gpu_rank = 0
            gpu_size = 1
        else:
            gpu_rank = dist.get_rank()
            gpu_size = dist.get_world_size()

        skip = max(1, gpu_size * num_workers)
        offset = gpu_rank * num_workers + worker_rank

        shared_index = indices[offset::skip]

        for index in shared_index:
            yield self.data_source[index]

    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch

Here the data source need to be Sized, but ListDirFilesIterDataPipe using yield and the total number is unknown, so I need to retrieve the file list before partition.

pipe = ListDirFilesIterDataPipe('.', masks='*.tar')
pipe = SamplerIterDataPipe(list(pipe), #or [x for x in pipe] because __len__ may throw exception 
                                              sampler=DistributedPartitionSampler)

Now I could get different shards on different workers (carefullly set epoch when training)

>>> # spawn with 2 GPUs and 3 workers per GPU...(12 shards)
    >>> # GPU-0 WORKER-0 gets [0,6]
    >>> # GPU-0 WORKER-1 gets [1,7]
    >>> # GPU-0 WORKER-2 gets [2,8]
    >>> # GPU-1 WORKER-0 gets [3,9]
    >>> # GPU-1 WORKER-1 gets [4,10]
    >>> # GPU-1 WORKER-2 gets [5,11]

@ejguan
Copy link
Contributor

ejguan commented Mar 9, 2021

@zhuzilin
Thanks for asking for MapDataPipe.

The problem with the tf.data-style iterative dataset is that it's hard to implement efficient shard or skip and can only do shuffling in a local pool instead of global shuffling and the cache operation can only cache all the data... For example, we have to shard the tfrecord filenames instead of sharding the data entries to get a good performance in tensorflow, otherwise tf.data has to at least read the header of all records. (Before tensorflow/tensorflow#46358, each shard would even have to read all the data...)

Yep, these are the fact for Iter-style DataPipe. But, there are still lots of benefits of iter-style datapipe like reducing memory overhead, better multiprocessing/threading support, etc. And, I do think there are ways to work around the problems you just pointed out. Like shuffling, you can still have the global shuffling by using a buffer that can hold the whole dataset, even though it's not recommended. And, just like you said, you could do global shuffling or sharding or skipping at the beginning of the whole pipeline (like filenames, file handlers, web requests, etc.), then you don't necessarily need to do the operation over actual data.
And, I believe map-style datapipes have same issues. Like shuffling, you need to load index or filenames (which will be similar to Iter-style) into memory and then shuffle, the memory overhead is still same as Iter-style.
Especially for the large-scale dataset, you can't really load everything into memory, even operations over indexes may cause OOM.

Therefore, I really hope we could have a separate pipeline logic for the map datasets instead of turning the map ones into iteratives.

We have received lots of requests for map-style datapipe, and we are gathering more feedback about it. We would really appreciate if you can elaborate the use cases that you think only map-style can achieve, but iter-style can not.

@ejguan
Copy link
Contributor

ejguan commented Mar 9, 2021

@npuichigo

@VitalyFedyunin Another question about sharding. Since we may use ListDirFilesIterDataPipe to yield shards, sometimes we don't know the actual number of shards. So how to deal with this situation.

I am actually quite curious about why ListDirFilesIterDataPipe prevents you having actual number of shards?

@ejguan
Copy link
Contributor

ejguan commented Mar 9, 2021

@sbrodehl

About a dataset, where sharding results in multiple workers having an uneven amount of data, which can result in "incomplete last batches". Currently, data from different shards is not merged, will this change?

Related issue: #44108

Yes, it will. In the future, you will be able to define when you want batching, before or after merged from shards.

@nlgranger
Copy link

I second @zhuzilin comment. The data preprocessing api should not get in the way of what we want to implement like tf.data does, random access to data should remain readily available. Wether or not it leads to atrocious IO and subpar performances is secondary. Ideally, the new API should allow random access and transparently take advantage of a more IO friendly sampling strategy, which would be documented in tutorials and selected by default in the dataloader for example.

@zhuzilin
Copy link
Contributor

zhuzilin commented Mar 9, 2021

@ejguan Thank you for your reply:)

I understand there are great advantages to use iter-style datasets. And it will be much easier to deal with all kinds of data formats or data sources if we switch to it. tf.data and the underlying io abstraction has proved its power in practice. I just hope that we could have a map-style counterpart to achieve the things that is hard for iter-style.

There are two use cases on my mind at the moment: sharding and caching.

  • First, about sharding.

I've been working on improving the performance of shard in the tensorflow community recently. As I mentioned in the comment above, because tf.data is completely iter-style (from the dataset ops, to the underlying io streams or file abstractions, to tfrecord the format itself), it's hacky to do operations that need to skip records. Though I added a Skip api to the iterator base class (tensorflow/tensorflow#40963), we still have to disable the BufferedInputStream underneath to avoid reading data that are skipped (tensorflow/tensorflow#46358 (comment)). Those require the user to have much knowledges of the framework architecture, otherwise, a shard(128, i) will make the io speed 100 times slower.

The workaround by tensorflow community is to shard the file names and read each file without the shard:

d = tf.data.Dataset.from_tensor_slices(filenames)
d = d.shard(num_workers, worker_index)
d = d.flat_map(lambda x: tf.data.TFRecordDataset(x))

instead of the intuitive:

d = tf.data.TFRecordDataset(filenames)
d = d.shard(num_workers, worker_index)

This work around would fail if the user have uneven-sized tfrecord files or maybe only one large file, which are common in practice.

On the other hand, the map-style dataset can easily achieve sharding with index = index + num_workers. And for tfrecord file, we could extract the offset of each record in the dataset ahead and use this metadata to do random read.

  • Second, about caching.

In our scenario, the user would save some large dataset on remote file systems like ceph and the latency fluctuation of the FS may affect the training speed a lot. Therefore, we need to cache part of the dataset to local disk to make IO more stable. However, it's hard to implement caching if we only need to store part of the data (The whole dataset is too large (TBs) to be hold in the local disk (GBs)), because we cannot check if a data entry is cached or not. Even if we can, we cannot get the entries that are not cached, since we cannot have random access to the dataset. Therefore, we only implement the caching functionality for pytorch, but not tf at the moment.

Again, I need to emphasis that I agree iter-style is a really nice model for most use cases. I just hope we could have some support for the map-style as well. The current dataset and dataloader model supports map-style and I really loved that. Hope we could find a way to keep on support at least part of that.

There are some comment on your rely.

Like shuffling, you can still have the global shuffling by using a buffer that can hold the whole dataset, even though it's not recommended.

Like shuffling, you need to load index or filenames (which will be similar to Iter-style) into memory and then shuffle, the memory overhead is still same as Iter-style.

For the iter-style, one need to get a pool that is the same size of the dataset to do global shuffling, while the map style would only need a function to produce the permutation of the index or directory. An image is several-hundred Ks, while a directory is only a couple of bytes. I don't think they have the same memory overhead.

Especially for the large-scale dataset, you can't really load everything into memory, even operations over indexes may cause OOM.

I agree. But I think most datasets do not belong to this category... I think it's reasonable to assume most users are facing the datasets at the scale of GBs to TBs instead of the gigantic ones that would take GBs only for index or directory. Even in this case, the map-style can still fallback to have a pool of index, which could hold a lot more records than the data pool of iter-style.

Thank you for your time on this looong comment~

@npuichigo
Copy link

@npuichigo

@VitalyFedyunin Another question about sharding. Since we may use ListDirFilesIterDataPipe to yield shards, sometimes we don't know the actual number of shards. So how to deal with this situation.

I am actually quite curious about why ListDirFilesIterDataPipe prevents you having actual number of shards?

@ejguan ListDirFilesIterDataPipe uses generator to yield shard file names, and __len__ is not implemented only if we pass in the length.

    def __iter__(self) -> Iterator[str] :
        yield from get_file_pathnames_from_root(self.root, self.masks, self.recursive, self.abspath)

    def __len__(self):
        if self.length == -1:
            raise NotImplementedError
        return self.length

So my question is how sharding works if all the IterDataPipe used are of unknown length. Or does it mean we may need pre-compute the number of (shard) files and then inject that to shard settings?

@nairbv
Copy link
Collaborator

nairbv commented Mar 10, 2021

Re: Uneven end of data handling

A common solution is an extra layer of indirection between "shards" and workers (which I guess would be instances of NumbersIteratableDataset() in the example). E.g., instead of:

# First
number_of_shards = 3
shard = NumbersIteratableDataset()
shard.sharding_settings(number_of_shards, 0)
# Second
number_of_shards = 3
shard_0 = NumbersIteratableDataset()
shard_0.sharding_settings(number_of_shards, 1)
shard_1 = NumbersIteratableDataset()
shard_1.sharding_settings(number_of_shards, 2)

you might have something like:

number_of_shards = 100
worker_one_shards = list(range(1,33))
worker1 = NumbersIterableDataSet()
worker1.sharding_settings...
worker_two_shards = list(range(33,66))
...

I.e you might have 100,000 rows, 100 shards, and 3 "workers." Each worker handles about 33 shards which is ~33,000 rows. Then you just have some extra communication overhead that if one worker is idle you need to move some shards from a busy node to that worker. You can add some rebalancing mechanism so that shards end up being moved to faster workers until they're allocated according to the speed of the hardware.

Going back to the original example, the same thing can probably achieved just by having a larger number of shards per machine.

@ejguan
Copy link
Contributor

ejguan commented Mar 10, 2021

@zhuzilin We really appreciate your feedback.
Just want to mention a fact that pytorch turns Map-style dataset into iter-style in the DataLoader at the end. All the benefits of MapDataset are the operations over indexes in Sampler. If we have already replaced sampler by directly doing operations over indexes within the pipeline, I don't really see random accessing to dataset is prevented by iter-style datapipe.

For sharding, what's your opinion about the following pipeline.

listfiles_dp = ListFiles(root=data_dir)
shard_dp = Shard(num_workers, worker_index)
loadfiles_dp = LoadFiles(shard_dp)
decoded_images_dp = DecodeImages(loadfiles_dp)
...

It definitely requires users more knowledge about the arch to add sharding pipe at the front of pipeline. We probably need to provide the factory pipelines for users to construct them without worry about where they should attach sharding pipe.

And that's the reason I said the memory overhead would be similar fot iter and map in this case. For map-styple, even if you do permutation over indexes, you still have a dict to save index to filepath or whatever.

This work around would fail if the user have uneven-sized tfrecord files or maybe only one large file, which are common in practice.

It's correct but heavily relying on the meta of TFRecord file to do the random access. It's not common case for users when they are using their own file format like tar, etc.

In our scenario, the user would save some large dataset on remote file systems like ceph and the latency fluctuation of the FS may affect the training speed a lot. Therefore, we need to cache part of the dataset to local disk to make IO more stable. However, it's hard to implement caching if we only need to store part of the data (The whole dataset is too large (TBs) to be hold in the local disk (GBs)), because we cannot check if a data entry is cached or not. Even if we can, we cannot get the entries that are not cached, since we cannot have random access to the dataset. Therefore, we only implement the caching functionality for pytorch, but not tf at the moment.

I am sorry that I am not sure that I understand what the exact need. Could you elaborate why iter-style datapipe doesn't work in this case?

@ejguan
Copy link
Contributor

ejguan commented Mar 10, 2021

@ejguan ListDirFilesIterDataPipe uses generator to yield shard file names, and __len__ is not implemented only if we pass in the length.

    def __iter__(self) -> Iterator[str] :
        yield from get_file_pathnames_from_root(self.root, self.masks, self.recursive, self.abspath)

    def __len__(self):
        if self.length == -1:
            raise NotImplementedError
        return self.length

So my question is how sharding works if all the IterDataPipe used are of unknown length. Or does it mean we may need pre-compute the number of (shard) files and then inject that to shard settings?

@npuichigo After ListDirFilesIterDataPipe yields all file names under a directory, users can choose basically whatever they want to shard these files.
Ohhh, if you are talking about using DataPipe with DistrubutedSampler in DataLoader, which requires len to generate random permutation of indexes, you need to wait for some time. We probably will update DataLoader later to make it forward compatible with new DataPipes then you don't need to do sampling for it. All the functionality of sampler will be replaced by other DataPipe without requiring __len__.

@zhuzilin
Copy link
Contributor

@ejguan
Sorry that I didn't expain the cache part clearly. Let me try to start with a simpler version. We could accelerate the IO pipeline by cache the data in the memory.

In tf.data, the user will:

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.cache()
dataset = dataset.repeat()

The inner logic of cache is similar to:

class CacheDataset(Dataset):
    def __init__(self, input_dataset):
        self.cache = Cache()
        self.input_dataset = input_dataset
        self.index = 0

    def __iter__(self):
        if self.index < len(self.input_dataset):
            data = next(self.input_dataset)
            self.cache[self.index] = data
        else:
            data = self.cache[self.index % len(self.input_dataset)]
        self.index += 1
        return data

It will start reuse the cache only after the whole dataset is cached into memory. And that is why we need to put the repeat after cache, otherwise the cache will keep on saving until OOM but never use the cached data.

However, we cannot cache the whole dataset in most cases, so we hope to cache only part of it and for the remains that are not cached, we could still read from the input_dataset. Something like this:

class CacheDataset(Dataset):
    def __init__(self, input_dataset, limit):
        # add a size limit to the cache
        self.cache = Cache(limit=limit)
        self.input_dataset = input_dataset
        self.index = 0

    def __iter__(self):
        if self.index in self.cache:
            data = self.cache[self.index % len(self.input_dataset)]
        else:
            # here we need to have random access to the input_dataset
            data = self.input_dataset[self.index]
            if not self.cache.full():
                self.cache[self.index % len(self.input_dataset)] = data
        self.index += 1
        return data

As the comment in the above code snippet, when trying to get data directly from the input dataset, we need to rely on random access. (I just found it is still about skipping some records...)

And in the origin comment, our team was trying to add one more hierachy to the caching mechanism. To deal with the remote storage system, we will have:

remote storage --|cache part|--> local disk --|cache part|--> memory

And this will also need the map-style.

@zhuzilin
Copy link
Contributor

@ejguan

If we have already replaced sampler by directly doing operations over indexes within the pipeline, I don't really see random accessing to dataset is prevented by iter-style datapipe.

I strongly agree on this and that's why I asked about the MapDataPipe in the first place. Because currently, we only have IterDataPipe in the master and it does not have a __getitem__ method for doing operations over indices. I'm curious about how will the index relevant part be merged into the current design. For instance, shall we have a different version for the MapDataPipe for operations like Map, Concat, Filter? how will the new DataLoaderdeal with the pipe graph node with index operation and without?

@VitalyFedyunin
Copy link
Contributor Author

The work in MapDataPipe is in progress now. We will also to map and concat but not filter for obvious reasons (as indexing becomes unclear or slow).

@VitalyFedyunin
Copy link
Contributor Author

About a dataset, where sharding results in multiple workers having an uneven amount of data, which can result in "incomplete last batches". Currently, data from different shards is not merged, will this change?

Related issue: #44108

Yes, it will change as we are going to support mini-batches and would be able to concatenate all 'streams' data.

@VitalyFedyunin
Copy link
Contributor Author

Answering sharding question. Let say we have dataset with access to some database of N elements.
We need to shard it and yield only 1/X of elements from this database, non overlapping with other
shards.

class IterateableDataset:
    pass

class NumbersIteratableDataset(IterateableDataset):
    def is_shardable(self):
        return True

    def __init__(self, range = 100):
        self.range = range

    def sharding_settings(self, total_num_of_shards, id_of_shard, seed = 0):
        # Called before __iter__
        self.total_num_of_shards = total_num_of_shards
        self.id_of_shard = id_of_shard
        self.seed = seed

    def __iter__(self):
        for i in range(self.id_of_shard, self.range, self.total_num_of_shards):
            yield (i + self.seed) % self.range

Now lets assume X == 3, so we have three consumers (and it doesn't matter how they are separated by physical/logical nodes).
For example it can be 1 on shard inside one process and two shards on the different machine.

# First
number_of_shards = 3
shard = NumbersIteratableDataset()
shard.sharding_settings(number_of_shards, 0)
# Second
number_of_shards = 3
shard_0 = NumbersIteratableDataset()
shard_0.sharding_settings(number_of_shards, 1)
shard_1 = NumbersIteratableDataset()
shard_1.sharding_settings(number_of_shards, 2)

So we basically need to share between machines only total number of shards and some cross system incremental id

for i in iter(shards_1):
    print(i)
# 2, 5, ...., 98

It is likely that we will need to have access to other data on next epoch, for this we can introduce seed or epock number to the sharding_settings

# Next epoch 
for i, shard in enumerate(shards):
    shard.sharding_settings(number_of_shards, i, seed = 777 * epoch_number )

for i in iter(shards_2):
    print(i)
# 79, 82 .... 97, 0, 3, ... 75

@VitalyFedyunin Another question about sharding. Since we may use ListDirFilesIterDataPipe to yield shards, sometimes we don't know the actual number of shards. So how to deal with this situation.

Oh, I see my wording is confusing and needs to be corrected in the document. When I wrote 'shards' I meant a number of consumers, not the number of partitions in the source dataset. For sure, having the number of shards divisible by the numbers of consumers is preferred for even load, but not necessary.

@VitalyFedyunin
Copy link
Contributor Author

Re: Uneven end of data handling

A common solution is an extra layer of indirection between "shards" and workers (which I guess would be instances of NumbersIteratableDataset() in the example). E.g., instead of:

...

Going back to the original example, the same thing can probably achieved just by having a larger number of shards per machine.

Thanks, this is exactly where my documentation is incomplete and require more details to separate shards and workers (and make API cleaner).

@munael
Copy link

munael commented Mar 16, 2021

Just curious what you guys think of InfiniBatch. Does it apply here at all? What's different or missing compared to your vision (I admit, couldn't follow the discussion and idea doc all that well 😅).

https://github.com/microsoft/infinibatch

@VitalyFedyunin
Copy link
Contributor Author

Just curious what you guys think of InfiniBatch. Does it apply here at all? What's different or missing compared to your vision (I admit, couldn't follow the discussion and idea doc all that well 😅).

https://github.com/microsoft/infinibatch

InfiniBatch looks interesting but doesn't fit our criteria of full backward compatibility and support of map style datasets

LuckerYi pushed a commit to LuckerYi/Fastspeech that referenced this issue Apr 14, 2021
# Add new data pipeline

* Add writers for zip/tar/chunk based shard format to be compatible with webdataset and TorchSpeech
* Abstract shard writer to support sharding when combined with zip/tar/chunk writers
* Make **FeatureInfo** more clear to behave as feature encoder/decoder before writing to or decoding from shards.
* Redesign data pipeline to be compatible with PyTorch's latest RFC (pytorch/pytorch#49440)
* Abstract distributed shard sampling to **SamplerIterDataPipe** with **DistributedPartitionSampler**.

Related work items: #3234727
@malfet malfet unpinned this issue Apr 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests