reference

- https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling
- https://www.kaggle.com/ayuraj/brain-tumor-eda-and-interactive-viz-with-w-b
- https://www.kaggle.com/mikecho/rsna-miccai-monai-ensemble?scriptVersionId=74508923

In [None]:
import os
import sys 
import json
from glob import glob
import random
import collections
import time
import re

import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

In [None]:
DATA_PATH = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
MRI_TYPES = ['FLAIR','T1w','T1wCE','T2w']
SIZE = 256
NUM_IMAGES = 64
SEED = 42

In [None]:
train_df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
train_df.head()

In [None]:
train_patients = glob(os.path.join(DATA_PATH, 'train/*'))
test_patients = glob(os.path.join(DATA_PATH, 'test/*'))

In [None]:
train_patients[0]

In [None]:
def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)

# key: patient id , values: image paths for 4 types (dictionary, key: MRI type, value: image paths)
def get_dicom_data(split='train'):
    '''
    dicoms = {
        '00688' (patient_id) :
            {
                'FLAIR' : [image paths ...]
                'T1w' : [image paths ...]
                'T1wCE' : [image paths ...]
                'T2w' : [image paths ...]
            }
        ...
    }
    '''
    
    assert split == 'train' or split == 'test'
    
    dicoms = {}

    for patient in glob(os.path.join(DATA_PATH, f'{split}/*')):
        patient_id = patient.split('/')[-1]

        d = {}
        for t in MRI_TYPES:
            t_images = glob(os.path.join(patient, f'{t}/*'))
            d[f'{t}'] = natural_sort(t_images)

        dicoms[f'{patient_id}'] = d
    
    return dicoms

train_dicoms = get_dicom_data('train')
test_dicoms = get_dicom_data('test')

In [None]:
sample_patient = list(train_dicoms.keys())[0]
print(f'patient id - {sample_patient}')
for i, v in train_dicoms[sample_patient].items():
    print(f'{i} : {len(v)}')

## Load 2D images

In [None]:
# Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
def read_mri(path, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(path)
    # VOI LUT (if available by DICOM device) is used to transform raw DICOM data to 
    # "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
               
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
        
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
        
    return data

In [None]:
def visualize_sample(patient_id, slice_ratio=0.5):
    
    dicoms = train_dicoms[patient_id]
    
    plt.figure(figsize=(16, 5))
    
    for i, t in enumerate(MRI_TYPES, 1):
        slice_idx = int(len(dicoms[t]) * slice_ratio) - 1
        sample_dicom = dicoms[t][slice_idx]
        data = read_mri(sample_dicom)
        plt.subplot(1, 4, i)
        plt.imshow(data, cmap="gray")
        plt.title(f"{t}", fontsize=16)
        plt.axis("off")

    mgmt_value = train_df[train_df.BraTS21ID == int(patient_id)].MGMT_value.item()
    plt.suptitle(f"MGMT_value: {mgmt_value}", fontsize=16)
    plt.show()
    
visualize_sample(sample_patient)

# Load 3D images

In [None]:
from matplotlib import animation, rc
rc('animation', html='jshtml')


def create_animation(ims):
    fig = plt.figure(figsize=(6, 6))
    plt.axis('off')
    im = plt.imshow(ims[0], cmap="gray")

    def animate_func(i):
        im.set_array(ims[i])
        return [im]

    return animation.FuncAnimation(fig, animate_func, frames = len(ims), interval = 1000//24)

def load_dicom_line(t_paths):
    images = []
    for filename in t_paths:
        data = read_mri(filename)
        if data.max() == 0:
            continue
        images.append(data)
        
    return images

images = load_dicom_line(train_dicoms[sample_patient]['FLAIR'])
create_animation(images)

In [None]:
def load_dicom_image(path, img_size=SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom) if voi_lut else dicom.pixel_array
    else:
        data = dicom.pixel_array
        
    if np.min(data)==np.max(data):
        data = np.zeros((img_size,img_size))
        return data
        
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
        
    data = cv2.resize(data, (img_size, img_size))
    return data

def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    files = train_dicoms[scan_id][mri_type]
    
    images = []
    for filename in files:
        data = load_dicom_image(filename)
        if data.max() == 0:
            continue
        images.append(data)
    
    if num_imgs > 0:
        assert len(images) >= num_imgs, f'len(images)({len(images)}) is less than num_imgs({num_imgs})'

        every_nth = len(images) / num_imgs
        indexes = [min(int(round(i*every_nth)), len(images)-1) for i in range(0,num_imgs)]
        selected_images = [images[i] for i in indexes]
    else:
        selected_images = images
    
    img3d = np.stack(selected_images).T
    
    img3d = img3d - np.min(img3d)
    if np.max(img3d) != 0:
        img3d = img3d / np.max(img3d)
    
    return np.expand_dims(img3d,0)

b = load_dicom_images_3d(sample_patient)
print(b.shape)
print(np.min(b), np.max(b), np.mean(b), np.median(b))

In [None]:
mri_anim = [(b[0,:,:,i] * 255).astype(np.uint8) for i in range(NUM_IMAGES)]
create_animation(mri_anim)

In [None]:
from IPython.display import clear_output
import time

for mri_image in mri_anim:
    plt.imshow(mri_image, cmap='gray')
    plt.show()
    time.sleep(0.01)
    clear_output(wait=True)

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

## Create GIF using Imageio and log WanDB

In [None]:
def convert_3d_to_list(arr):
    return [(arr[0,:,:,i] * 255).astype(np.uint8) for i in range(arr.shape[3])]

In [None]:
import imageio
import wandb
wandb.login()

In [None]:
CONFIG = {
    'IMG_SIZE': 224, 
    'NUM_FRAMES': 14,
    'competition': 'rsna-miccai-brain', 
}

In [None]:
run = wandb.init(
    entity = 'monet-kaggle',
    project='brain-tumor-viz',
    config=CONFIG,
    job_type='vis-dataset-tables')

In [None]:
patient_ids = []
for patient_id in train_dicoms.keys():
    if patient_id in ['00109', '00123', '00709']:
        continue
    patient_ids.append(patient_id)
len(patient_ids)

In [None]:
data_at = wandb.Table(columns=['patent_id', 'target', 'FLAIR', 'T1w', 't1wCE', 'T2w'])

for i, patient_id in enumerate(patient_ids):
    os.makedirs('tables-gif/', exist_ok=True)
    mgmt_value = train_df[train_df.BraTS21ID == int(patient_id)].MGMT_value.item()
    
    for j, mri_type in enumerate(MRI_TYPES):
        arr_3d = load_dicom_images_3d(patient_id,
                                      num_imgs=0,
                                      img_size=CONFIG['IMG_SIZE'],
                                      mri_type=mri_type,
                                      split="train")
        frames = convert_3d_to_list(arr_3d)
        imageio.mimsave(f'tables-gif/out-{patient_id}-{j}.gif', frames)
    
    data_at.add_data(
        patient_id,
        mgmt_value,
        wandb.Image(f'tables-gif/out-{patient_id}-0.gif'),
        wandb.Image(f'tables-gif/out-{patient_id}-1.gif'),
        wandb.Image(f'tables-gif/out-{patient_id}-2.gif'),
        wandb.Image(f'tables-gif/out-{patient_id}-3.gif'),
    )

wandb.log({'MRI Sequencing Dataset' : data_at})
wandb.finish()

## train / test splits

In [None]:
# because of some missing MRIs, 3 samples could be excluded
samples_to_exclude = [109, 123, 709]

train_df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
print("original shape", train_df.shape)
train_df = train_df[~train_df.BraTS21ID.isin(samples_to_exclude)]
print("new shape", train_df.shape)
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=SEED, 
    stratify=train_df["MGMT_value"],
)

In [None]:
df_train.tail()

## Model and training classes

In [None]:
class BrainDataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, label_smoothing=0.01, split="train", augment=False):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.split = split
        self.augment = augment
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)
        else:
            if self.augment:
                rotation = np.random.randint(0,4)
            else:
                rotation = 0

            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train", rotate=rotation)

        if self.targets is None:
            return {"X": torch.tensor(data).float(), "id": scan_id}
        else:
            y = torch.tensor(abs(self.targets[index]-self.label_smoothing), dtype=torch.float)
            return {"X": torch.tensor(data).float(), "y": y}
