While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.
DataLoader is an iterable that abstracts this complexity for us in an easy API.

torch.utils.data.DataLoader is recommended for PyTorch users (a tutorial is here). It works with a map-style dataset that implements the getitem() and len() protocols, and represents a map from indices/keys to data samples. It also works with an iterable datasets with the shuffle argumnent of False. https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html

collate_fn receives a raw batch

Before sending to the model, collate_fn function works on a batch of samples generated from DataLoader. The input to collate_fn is a batch of data with the batch size in DataLoader, and collate_fn processes them according to the data processing pipelines declared previouly. Pay attention here and make sure that collate_fn is declared as a top level def. This ensures that the function is available in each worker.

The last torch specific feature we’ll use is the DataLoader, which is easy to use since it takes the data as its first argument. Specifically, as the docs say: DataLoader combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

Please pay attention to collate_fn (optional) that merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

In [None]:
from torch.utils.data.dataset import random_split
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])