# RSNA-MICCAI Brain Tumor Radiogenomic Classification

In [None]:
import os
import json
import glob
import random
import collections

import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import random
from skimage.transform import resize


train_df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
test_df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv')


TYPES = ["FLAIR",]
WHITE_THRESHOLD = 10 # out of 255
EXCLUDE = [109, 123, 709]

In [None]:
def load_dicom(path, size = 64):
    ''' 
    Reads a DICOM image, standardizes so that the pixel values are between 0 and 1, then rescales to 0 and 255
    
    Note super sure if this kind of scaling is appropriate, but everyone seems to do it. 
    '''
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return cv2.resize(data, (size, size))


def get_all_image_paths(brats21id, image_type, folder='train'): 
    '''
    Returns an arry of all the images of a particular type for a particular patient ID
    '''
    assert(image_type in TYPES)
    
    patient_path = os.path.join(
        "../input/rsna-miccai-brain-tumor-radiogenomic-classification/%s/" % folder, 
        str(brats21id).zfill(5),
    )

    return sorted(
        glob.glob(os.path.join(patient_path, image_type, "*")), 
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    
def get_white_area(image, threshold): 
    '''
    Given an image, this function computes the fraction of pixels for which the greyscale value is at least threshold. 
    '''
    return np.sum(image > threshold) / image.shape[0] / image.shape[1]

def filter_images(images, quantiles):
    '''
    Filters a list of images based on the quantiles given. 
    
    For example, if quantiles = [0.33, 0.66], then we would find the photos with 0.33 and 0.66 of the maximum white area.
    
    Note that the images get bigger (up to the maximum) and then get smaller. Thus, we would extract 5 photos
    corresponding to [0.33, 0.66, 1, 0.66, 0.33]
    '''
    quantiles.sort()
    white_areas = [get_white_area(im, WHITE_THRESHOLD) for im in images]
    max_white_area = max(white_areas)
    
    middle_index = white_areas.index(max_white_area)
    
    indices_to_keep = []
    
    q = 0
    for i in range(middle_index): 
        if q >= len(quantiles): 
            break
        if white_areas[i] >= quantiles[q] * max_white_area: 
            indices_to_keep.append(i)
            q += 1
    
    indices_to_keep.append(middle_index)
    
    q = len(quantiles) - 1
    for i in range(middle_index, len(white_areas)): 
        if q < 0: 
            break
        if white_areas[i] <= quantiles[q] * max_white_area: 
            indices_to_keep.append(i)
            q -= 1
    
    # expected # of images vs. actual number of images
    difference = len(quantiles) * 2 + 1 - len(indices_to_keep)
    
    if difference > 0: 
        indices_to_keep += [i] * difference
    
    return indices_to_keep

# Plotting Functions

def plot_image_white(brats21id, image_type):
    images_paths = get_all_image_paths(brats21id, image_type)
    images = [load_dicom(im) for im in images_paths]
    plt.xlabel('Image Number')
    plt.ylabel('White Area')
    areas = [get_white_area(x, WHITE_THRESHOLD) for x in images]
    plt.plot(range(len(images)), areas)
    plt.show()
    return max(areas)

def plot_images_at_quantiles(images, quantile_labels):  
    assert(len(images) == len(quantile_labels))
    
    plt.figure(figsize=((30 // len(quantile_labels)) * len(images), 10))

    for i in range(len(images)):
        plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i], cmap="gray")
        plt.title(f"{quantile_labels[i]}", fontsize=16)
        plt.axis("off")
    plt.show()
    return 

def center_images_for_patient(images, quantiles, size):
    #select the image with the most white space
    image_main = images_filtered[quantile_labels.index(1)]
    
    #find binding box of the image

    return resized_img

In [None]:
# image_paths = np.array(get_all_image_paths(100, 'FLAIR', 'train'))
# images = np.array([load_dicom(im) for im in image_paths])

# quantiles = [0.33, 0.67]
# quantiles.sort()
# quantile_labels = quantiles + [1] + quantiles[::-1]

# images_filtered = images[filter_images(images, quantiles)]

# plot_images_at_quantiles(images_filtered, quantile_labels)

In [None]:
def filter_images_for_patient(brats21id, quantiles, folder='train'):
    images_filtered = {}
    for t in TYPES:
        image_paths = np.array(get_all_image_paths(brats21id, t, folder))
        
        if len(image_paths) > 20: 
            image_paths = image_paths[:: len(image_paths) // 20]
        images = np.array([load_dicom(im) for im in image_paths])
        images = images[filter_images(images, quantiles)]
        
        image_main = images[len(quantiles)]
        
        col_sum = np.where(np.sum(image_main, axis = 0)>0)
        row_sum = np.where(np.sum(image_main, axis = 1)>0)
        y1, y2 = row_sum[0][0], row_sum[0][-1]
        x1, x2 = col_sum[0][0], col_sum[0][-1]

        images_filtered[t] = [resize(im[y1:y2, x1:x2], (64, 64), anti_aliasing=True) for im in images]        

    return images_filtered

In [None]:
def get_data_for_patients(patient_ids, folder='train'):
    output = np.array([]) 
    count = 0 
    for patient_id in patient_ids:
        if patient_id in EXCLUDE: 
            continue
        images = filter_images_for_patient(patient_id, [0.33, 0.66], folder)
        output = np.append(output, images)
        count += 1
        if count % 100 == 99: 
            print('Done with %d out of %d' % (count + 1, len(patient_ids)))
    
    return output

def flatten_data_for_individual(data):
    all_types = [data[t] for t in TYPES]
    return np.array([item for sublist in all_types for item in sublist])

def format_data_for_keras(all_data):
    result =  np.array([flatten_data_for_individual(data) for data in all_data])
    result = np.swapaxes(result, 1, 2)
    result = np.swapaxes(result, 2, 3)
    return result



In [None]:
training_data_raw = get_data_for_patients(train_df.BraTS21ID)
testing_data_raw = get_data_for_patients(test_df.BraTS21ID, folder='test')

In [None]:
X = format_data_for_keras(training_data_raw)
y = train_df.MGMT_value[~train_df.BraTS21ID.isin(EXCLUDE)]
X_test = format_data_for_keras(testing_data_raw)

X.shape, X_test.shape, y.shape 


In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pydicom
import ast
import cv2
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.25, random_state=42)

In [None]:
inpt = keras.Input(shape=X_train.shape[1:])
h = keras.layers.experimental.preprocessing.Rescaling(1./255)(inpt)

# convolutional layer!
h = keras.layers.Conv2D(32, kernel_size=(4, 4), activation="relu", name="Conv_1")(h) 
# pooling layer
h = keras.layers.MaxPool2D()(h) 

# convolutional layer!
h = keras.layers.Conv2D(32, kernel_size=(4, 4), activation="relu", name="Conv_2")(h) 
# pooling layer
h = keras.layers.MaxPool2D()(h)

h = keras.layers.Flatten()(h)   
output = keras.layers.Dense(1, activation="sigmoid")(h)

model = keras.Model(inpt, output)

In [None]:
model.compile(loss='binary_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])
history = model.fit(x=X_train, y = y_train, 
                    epochs=20,
                    validation_data= (X_valid, y_valid))
y_pred = model.predict(X_valid)
y_pred = np.reshape(y_pred, (y_pred.shape[0], ))
y_pred.shape
print(roc_auc_score(y_valid, y_pred))

In [None]:
predictions = model.predict(X_test)
print(predictions)

In [None]:
submission = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv', converters={'BraTS21ID': lambda x: str(x)})
submission['MGMT_value'] = predictions

submission.to_csv('submission.csv', index=False)