In [None]:
# install pull request version of torchio https://github.com/fepegar/torchio/pull/683
!pip install git+https://github.com/laynr/torchio.git@681-add-plot_volume-indices-parameter

In [None]:
import imageio
import pandas as pd
import torchio as tio
from pathlib import Path
import multiprocessing as mp
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

dims = (280, 280, 264)
plt.rcParams["figure.figsize"] = (12, 10)

out_dir      = Path.cwd()
data_dir     = Path('/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification')
training_dir = data_dir / 'train'

In [None]:
# get patients
def get_patients(patients_dir, demo=False):
    dir_list = training_dir.glob('*')
    patients = [x.name for x in dir_list if x.is_dir()]
    
    if demo:
        patients = patients[:2]

    # Remove cases the competion host said to exclude 
    # https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/262046
    if '00109' in patients: patients.remove('00109')
    if '00123' in patients: patients.remove('00123')
    if '00709' in patients: patients.remove('00709')
        
    return patients

patients = get_patients(training_dir, demo=True)

In [None]:
# create dataset with synchronized MRIs
def data_preparation(patients):
    subjects  = []
    labels_df = pd.read_csv(data_dir / 'train_labels.csv', index_col=0)
    # loop thru patients
    for patient in patients:
        # get label for patient
        label = labels_df._get_value(int(patient), 'MGMT_value')
        # create subject object for each patient
        subject = tio.Subject(
            BraTS21ID=patient,
            MGMT_value=label,
            FLAIR=tio.ScalarImage(training_dir / patient / 'FLAIR',),
            T1w=tio.ScalarImage(training_dir / patient / 'T1w',),
            T1wCE=tio.ScalarImage(training_dir / patient / 'T1wCE',),
            T2w=tio.ScalarImage(training_dir / patient / 'T2w',),
         )
        # add subject object to subjects list
        subjects.append(subject)

    # preprocessing transforms
    preprocessing_transforms = tio.Compose([
        tio.ToCanonical(),
        tio.Resample(1, image_interpolation='bspline'),
        tio.Resample('T1w', image_interpolation='nearest'),
        tio.CropOrPad(dims),
    ])
        

    # create datasets from transformed subjects
    dataset = tio.SubjectsDataset(subjects, transform=preprocessing_transforms)
    print(f'patients :{len(dataset)}')
    
    return dataset

dataset = data_preparation(patients) 

In [None]:
# create slices
def preprocess_dataset(dataset, out_dir, parallel=True):
    if parallel:
        loader = DataLoader(
            dataset,
            num_workers=mp.cpu_count(),
            collate_fn=lambda x: x[0],
        )
        iterable = loader
    else:
        iterable = dataset
        
    for subject in tqdm(iterable):
        slices_dir = out_dir / f'{subject["BraTS21ID"]}' / 'slices' 
        slices_dir.mkdir(parents=True, exist_ok=True)
        for x in range(dims[2]):
            filename = slices_dir / f'{x:03d}_{subject["BraTS21ID"]}_{subject["MGMT_value"]}.png'
            subject.plot(reorient=False, indices= (x,x,x), output_path=filename, show=False)   

preprocess_dataset(dataset, out_dir, parallel=True)

In [None]:
# create gifs
def create_gifs(patients, out_dir):
    image_paths = []
    for patient in patients:
        slices_dir = out_dir / f'{patient}' / 'slices'
        gif_dir   = out_dir / f'{patient}' / 'gif'
        gif_dir.mkdir(parents=True, exist_ok=True)
    
        slices = slices_dir.glob('*')
        filenames = [x for x in slices if x.is_file()]
        filenames.sort()

        images = []
        for filename in filenames:
            images.append(imageio.imread(filename))
        imageio.mimsave(gif_dir / f'{patient}.gif', images)
        
        image_paths.append(gif_dir / f'{patient}.gif')
        return image_paths
                
image_paths = create_gifs(patients, out_dir)

In [None]:
# display gifs
from IPython.display import Image
def display_gifs(image_paths):
    for image_path in image_paths:
        display(Image(image_path))
        
display_gifs(image_paths)