Skip to content

calling iter twice messes up dataloaders with queues  #19427

@ben-da6

Description

@ben-da6

Bug description

This bug has reappeared #18414

We now call iter() twice in different places:

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import multiprocessing as mp
from queue import Queue
from typing import Iterator

import numpy as np
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from torch.utils.data import DataLoader, IterableDataset


class QueueDataset(IterableDataset):
    def __init__(self, queue: Queue) -> None:
        super().__init__()
        self.queue = queue

    def __iter__(self) -> Iterator:
        for k in range(5):
            print(f"getting {k}")
            tensor, index = self.queue.get(timeout=10)
            print(f"got {index}")
            yield tensor


if __name__ == "__main__":
    q = mp.Queue()
    arr = np.random.random([1, 32]).astype(np.float32)
    for ind in range(10):
        q.put((arr, ind))
    max_epoch = 1
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=max_epoch, enable_progress_bar=False, devices=1)
    trainer.fit(BoringModel(), dataloader)
    trainer.save_checkpoint("model.ckpt")

    # q now has the next 5 elems in
    # resuming training we will hit the double iter() issue
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=max_epoch + 1, enable_progress_bar=False, devices=1)
    trainer.fit(BoringModel(), dataloader, ckpt_path="model.ckpt")

Error messages and logs

relevant logs are:

# first epoch all good
getting 0
got 0
getting 1
got 1
getting 2
got 2
getting 3
got 3
getting 4
got 4

# second epoch we start getting from the queue twice!
# from fit loop iter()
getting 0
got 5
getting 1
got 6
getting 2
got 7
# from training_epoch loop iter()
getting 0
got 8
getting 1
got 9
getting 2

Environment

lighting==2.1.4

More info

No response

cc @justusschock @awaelchli @carmocca

Activity

added
bugSomething isn't working
needs triageWaiting to be triaged by maintainers
on Feb 7, 2024
awaelchli

awaelchli commented on Feb 11, 2024

@awaelchli
Contributor

This condition here is meant to prevent the iter() from getting called a second time, because in this case restarting should be True.

# `iter()` was called once in `FitLoop.setup_data()` already
if self.trainer.current_epoch > 0 and not self.restarting:
iter(data_fetcher) # creates the iterator inside the fetcher

But it isn't. The problem is that the fit loop sets restarting=False even though we are resuming, due to the logic here:

def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values)
restarting = restarting and epoch_unfinished or self._iteration_based_training()
_Loop.restarting.fset(self, restarting) # call the parent setter

This is tricky to solve @carmocca. The logic probably needs to be lifted up into the fit loop before epoch_loop.run(), with a different conditioning that does not rely on restarting.

added
loopsRelated to the Loop API
data handlingGeneric data-related topic
and removed
needs triageWaiting to be triaged by maintainers
on Feb 11, 2024
added this to the 2.2.x milestone on Feb 11, 2024
carmocca

carmocca commented on Feb 12, 2024

@carmocca
Contributor

I didn't look too deeply. Couldn't we check restarting too for the FitLoop's iter call? We have a lot of tests around this so If a solution passes them we should be good.

ben-da6

ben-da6 commented on Feb 14, 2024

@ben-da6
Author

The problem in the restarting property is self._iteration_based_training() is False

ben-da6

ben-da6 commented on Feb 14, 2024

@ben-da6
Author

Also since this has appeared twice now, and its the sort of bug which is hard to track down could we add a test like my example?

modified the milestones: 2.2.x, 2.3.x on Jun 13, 2024
modified the milestones: 2.3.x, 2.4.x on Aug 7, 2024
iyilmaz24

iyilmaz24 commented on Apr 9, 2025

@iyilmaz24

Also since this has appeared twice now, and its the sort of bug which is hard to track down could we add a test like my example?

Hey all, I went ahead and created the test for this #20705. Would love to hear any feedback :)

suprjinx

suprjinx commented on Apr 25, 2025

@suprjinx

I've been looking at this issue a bit, and I could use some help understanding what's needed. I think "calling iter twice" isn't necessarily a problem since any number of functions are calling it and the underlying data from the queue looks correct. What seems to be the issues are: 1) recognizing the empty queue as the end of the epoch and 2) aligning the epochs with the queue data.

Ie, i would say, "this actually looks okay but we should end processing smoothly when the queue is exhausted." But I could be wrong?

sudiptob2

sudiptob2 commented on May 2, 2025

@sudiptob2
Contributor

Hey all, proposed a fix for this issue in #20775. Would appreciate it if anyone could take a look.

sudiptob2

sudiptob2 commented on Jun 6, 2025

@sudiptob2
Contributor

@carmocca @awaelchli Hey, I was wondering if you could take a look at this solution? #20775

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdata handlingGeneric data-related topicloopsRelated to the Loop APIver: 2.1.x

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

      Participants

      @suprjinx@awaelchli@carmocca@sudiptob2@iyilmaz24

      Issue actions

        calling iter twice messes up dataloaders with queues · Issue #19427 · Lightning-AI/pytorch-lightning