# RepeatDalaloaders examples

This notebook presents examples of code where initialising a Dataloader results in spawning new processes each epoch. It also presents examples that prevent this issue.

See more information in the SkiNet documentation under Dataloaders.

# Dataloader - new worker each epoch

- Initialisation of a dataset is the main process

- From this main process we print: SillyDataset __init__ says hello! PID is [98805] 

- Then we initialise a dataloader using num_workers=2. For dataloaders, python uses multiprocessing library under the hood

- Once we start iterating over the epochs, **PyTorch spawns separate subprocesses (with their own PIDs)** to handle loading the data

- Basically at the beginning of each epoch we call

```for batch in loader:```

and under the hood it creates a new iterator

```
iterator = iter(loader)
batch = next(iterator)
```


-  What happens now is that **PyTorch creates worker processes once we enter __iter__()**. So basically it re-creates workers every time after we call "for batch in loader" = after entering "__iter__()".
- Here is the __iter__ method of the Dataloader class.

```
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> "_BaseDataLoaderIter":
    # When using a single worker the returned iterator should be
    # created everytime to avoid resetting its state
    # However, in the case of a multiple workers iterator
    # the iterator is only created once in the lifetime of the
    # DataLoader object so that workers can be reused
    if self.persistent_workers and self.num_workers > 0:
        if self._iterator is None:
            self._iterator = self._get_iterator()
        else:
            self._iterator._reset(self)
        return self._iterator
    else:
        return self._get_iterator()
```

where the iterator is obtained like this:

```
def _get_iterator(self) -> "_BaseDataLoaderIter":
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        self.check_worker_number_rationality()
        return _MultiProcessingDataLoaderIter(self)
```

-  The workers are created in the __init__ method of the  _MultiProcessingDataLoaderIter class, where a new process is created using Process:
-  
```
for i in range(self._num_workers):
    # No certainty which module multiprocessing_context is
    index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
    # Need to `cancel_join_thread` here!
    index_queue.cancel_join_thread()
    w = multiprocessing_context.Process(
        target=_utils.worker._worker_loop,
        args=(
            self._dataset_kind,
            self._dataset,
            index_queue,
            self._worker_result_queue,
            self._workers_done_event,
            self._auto_collation,
            self._collate_fn,
            self._drop_last,
            self._base_seed,
            self._worker_init_fn,
            i,
            self._num_workers,
            self._persistent_workers,
            self._shared_seed,
        ),
    )
    w.daemon = True
    w.start()
    self._index_queues.append(index_queue)
    self._workers.append(w)
```



- **The workers are shut down when the sampler is exhausted** as seen in the _MultiProcessingDataLoaderIter._next_data method. At the end of each epoch Pytorch kills old worker processes and at the beginning of the next, it spawns new processes  and these dudes all help call the Dataset's __getitem__() in parallel.

- Each worker grabs a batch of indices from this queue using multiprocessing magic.

- Each worker has its own Dataset instance (deserialized from the main one via pickle)

- Basically, one sees new PIDs every new epoch


In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, time

class SillyDataset(Dataset):
    def __init__(self):
        self.pid = os.getpid()
        print(f"This method is being called in the main process! SillyDataset __init__ says hello! PID is [{self.pid}] ")

    def __len__(self):
        return 20

    def __getitem__(self, idx):
        # This print happens inside worker processes when fetching data
        print(f"Fetching index {idx} in __getitem__ of SillyDataset PID is [{os.getpid()}] ")
        return idx

def run_default_dataloader():
    dataset = SillyDataset()
    print("data set initialisation completed")
    loader = DataLoader(dataset, batch_size=4, num_workers=2, shuffle=True)
    print("data loader initialisation completed")


    for epoch in range(5):
        print(f"\n🌍 Epoch {epoch}")
        
        for batch in loader:
            print(f"🔥 Main process [{os.getpid()}] got batch: {batch}")
            time.sleep(0.1)
            #break

run_default_dataloader()


This method is being called in the main process! SillyDataset __init__ says hello! PID is [33883] 
data set initialisation completed
data loader initialisation completed

🌍 Epoch 0
Fetching index 10 in __getitem__ of SillyDataset PID is [34151] Fetching index 14 in __getitem__ of SillyDataset PID is [34150] 

Fetching index 3 in __getitem__ of SillyDataset PID is [34151] Fetching index 6 in __getitem__ of SillyDataset PID is [34150] 

Fetching index 7 in __getitem__ of SillyDataset PID is [34151] Fetching index 12 in __getitem__ of SillyDataset PID is [34150] 

Fetching index 0 in __getitem__ of SillyDataset PID is [34150] Fetching index 16 in __getitem__ of SillyDataset PID is [34151] 

Fetching index 11 in __getitem__ of SillyDataset PID is [34150] Fetching index 13 in __getitem__ of SillyDataset PID is [34151] 

Fetching index 2 in __getitem__ of SillyDataset PID is [34151] Fetching index 17 in __getitem__ of SillyDataset PID is [34150] 

Fetching index 1 in __getitem__ of SillyData

# Dataloader with persistent workers 

- the same PIDs every new epoch

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, time

class SillyDataset(Dataset):
    def __init__(self):
        self.pid = os.getpid()
        print(f"SillyDataset __init__ says hello! PID is [{self.pid}] ")

    def __len__(self):
        return 20

    def __getitem__(self, idx):
        # This print happens inside worker processes when fetching data
        print(f"Fetching index {idx} in __getitem__ of SillyDataset PID is [{os.getpid()}] ")
        return idx

def run_dataloader_persistent_workers():
    dataset = SillyDataset()
    
    loader = DataLoader(
        dataset,
        batch_size=4,
        num_workers=2,
        persistent_workers=True,
        shuffle=True
    )

    for l in loader:
        print("all batches from loader:", l)

    for epoch in range(7):
        print(f"\n🌍 Epoch {epoch}")
        for batch in loader:
            print(f"🔥 Main process [{os.getpid()}] got batch: {batch}")
            time.sleep(0.1)
            #break

run_dataloader_persistent_workers()



SillyDataset __init__ says hello! PID is [33883] 
Fetching index 17 in __getitem__ of SillyDataset PID is [34853] 
Fetching index 19 in __getitem__ of SillyDataset PID is [34854] 
Fetching index 11 in __getitem__ of SillyDataset PID is [34853] Fetching index 10 in __getitem__ of SillyDataset PID is [34854] 

Fetching index 13 in __getitem__ of SillyDataset PID is [34853] 
Fetching index 7 in __getitem__ of SillyDataset PID is [34854] Fetching index 15 in __getitem__ of SillyDataset PID is [34853] 

Fetching index 4 in __getitem__ of SillyDataset PID is [34854] 
Fetching index 12 in __getitem__ of SillyDataset PID is [34853] Fetching index 8 in __getitem__ of SillyDataset PID is [34854] 

Fetching index 6 in __getitem__ of SillyDataset PID is [34853] Fetching index 18 in __getitem__ of SillyDataset PID is [34854] 

Fetching index 3 in __getitem__ of SillyDataset PID is [34853] 
Fetching index 2 in __getitem__ of SillyDataset PID is [34854] Fetching index 5 in __getitem__ of SillyDataset

In [3]:
#remove_cell
import sys

#sourcepath = '/Users/Pavel/Documents/repos/SkiNet/'
sourcepath = '/workplace/SkiNet/'

sys.path.insert(0,sourcepath)
#sys.path.insert(0,datapath)

#automatically track changes in the source code
%load_ext autoreload
%autoreload 2


%matplotlib inline

In [4]:
from SkiNet.ML.dataloaders.dataloaders import RepeatDataLoader

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, time

# set up logging
from SkiNet.Utils.loggers import stdout_logging, file_logging
import logging

stdout_logging(logging.DEBUG)
file_logging()

class SillyDataset(Dataset):
    def __init__(self):
        self.pid = os.getpid()
        print(f"SillyDataset __init__ says hello! PID is [{self.pid}] ")

    def __len__(self):
        return 20

    def __getitem__(self, idx):
        # This print happens inside worker processes when fetching data
        print(f"Fetching index {idx} in __getitem__ of SillyDataset PID is [{os.getpid()}] ")
        return idx

def run_RepeatDataLoader():
    dataset = SillyDataset()
    
    loader = RepeatDataLoader(
        dataset,
        batch_size=4,
        num_workers=2,
        persistent_workers=False,
        max_num_to_repeat=10,
        shuffle=True,
        drop_last=False
    )

    for l in loader:
        print("all batches from loader:", l)

    
    for epoch in range(7):
        print(f"\n🌍 Epoch {epoch}")
        for batch in loader:
            print(f"🔥 Main process [{os.getpid()}] got batch: {batch}")
            time.sleep(0.5)
            #break

run_RepeatDataLoader()

Setting up stdout_logging
Setting logging level to 10 for stdout logging
Setting logging level to 10 to log into /workplace/SkiNet/skinet_logs.log
SillyDataset __init__ says hello! PID is [33883] 
[92m2025-06-10 16:11:22 - DEBUG [__init__ (85) in SkiNet.ML.dataloaders.dataloaders] - BatchSampler repeatdl_batch_sampler length 5[0m
[92m2025-06-10 16:11:22 - DEBUG [__init__ (86) in SkiNet.ML.dataloaders.dataloaders] - BatchSampler repeatdl_batch_sampler <torch.utils.data.sampler.BatchSampler object at 0x7fe184d91d50>[0m
[92m2025-06-10 16:11:22 - DEBUG [__init__ (90) in SkiNet.ML.dataloaders.dataloaders] - BatchSampler repeat_sampler length 2[0m
[92m2025-06-10 16:11:22 - DEBUG [__init__ (91) in SkiNet.ML.dataloaders.dataloaders] - BatchSampler repeat_sampler <SkiNet.ML.dataloaders.dataloaders._RepeatSampler object at 0x7fe1854f5990>[0m
[92m2025-06-10 16:11:22 - DEBUG [__init__ (99) in SkiNet.ML.dataloaders.dataloaders] - Iterator in RepeatDataloder.__init__ None[0m
[92m2025-06-1

# Repeat Dataloader
- the same PIDs every new epoch