PyTorch DataLoader improvements for Iterable Dataset #127072
Labels
module: dataloader
Related to torch.utils.data.DataLoader and Sampler
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 The feature, motivation and pitch
Currently, the PyTorch DataLoader makes it hard to design properly IterableDataset that correctly function.
The main reason is that the IterableDataset needs to know about the num_workers, batch_size and shuffle argument provided to the DataLoader in order to work properly. One could argue to expose them on the IterableDataset, but this isn't intuitive for new users and leads to more issues. Some could provide the same batch_size to both the dataset and DataLoader.
I am thinking a simple solution would be to have a protocol for IterableDataset. If they have a
setup
function, it would be called in the DataLoader init function to pass down the arguments of the user (num_workers, batch_size, shuffle).Interestingly, this is already done for DataPipes: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py#L305 but only with shuffling.
In the framework LitData: https://github.com/Lightning-AI/litdata, I had to override the PyTorch DataLoader to support this properly: https://github.com/Lightning-AI/litdata/blob/main/src/litdata/streaming/dataloader.py#L479. So the length of the iterable dataset doesn't cause issues with DDP.
Alternatives
Subclass the PyTorch DataLoader
Additional context
No response
cc @andrewkho @gokulavasan @ssnl @VitalyFedyunin @dzhulgakov
The text was updated successfully, but these errors were encountered: