-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topicloopsRelated to the Loop APIRelated to the Loop APIver: 2.1.x
Milestone
Description
Bug description
This bug has reappeared #18414
We now call iter() twice in different places:
- https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py#L263
- https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py#L171C1-L172C1
What version are you seeing the problem on?
v2.1
How to reproduce the bug
import multiprocessing as mp
from queue import Queue
from typing import Iterator
import numpy as np
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from torch.utils.data import DataLoader, IterableDataset
class QueueDataset(IterableDataset):
def __init__(self, queue: Queue) -> None:
super().__init__()
self.queue = queue
def __iter__(self) -> Iterator:
for k in range(5):
print(f"getting {k}")
tensor, index = self.queue.get(timeout=10)
print(f"got {index}")
yield tensor
if __name__ == "__main__":
q = mp.Queue()
arr = np.random.random([1, 32]).astype(np.float32)
for ind in range(10):
q.put((arr, ind))
max_epoch = 1
dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
trainer = Trainer(max_epochs=max_epoch, enable_progress_bar=False, devices=1)
trainer.fit(BoringModel(), dataloader)
trainer.save_checkpoint("model.ckpt")
# q now has the next 5 elems in
# resuming training we will hit the double iter() issue
dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
trainer = Trainer(max_epochs=max_epoch + 1, enable_progress_bar=False, devices=1)
trainer.fit(BoringModel(), dataloader, ckpt_path="model.ckpt")
Error messages and logs
relevant logs are:
# first epoch all good
getting 0
got 0
getting 1
got 1
getting 2
got 2
getting 3
got 3
getting 4
got 4
# second epoch we start getting from the queue twice!
# from fit loop iter()
getting 0
got 5
getting 1
got 6
getting 2
got 7
# from training_epoch loop iter()
getting 0
got 8
getting 1
got 9
getting 2
Environment
lighting==2.1.4
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdata handlingGeneric data-related topicGeneric data-related topicloopsRelated to the Loop APIRelated to the Loop APIver: 2.1.x
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
awaelchli commentedon Feb 11, 2024
This condition here is meant to prevent the
iter()
from getting called a second time, because in this caserestarting
should be True.pytorch-lightning/src/lightning/pytorch/loops/training_epoch_loop.py
Lines 169 to 171 in 47c8f4c
But it isn't. The problem is that the fit loop sets
restarting=False
even though we are resuming, due to the logic here:pytorch-lightning/src/lightning/pytorch/loops/fit_loop.py
Lines 123 to 128 in 47c8f4c
This is tricky to solve @carmocca. The logic probably needs to be lifted up into the fit loop before
epoch_loop.run()
, with a different conditioning that does not rely onrestarting
.carmocca commentedon Feb 12, 2024
I didn't look too deeply. Couldn't we check
restarting
too for theFitLoop
'siter
call? We have a lot of tests around this so If a solution passes them we should be good.ben-da6 commentedon Feb 14, 2024
The problem in the
restarting
property isself._iteration_based_training()
is Falseben-da6 commentedon Feb 14, 2024
Also since this has appeared twice now, and its the sort of bug which is hard to track down could we add a test like my example?
iyilmaz24 commentedon Apr 9, 2025
Hey all, I went ahead and created the test for this #20705. Would love to hear any feedback :)
suprjinx commentedon Apr 25, 2025
I've been looking at this issue a bit, and I could use some help understanding what's needed. I think "calling iter twice" isn't necessarily a problem since any number of functions are calling it and the underlying data from the queue looks correct. What seems to be the issues are: 1) recognizing the empty queue as the end of the epoch and 2) aligning the epochs with the queue data.
Ie, i would say, "this actually looks okay but we should end processing smoothly when the queue is exhausted." But I could be wrong?
sudiptob2 commentedon May 2, 2025
Hey all, proposed a fix for this issue in #20775. Would appreciate it if anyone could take a look.
sudiptob2 commentedon Jun 6, 2025
@carmocca @awaelchli Hey, I was wondering if you could take a look at this solution? #20775