In [None]:
!pip install fastai==2.5.1

### Introduction
This notebook is to introduce how to build a customize fastai's DataLoader for a task with item is sequence of images (MRI Images, Video, ...) 

With fastai's DataLoader, you can leverage many of the neat fastai functionalities like *item transformation* or *batch transformation* or my favorite part: *show batch*. This great feature help you to visualize how your raw data is tranformed to the input of your model, and it is especially helpful to debug or present to others. 

You can checkout the official tutorial for *Using fastai on a custom new task* here: https://docs.fast.ai/tutorial.siamese.html

In [None]:
from fastai.vision.all import *
import fastai
import random
import PIL

In [None]:
## For experimenting purpose, only type 'T1W' is chosen and sequence with more than 16 images
mri_type = 'T1w'
min_subset = 16
patients_seq_folder = [patient/mri_type for patient in Path('../input/rsna-miccai-png/train').ls() if (patient/mri_type).exists() and len((patient/mri_type).ls()) > min_subset]

In [None]:
patients_seq_folder[:5]

In [None]:
labels = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
labels.head()

In [None]:
def open_image(fname, size=224):
    img = PIL.Image.open(fname)
    img = img.resize((size, size))
    t = torch.Tensor(np.array(img))
    return t.float()/255.0

To help fastai know how to show your batch, each item (including x and y) need to be an instance of a class which has a *show* function. And by subclassing fastuple, fastai will know how to apply the transformation for each element based on their type (For example applying Resize on to PILImage and not the string)

In [None]:
class SeqImage(fastuple):
    def show(self, ctx=None, **kwargs):
        *imgs, label = self
        if not isinstance(imgs[0], Tensor):
            imgs = [tensor(img).permute(2,0,1) for img in imgs]
        img_cat = torch.cat(imgs, dim=2)
        return show_image(img_cat, figsize=(20,20), title=label, ctx=ctx, **kwargs)

In [None]:
# splits the data to training set and validation set
splits = RandomSplitter()(patients_seq_folder)

In [None]:
train_folders, valid_folders = L(patients_seq_folder)[splits[0]], L(patients_seq_folder)[splits[1]]

We will try to show an SeqImage as below

In [None]:
files_test = train_folders[0].ls()[:16]

In [None]:
imgs_label = [PILImage.create(file) for file in files_test]
imgs_label.append(0)

In [None]:
s = SeqImage(imgs_label)

In [None]:
tst = Resize(224)(s)
tst = ToTensor()(tst)
tst.show();

Then encodes function in Transform class used to apply the transformation to each item (similar to *forward* in Pytorch modules) 

In [None]:
class SeqTransform(Transform):
    def encodes(self, folder):
        files = folder.ls()
        files = sorted(random.sample(files, min_subset), key=lambda path: int(path.stem.split('-')[1]))
        imgs = [PILImage.create(file) for file in files]
        label = labels[labels['BraTS21ID']==int((folder).parent.name)]['MGMT_value'].values[0]
        return SeqImage(*imgs, label)

In [None]:
tfm = SeqTransform()

In [None]:
tls = TfmdLists(patients_seq_folder, tfm, splits=splits)


In [None]:
show_at(tls.valid, 3);

In [None]:
dls = tls.dataloaders(after_item=[Resize(224), ToTensor], 
                      after_batch=[IntToFloatTensor])

Define a show_batch function as below to make show_batch works. x is a batch of SeqImage item

In [None]:
@typedispatch
def show_batch(x:SeqImage, y, samples, ctxs=None, max_n=6, nrows=None, ncols=1, figsize=None, **kwargs):
    if figsize is None: figsize = (ncols*6, max_n//ncols * 3)
    if ctxs is None: ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
    for index,ctx in enumerate(ctxs): 
        imgs_ls = [x[i][index] for i in range(min_subset)]
        label = int(x[-1][index])
        SeqImage(*imgs_ls, label).show(ctx=ctx)

In [None]:
b = dls.one_batch()

In [None]:
dls.show_batch(figsize=(20,20))