# How does reading from shelve files in a DataLoader compare to the alternatives?

One of the things that will slow us down in the competition is how to read the .dicom files fast. One standard alternative is to save the X-ray images in some image format like jpg or png, but could it be faster to save the whole data in a `pkl` type single file that one can access by some index?

In [an EDA notebook](https://www.kaggle.com/bjoernholzhauer/eda-dicom-reading-vinbigdata-chest-x-ray) I created a prototype for putting the training data into `shelve` (like pickle, but allows parallel reading and access via dictionary keys - a persistent dictionary, really - see [the documentation](https://docs.python.org/3/library/shelve.html)). I resized the images so that the shortest dimension is at least 600 pixels, i.e. the images are probably still larger than you'd use as a model input, but the file we end up saving ends up being small enough for our purposes - primarily through saving the image itself as uint8 (which hopefully does not loose meaningful information).

I did for the moment retaining the bounding boxes and class labels (and the information on which radiologist rad_id) for all radiologists gave this annotation. That way, one can then decide what to do - whether you want to somehow aggregate the annotations (e.g. like in [this notebook](https://www.kaggle.com/sreevishnudamodaran/vinbigdata-fusing-bboxes-coco-dataset)) or use the different annotations as data augmentations e.g. by using different ones in different epochs.

In this notebook I try to compare the different approaches in terms of speed. I'm going to get 8 batches of data to make sure that parallelized approaches don't get penalized too much for the initial set-up (as they would be, if we e.g. just looked at one batch). It turns out the `shelve` approach is fastest out of the approaches I tried - let me know, if I missed obvious alternatives.

In [None]:
import os
import re
import pandas as pd
import numpy as np
import torch
import albumentations
import random
import shelve
from PIL import Image
import re
import itertools
import matplotlib.pyplot as plt
import fastcore
from fastcore.parallel import ProcessPoolExecutor

In [None]:
train = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv')
list_of_images = np.sort(np.unique(train['image_id'].values))

# Define augmentations
We will need some augmentations that we want to apply to the images we load. We'll use some light augmentation. Omitting this might make our comparisons here unrealistic.

In [None]:
INPUT_SHAPE = 224

augs = albumentations.Compose([
            #albumentations.SmallestMaxSize(max_size=600), # Not necessary, if we are reading the shelve file (already done there)
            albumentations.RandomResizedCrop(INPUT_SHAPE, INPUT_SHAPE, scale=(0.9, 1.0)),            
            albumentations.ShiftScaleRotate(rotate_limit=10, p=0.5),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            albumentations.ISONoise()], 
    bbox_params=albumentations.BboxParams(format='pascal_voc'))

# Define dataset and dataloader

This illustrates how slow this approach relying on `__getitem__`, skip to the next section for a much better approach.

In [None]:
class DetectionDataset:
    def __init__(self, image_ids, augmentations=None):
        self.image_ids, self.augmentations = image_ids, augmentations
    def __len__(self):
        return len(self.image_ids)
    def __getitem__(self, item):        
        with shelve.open('../input/eda-dicom-reading-vinbigdata-chest-x-ray/training_data.db', flag='r', writeback=False) as myshelf:
            tmpdict = myshelf[self.image_ids[item]]        
        rad_id = tmpdict['rad_id']        
        which_rad_id = random.sample( list(np.unique(rad_id)), k=1)[0]
        which_indices = [idx for idx, val in enumerate(rad_id) if val==which_rad_id]        
        
        image = np.stack([tmpdict['image']]*3).transpose(1,2,0)        
        bboxes = tmpdict['bboxes'][which_indices]
        class_labels = tmpdict['class_labels'][which_indices]        
        transformed = self.augmentations(image=image, bboxes=bboxes, class_labels=class_labels)
        #{'image': transformed['image'], 'bboxes': transformed['bboxes'], 'class_labels': transformed['class_labels']}
        
        return {'image': transformed['image'], 'bboxes': transformed['bboxes'], 'class_labels': transformed['class_labels']}

detdatset = DetectionDataset(image_ids=list_of_images, augmentations=augs)
example_loader = torch.utils.data.DataLoader(detdatset, batch_size=64, shuffle=True, num_workers=0, collate_fn=lambda x: x) #collate_fn=

# How fast is the DataLoader?
Let's test this by getting 8 batches, repeatedly.

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(example_loader, 8)):
    len(batch)

That's not great and a lot worse than reading individual image files (see further down), so we can do better.

# Define our own DataLoader that loads a whole batch
One of the stupid things in what I did above is that I keep opening the `shelve` file repeatedly for reading that's pretty inefficient and we can speed things up a lot by reading a whole batch at once! For the same reason this also results in a speedup vs. loading individual image files.

To do better, let's use a DataLoader that loads a whole batch in one opening of the file. To keep things modular and easier to modify, I'll first define a batched DataSet class and then use that in my DataLoader. That will makes things easier to modify/tinker with.

In [None]:
class BatchedDataSet:
    
    def __init__(self, image_ids, augmentations=None):
        self.image_ids, self.augmentations = image_ids, augmentations
        
    def __len__(self):
        return len(self.image_ids)
    
    def getbatch(self, items):
        what_images = [self.image_ids[item] for item in items]
        with shelve.open('../input/eda-dicom-reading-vinbigdata-chest-x-ray/training_data.db', flag='r', writeback=False) as myshelf:
            tmpdict =  { key: myshelf[key] for key in what_images }
        
        for what_image in what_images:
            rad_id = tmpdict[what_image]['rad_id']        
            which_rad_id = random.sample( list(np.unique(rad_id)), k=1)[0]
            which_indices = [idx for idx, val in enumerate(rad_id) if val==which_rad_id]        

            image = np.stack([tmpdict[what_image]['image']]*3).transpose(1,2,0)        
            bboxes = tmpdict[what_image]['bboxes'][which_indices]
            class_labels = tmpdict[what_image]['class_labels'][which_indices]        
            transformed = self.augmentations(image=image, bboxes=bboxes, class_labels=class_labels)
            tmpdict[what_image] = {'image': transformed['image'], 
                                   'bboxes': transformed['bboxes'], 
                                   'class_labels': transformed['class_labels']}        
        return tmpdict

Now let's get to the actual DataLoader. This is what a very basic PyTorch DataLoader that loads data in batches from the `shelve` file looks like. The class needs to have the following methods:
* `__init__`: for initializing an instance of the class, when we initialize it, we need to provide basic information such as the list of `image_id` values we will select from, the batch size and so on.
* `__iter__`: Determines what happens at the start of each epoch, i.e. each time we start over in terms of going through the whole data. The main thing we need to do here is shuffling the batches if requested.
* `__next__`: Generate a batch of data (or the **next** batch of data), this part is what makes this DataLoader different from your basic PyTorch DataLoader that would call `__getitem__` on the DataSet `batch_size` times per batch. Here, we instead get the whole batch at once, which leads to a major speedup, if (as is the case here) we can load a batch very efficiently all at once from a single file.
* `__len__`: When we loop over the DataLoader for a training epoch, it's useful to know how many batches it returns.

In [None]:
class BatchedDataLoader:

    def __init__(self, dataset, batch_size: int=32, shuffle: bool=False, drop_last: bool=False):
        self.dataset, self.dataset_len, self.batch_size, self.shuffle, self.drop_last = dataset, len(dataset), batch_size, shuffle, drop_last

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if (remainder > 0) & (drop_last==False):
            n_batches += 1
            self.n_items = self.dataset_len
        else:
            self.n_items = self.n_batches * self.batch_size
        self.n_batches = n_batches
        self.batch_list = [i for i in range(self.n_items)]

    def __iter__(self):
        if self.shuffle:
            ridx = torch.randperm(self.dataset_len)
            self.batch_list = [ridx[i] for i in range(self.n_items)]
            
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.n_items:
            raise StopIteration        
        batch = self.dataset.getbatch( items = self.batch_list[self.i:min(self.n_items, self.i+self.batch_size)] )
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches

In [None]:
batcheddataset = BatchedDataSet(image_ids=list_of_images, augmentations=augs)
train_batches = BatchedDataLoader(dataset=batcheddataset, batch_size=64, shuffle=True)

# How fast is our own BatchLoader?

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(train_batches, 8)):
    len(batch)

Great! That's a massive speedup vs. loading each item one-at-a-time using the `__getitem__` of the data set we had defined as part of a standard PyTorch DataLoader. It's also a speedup vs. loading individual images without parallelization (see below), but there's not much in it vs. loading images with 2 or 4 workers. Clearly, we should parallelize this to outperform parallel processed image loading. 

# What does our BatchLoader return?

Let's get a batch of data and see what the DataLoader produced.

In [None]:
example_batch = next(iter(train_batches))

Here's what the first two entries of a batch read from the shelve file looks like:

In [None]:
example_keys = [key for key in example_batch.keys()]
for key in example_keys[0:2]:
    print(example_batch[key])

In [None]:
plt.imshow(example_batch[example_keys[0]]['image']);

In [None]:
example_batch[example_keys[0]] #['bboxes']

In [None]:
example_batch[example_keys[0]]['class_labels']

# Let's add parallel processing in our DataLoader


As said in the previous Section, one obvious feature to add to our batch-DataLoader is parallel loading (i.e. multiple workers that are each responsible for one batch) and adding a `num_workers: int=0` option for that. The task of getting the data is clearly very parallelizable and by using `shelve` there's no problem with multiple workers trying to read from our file.

Now, we'll use the `ProcessPoolExecutor` from `fastcore.parallel` - as described in Chapter 19 __A fastai Learner from Scratch__ of the great **Deep Learning for Coders with fastai & PyTorch** [book](https://www.amazon.com/Deep-Learning-Coders-fastai-PyTorch/dp/1492045527) - in order to parallelize our DataLoader.

In [None]:
def getbatch(items, augmentations):    
    with shelve.open('../input/eda-dicom-reading-vinbigdata-chest-x-ray/training_data.db', flag='r', writeback=False) as myshelf:
        tmpdict =  { key: myshelf[key] for key in items }
        
    for what_image in items:
        rad_id = tmpdict[what_image]['rad_id']        
        which_rad_id = random.sample( list(np.unique(rad_id)), k=1)[0]
        which_indices = [idx for idx, val in enumerate(rad_id) if val==which_rad_id]        

        image = np.stack([tmpdict[what_image]['image']]*3).transpose(1,2,0)        
        bboxes = tmpdict[what_image]['bboxes'][which_indices]
        class_labels = tmpdict[what_image]['class_labels'][which_indices]        
        transformed = augmentations(image=image, bboxes=bboxes, class_labels=class_labels)
        tmpdict[what_image] = {'image': transformed['image'], 
                               'bboxes': transformed['bboxes'], 
                               'class_labels': transformed['class_labels']}        
    return tmpdict
    
class BatchedParallelDataLoader:

    def __init__(self, image_ids, augmentations, batch_size: int=32, shuffle: bool=False, drop_last: bool=False, n_workers: int=1):
        self.dataset_len, self.batch_size, self.shuffle, self.drop_last, self.n_workers = len(image_ids), batch_size, shuffle, drop_last, n_workers
        self.image_ids, self.augmentations = image_ids, augmentations

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if (remainder > 0) & (drop_last==False):
            n_batches += 1
            self.n_items = self.dataset_len
        else:
            self.n_items = self.n_batches * self.batch_size
        self.n_batches = n_batches
        self.batch_list = [i for i in range(self.n_items)]

    def __iter__(self):
        if self.shuffle:
            ridx = torch.randperm(self.dataset_len)
            self.batch_list = [ridx[i] for i in range(self.n_items)]
            
        chunks = [ list(self.image_ids[ self.batch_list[i:min(self.n_items, i+self.batch_size)]]) for i in range(0, self.n_items, self.batch_size) ]
        with ProcessPoolExecutor(self.n_workers) as ex:
            yield from ex.map(getbatch, chunks, augmentations=self.augmentations)

    def __len__(self):
        return self.n_batches
    
bpdl = BatchedParallelDataLoader(n_workers=4, image_ids=list_of_images, augmentations=augs)

# How fast is the parallelized version of our DataLoader?

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(bpdl, 8)):
    len(batch)

Woohoo! That clearly beats loading images in parallel with the standard PyTorch parallel DataLoader!

# What if we only have 2 workers?
This is obviously an important question, because [currently](https://www.kaggle.com/docs/notebooks) a Kaggle GPU notebook gets 2 CPUs (although TPU notebooks get 4).

In [None]:
bpdl = BatchedParallelDataLoader(n_workers=2, image_ids=list_of_images, augmentations=augs)

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(bpdl, 8)):
    len(batch)

Wow! We did not actually gain that much by having 4 vs. 2 workers before. I would guess though that part of it might be how many batches we generate and maybe if we did a lot more we'd gain something with more workers.

# How does that compare to reading .dicom files on-the-fly?

Here I'll follow the approach I also took in the [notebook](https://www.kaggle.com/bjoernholzhauer/eda-dicom-reading-vinbigdata-chest-x-ray#7.-Creating-fast-to-read-shelve-file) that created the shelve file.

In [None]:
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import fastcore
from fastcore.parallel import parallel

# Using function from another great notebook: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
def read_xray(path, voi_lut = True, fix_monochrome = True):
    dicom = pydicom.read_file(path)
    
    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)        
    else:
        data = dicom.pixel_array
               
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
        
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    data = np.stack([data]*3).transpose(1,2,0)
        
    return data

# This function will read a .dicom file, turn the smallest side to 600 pixels, and then save the additional annotations together with the image into a dictionary
def get_and_save(x):
    
    idx=x[0] 
    image_id=x[1]
    
    transform = albumentations.Compose([albumentations.SmallestMaxSize(max_size=600, always_apply=True)], bbox_params=albumentations.BboxParams(format='pascal_voc')) 
    img = read_xray(path='../input/vinbigdata-chest-xray-abnormalities-detection/train/' + image_id + '.dicom')
    rad_id = np.array([int(re.findall(r'\d+', rad_id)[0]) for rad_id in train.loc[train['image_id']==image_id, 'rad_id'].values], dtype=np.int8)
    class_labels = train.loc[train['image_id']==image_id, 'class_id'].values    
    bboxes = [list(row) for rowid, row in train.loc[train['image_id']==image_id, ['x_min', 'y_min', 'x_max', 'y_max', 'class_id']].fillna({'x_min':0, 'y_min':0, 'x_max':1, 'y_max':1}).astype(np.int16).iterrows() ]
    
    transformed = transform(image=img,
                            bboxes=bboxes, 
                            class_labels=class_labels)
    
    return dict(image_id=image_id,
                image=transformed['image'][:,:,0],
                rad_id=rad_id,
                bboxes=np.array(transformed['bboxes'], dtype=np.float32), 
                class_labels=transformed['class_labels'].astype(np.int8))

In [None]:
%%timeit
for batchno in range(8):
    example_batch3 = parallel(get_and_save, [(idx, image_id) for idx, image_id in enumerate(list_of_images[0:64])], n_workers=4, progress=False)

As we can see that this is way slower than either reading images or reading shelve files.

# How does that compare to reading images?

This is not quite a perfect comparison, because [the Kaggle dataset I used](https://www.kaggle.com/awsaf49/vinbigdata-512-image-dataset) has 512 by 512 images (instead of 600 by aspect-ratio-preserving dimension) and I could not figure out how to get correct bounding boxes for these images. So, some of the augmentation work gets skipped. Hence, the way we do the comparison makes loading individual images look more favorable than it is.

In [None]:
INPUT_SHAPE = 224

augs = albumentations.Compose([            
            albumentations.RandomResizedCrop(INPUT_SHAPE, INPUT_SHAPE, scale=(0.9, 1.0)),            
            albumentations.ShiftScaleRotate(rotate_limit=10, p=0.5),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            albumentations.ISONoise()] #,bbox_params=albumentations.BboxParams(format='pascal_voc')
)

class ImageDataSet:
    def __init__(self, image_ids, train, augmentations=None):
        self.image_ids, self.train, self.augmentations = image_ids, train, augmentations
    def __len__(self):
        return len(self.image_ids)
    def __getitem__(self, item):
        image = np.array(Image.open('../input/vinbigdata-512-image-dataset/vinbigdata/train/' + self.image_ids[item] + '.png').convert('RGB'))
        
        rad_id = self.train.loc[train['image_id']==self.image_ids[item], 'rad_id'].values
        class_labels = train.loc[train['image_id']==self.image_ids[item], 'class_id'].values    
        bboxes = [list(row) for rowid, row in train.loc[train['image_id']==self.image_ids[item], ['x_min', 'y_min', 'x_max', 'y_max', 'class_id']].fillna({'x_min':0, 'y_min':0, 'x_max':1, 'y_max':1}).astype(np.int16).iterrows() ]
                
        which_rad_id = random.sample( list(np.unique(rad_id)), k=1)[0]        
        which_indices = [idx for idx, val in enumerate(rad_id) if val==which_rad_id]
        bboxes = np.array(bboxes)[which_indices]
        class_labels = np.array(class_labels)[which_indices]
        transformed = self.augmentations(image=image) #, bboxes=bboxes, class_labels=class_labels)
        
        #return {'image': np.transpose(transformed['image'], (2,0,1)), 'bboxes': transformed['bboxes'], 'class_labels': transformed['class_labels']}    
        return np.transpose(transformed['image'], (2,0,1))

imageds = ImageDataSet(image_ids=list_of_images, 
                       train=pd.read_csv('../input/vinbigdata-512-image-dataset/vinbigdata/train.csv'), 
                       augmentations=augs)
example_loader2 = torch.utils.data.DataLoader(imageds, batch_size=64, shuffle=True, num_workers=4, collate_fn=lambda x: x)

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(example_loader2, 8)):
    len(batch)

In [None]:
example_loader2a = torch.utils.data.DataLoader(imageds, batch_size=64, shuffle=True, num_workers=1, collate_fn=lambda x: x)

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(example_loader2a, 8)):
    len(batch)

In [None]:
example_loader2b = torch.utils.data.DataLoader(imageds, batch_size=64, shuffle=True, num_workers=0, collate_fn=lambda x: x)

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(example_loader2b, 8)):
    len(batch)

In [None]:
example_loader2c = torch.utils.data.DataLoader(imageds, batch_size=64, shuffle=True, num_workers=2, collate_fn=lambda x: x)

In [None]:
%%timeit
for batchno, batch in enumerate(itertools.islice(example_loader2c, 8)):
    len(batch)

# How can we use it with a model?