# Infinite dataloader
This demonstrates an inifinte dataloader in pytorch, that loads data from a custom dataset.
Th custom dataset is just for demo purpose. The torchvision.datasets.ImageFolder can be instead 
used to create the same.

The folder structure is:

/path/to/data/

|---train/\<class_folders\>/\<image_file\>

|---test/\<class_folders\>/\<image_file\>

|---val/\<class_folders\>/\<image_file\>
                

In [1]:
import glob
import math
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
from PIL import Image

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

In [2]:
# setups the dataset paths
dataset_path = Path('../datasets/intel_image_classification/')

# Setup train and test paths
train_dir = dataset_path / "seg_train/seg_train"
test_dir = dataset_path / "seg_test/seg_test"

print('train path {} \ntest path {}'.format(train_dir, test_dir))

os.path.isdir(train_dir), os.path.isdir(test_dir)

train path ../datasets/intel_image_classification/seg_train/seg_train 
test path ../datasets/intel_image_classification/seg_test/seg_test


(True, True)

## Stanford format for preparing Dataset Class

Here the stanford format for preparing custom dataset. A function is used to prepare two dictionaries segregating train-test set and mapping sample ids to labels

\>>> partition

{'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']}

\>>> labels

{'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1}

In [3]:
def dataset_preprocess(PATH) -> list:
    """
    Given the following folder structure:
    /path/to/data/
    |---train/\<class_folders\>/\<image_file\>
    |---test/\<class_folders\>/\<image_file\>
    |---val/\<class_folders\>/\<image_file\>
    
    the function returns three dicts:
    
    1) partition: contains two dicts 'train' and 'validation' containing respective sample ids
    2) maping of ids to labels
    3) class name to ids
    
    Args:
        param1 (str): 'path/to/data'
    Returns:
        list: [classes_to_idx, partition, labels, list of dicts]    
    """
    
    labels = {}
    classes_to_idx = {}
    partition = {'train': [], 'test': []}
    
    # print('PATH', PATH)
    for file in glob.glob(os.path.join(PATH, '*/*/*/*.jpg')):
        
        # print('file: ', file)
        
        class_name = file.split('/')[-2]
        # print('class_name', class_name)
        if not class_name in classes_to_idx.keys():
            classes_to_idx[class_name] = len(classes_to_idx.keys())
        # print('classes_to_idx', classes_to_idx)
        
        ID = file
        labels[ID] = classes_to_idx[class_name]

        if 'seg_train' in file:
            partition['train'].append(ID)
        elif 'seg_test' in file:
            partition['test'].append(ID)
                
        # print('partition', partition)
        # print('labels', labels)
        # break
        
    # print('classes_to_idx', classes_to_idx)
    # print('partition train test len', len(partition['train']), len(partition['test']))
    # print('labels', len(labels))
    
    return classes_to_idx, partition, labels
    
            
_, _, _ = dataset_preprocess(dataset_path)

In [4]:
# Extention of standard pytorch Dataset class. Returns X,y: images, labels after tranforms
class MyDataset(Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, list_IDs, labels, classes_to_idx):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs
        self.classes_to_idx = classes_to_idx
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((224, 224))])
    
    def find_classes(self) -> dict:
        'returns class names and indexes'
        return self.classes_to_idx
            
    @staticmethod
    def pil_loader(path: str) -> Image.Image:
        'loads an image file'
        with open(path, "rb") as f:
            img = Image.open(f)
            return img.convert("RGB")

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)
        
    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        # print('ID', ID)
        X = self.pil_loader(ID)
        X = self.transform(X)
        y = self.labels[ID]

        return X, y

In [5]:
# test the standard pytroch dataloader which is expected to throw StopIteration exception
classes_to_idx, partition, labels = dataset_preprocess(dataset_path)
test_dataset = MyDataset(partition['test'], labels, classes_to_idx)

print('class names', test_dataset.classes_to_idx)
print('dataset length', len(test_dataset))

BATCH_SIZE = 8
train_dataloader = DataLoader(dataset=test_dataset,
                              batch_size=BATCH_SIZE,
                              num_workers=1,
                              shuffle=True)

print('expected iterations', math.ceil(len(test_dataset)/BATCH_SIZE))

iter_obj = iter(train_dataloader)
idx = 0
while True:
    print(next(iter_obj)[0].shape, idx, '\r', end='')
    idx+=1

class names {'sea': 0, 'mountain': 1, 'buildings': 2, 'glacier': 3, 'street': 4, 'forest': 5}
dataset length 3000
expected iterations 375
torch.Size([8, 3, 224, 224]) 374 

StopIteration: 

## Infinite Dataloader
Create dataloader that runs infinitely by capturing the stop iteration error and intitating a fresh iterator

In [6]:
# extends pytorch dataloader with infinite iteration
class InfiniteDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize an iterator over the dataset.
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        'Expects StopIteration and initiates new iterator'
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:           
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch

In [7]:
# test the infinite dataloader
classes_to_idx, partition, labels = dataset_preprocess(dataset_path)
test_dataset = MyDataset(partition['test'], labels, classes_to_idx)

print('dataset length', len(test_dataset))

BATCH_SIZE = 8
train_dataloader = InfiniteDataLoader(dataset=test_dataset,
                              batch_size=BATCH_SIZE,
                              num_workers=1,
                              shuffle=True)
print('expected iterations from nomral loader', math.ceil(len(test_dataset)/BATCH_SIZE))

iter_obj = iter(train_dataloader)
idx = 0
while True:
    print(next(iter_obj)[0].shape, idx, '\r', end='')
    idx+=1

dataset length 3000
expected iterations from nomral loader 375
torch.Size([8, 3, 224, 224]) 1387 

KeyboardInterrupt: 