In [None]:
# this version has labels
# !pip install --quiet torchio

# this version has no labels
!pip install git+https://github.com/laynr/torchio.git@plot_animation

In [None]:
import pandas as pd
import torchio as tio
from pathlib import Path
import multiprocessing as mp
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

plt.rcParams["figure.figsize"] = (12, 10)

out_dir      = Path.cwd() / "dataset"
data_dir     = Path('/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification')
training_dir = data_dir / 'train'

In [None]:
# get patients
def get_patients(patients_dir, demo=True):
    dir_list = training_dir.glob('*')
    patients = [x.name for x in dir_list if x.is_dir()]
    
    if demo:
        patients = patients[:5]

    # Remove cases the competion host said to exclude 
    # https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/262046
    if '00109' in patients: patients.remove('00109')
    if '00123' in patients: patients.remove('00123')
    if '00709' in patients: patients.remove('00709')
        
    return patients

patients = get_patients(training_dir, demo=False)

In [None]:
def data_preparation(patients):
    subjects  = []
    labels_df = pd.read_csv(data_dir / 'train_labels.csv', index_col=0)
    # loop thru patients
    for patient in patients:
        # get label for patient
        label = labels_df._get_value(int(patient), 'MGMT_value')
        # create subject object for each patient
        subject = tio.Subject(
            BraTS21ID=patient,
            MGMT_value=label,
            FLAIR=tio.ScalarImage(training_dir / patient / 'FLAIR',),
            T1w=tio.ScalarImage(training_dir / patient / 'T1w',),
            T1wCE=tio.ScalarImage(training_dir / patient / 'T1wCE',),
            T2w=tio.ScalarImage(training_dir / patient / 'T2w',),
         )
        # add subject object to subjects list
        subjects.append(subject)

    # preprocessing transforms
    preprocessing_transforms = tio.Compose([
        tio.ToCanonical(),
        tio.Resample(1, image_interpolation='bspline'),
        tio.Resample('T1w', image_interpolation='nearest'),
        #tio.RescaleIntensity((-1, 1)),
        tio.CropOrPad((280, 280, 264)),
        #tio.CropOrPad((128, 128, 64))
        #tio.OneHot(),
    ])
        

    # create datasets from transformed subjects
    dataset = tio.SubjectsDataset(subjects, transform=preprocessing_transforms)
    print(f'patients :{len(dataset)}')

    
    return dataset

# create training and validation datasets    
dataset = data_preparation(patients) 

In [None]:
# create png
def preprocess_dataset(dataset, out_dir, parallel=True):
    if parallel:
        loader = DataLoader(
            dataset,
            num_workers=mp.cpu_count(),
            collate_fn=lambda x: x[0],
        )
        iterable = loader
    else:
        iterable = dataset
    for subject in tqdm(iterable):
        if 0 == subject["MGMT_value"]:
            class_dir = out_dir / '0'
        if 1 == subject["MGMT_value"]:
            class_dir = out_dir / '1'
            
        class_dir.mkdir(parents=True, exist_ok=True)
        filename = class_dir / f'{subject["BraTS21ID"]}.png'
        subject.plot(reorient=False, output_path=filename, show=False)

# save as png
preprocess_dataset(dataset, out_dir)
