## RSNA-MICCAI Brain Tumor Radiogenomic Classification - Inference Notebook

In this notebook we will use different trained models to predict the MGMT_value for the test set using ensemble methods.

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import imageio
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import wandb
import re 
import glob
import cv2
from tqdm.notebook import tqdm
from pathlib import Path

# Deep learning packages
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import warnings
warnings.filterwarnings('ignore')
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

In [None]:
config = {
  'images_source_path' : '../input/rsna-miccai-brain-tumor-radiogenomic-classification/test',
  'csv_path': '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv',
  'data_path': '../input/rsna-miccai-brain-tumor-radiogenomic-classification',
   "models_path": "../input/genetic-biomarker-prediction-with-crnn-train/", 
  'output_path': './crnn/',
  'nfolds': 3,
  'global_seed': 42,
  'batch_size': 4,
  'frames_per_seq': 24,
  'img_size': 224,
  'learning_rate': 0.0001,
  'rnn_cells': 16,  
  'num_epochs': 10,
  'channels': 3,
  'scale' : 0.75
}


mri_types = ['T2w'] 
# mri_types = ['FLAIR','T1w','T1wCE','T2w']

In [None]:
class BrainTumor_GeneticSequence():
    """Prepares the train and the validation data pipeline for mri_type, for ex: mri_type = FLAIR"""
    mri_type = "FLAIR"
    df_data = None
    df_train_labels = pd.read_csv(config['csv_path'])
    
    def __init__(self, mri_type):
        self.mri_type = mri_type
        self.df_data = pd.DataFrame(columns=['BraTS21ID'] + mri_types)
        for key in mri_types:
            self.df_data[key] = self.df_data[key].astype(int)
        self.df_data['BraTS21ID'] = self.df_data['BraTS21ID'].astype(int)

    def prepare_dataframe(self, mode='train'):
        train_folders = ''
        if mode == 'test':
            folders_path = "test_images_source_path"
        else:
            folders_path = "images_source_path"
        train_folders = config[folders_path] + '/'
        for f in tqdm(os.listdir(train_folders)):
            if f in ["00109", "00123", "00709"]: 
                continue
            BraTS21ID = int(f)
            self.df_data = self.df_data.append({'BraTS21ID': BraTS21ID, 'FLAIR': 0, 'T1w': 0, 'T1wCE': 0, 'T2w' : 0}, ignore_index=True)
            BraTS21ID_key_path = f'{config[folders_path]}/{format(BraTS21ID, "05d")}/{self.mri_type}/*.dcm'
            files_len = len(glob.glob(BraTS21ID_key_path))
            # update file count or remove the patient from the dataset for the mri_type chosen
            if files_len > 0:
                self.df_data.loc[self.df_data['BraTS21ID'] == BraTS21ID, self.mri_type] = files_len
            else:
                self.df_data = self.df_data.loc[self.df_data.BraTS21ID!=BraTS21ID]
        self.df_data["folder_name"] = [format(x, '05d') for x in self.df_data["BraTS21ID"]]
        self.df_data["folder_path"] = [os.path.join(config[folders_path], x) for x in self.df_data["folder_name"]]
#         self.df_data = self.df_data.head(30) # for testing
#         print(self.df_data.head())
        self.df_data = pd.merge(self.df_data, self.df_train_labels,how='left',on='BraTS21ID')
    
    def rotate_image(self, image, angle):
        image_center = tuple(np.array(image.shape[1::-1]) / 2)
        rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
        result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR)
        return result
    
    def normalize(self, image):
        result = cv2.normalize(image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        return result
    
    def read_mri(self, path, voi_lut = True, fix_monochrome = True):
        # Original from: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
        dicom = pydicom.read_file(path)
        if voi_lut:
            data = apply_voi_lut(dicom.pixel_array, dicom)
        else:
            data = dicom.pixel_array
        if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
            data = np.amax(data) - data
        data = data - np.min(data)
        data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        data = self.normalize(data)
        data = self.rotate_image(data, np.random.randint(0,20))
        data = self.crop_center_square(data)
        data = cv2.resize(data, (config['img_size'], config['img_size']))
#         (thresh, im_bw) = cv2.threshold(data, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
#         data = cv2.threshold(data, thresh, 255, cv2.THRESH_BINARY)[1]
        data = np.repeat(data[..., np.newaxis], 3, -1)
        return data
    
    def crop_center_square(self, frame, scale=config['scale']):
        y, x = frame.shape[0:2]
        center_x, center_y = x / 2, y / 2
        width_scaled, height_scaled = x * scale, y * scale
        left_x, right_x = center_x - width_scaled / 2, center_x + width_scaled / 2
        top_y, bottom_y = center_y - height_scaled / 2, center_y + height_scaled / 2
        return frame[int(top_y):int(bottom_y), int(left_x):int(right_x)]

    def get_img_path_3d(self, dir_path):
        modality_path = os.path.join(dir_path.decode('utf8'), self.mri_type)
        files = sorted(glob.glob(f"{modality_path}/*.dcm"), key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
        total_img_num = len(files) 
        mid_num = total_img_num // 2
        num_3d2 = config['frames_per_seq'] // 2
        start_idx = max(0, mid_num - num_3d2)
        end_idx = min(len(files), mid_num + num_3d2)
        target_file_paths = tf.convert_to_tensor(files[start_idx:end_idx], dtype=tf.string) 
        
        def get_frames(path):
            file_path = path.numpy().decode('UTF-8')
            image = self.read_mri(file_path)    
            return image
    
        mri_images = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn=get_frames, elems=target_file_paths, fn_output_signature=tf.float32))

        # padding null images 
        if mri_images.shape[0] < config['frames_per_seq']:
            n_zero = tf.zeros((config['frames_per_seq'] - mri_images.shape[0], config['img_size'], config['img_size'], config['channels']), dtype=tf.dtypes.float32)
            mri_images = np.concatenate((mri_images,  n_zero), axis = 0)
        return mri_images
    
    def load_frame(self, df_dict):
        dirname = df_dict['folder_path']
        paths = tf.numpy_function(self.get_img_path_3d, [dirname], tf.float32)
        label = df_dict['MGMT_value']
        label = tf.cast(label, tf.float32)
        return paths, label
    
    def get_loader(self):
        AUTOTUNE = tf.data.AUTOTUNE
        testloader = tf.data.Dataset.from_tensor_slices(dict(self.df_data))
        testloader = (testloader
                    .map(self.load_frame, num_parallel_calls=AUTOTUNE)
                    .batch(config['batch_size'])
                    .prefetch(AUTOTUNE)
                    )
        return testloader
    
    def predict(self, model):
        AUTOTUNE = tf.data.AUTOTUNE
        testloader = tf.data.Dataset.from_tensor_slices(dict(self.df_data))
        testloader = (testloader
                    .map(self.load_frame, num_parallel_calls=AUTOTUNE)
                    .batch(config['batch_size'])
                    .prefetch(AUTOTUNE)
                    )
        proba = model.predict(testloader, verbose=1)
        return proba 

In [None]:
sample_submission_path = os.path.join(config['data_path'], 'sample_submission.csv')
sample_df = pd.read_csv(sample_submission_path); 
test_df = sample_df.copy(); 
test_df["folder_name"] = [format(x, "05d") for x in test_df.BraTS21ID]
test_df["folder_path"] = [os.path.join(config['data_path'], 'test', x) for x in test_df["folder_name"]]
test_df['MGMT_pred'] = 0

# def plot_gifs(loader):
#     os.makedirs('gifs/', exist_ok=True)
#     frames = next(iter(loader))
#     for i, frame in enumerate(frames):
#         imageio.mimsave(f'gifs/out_{i}.gif', (frame).numpy().astype('uint8')) 
    
# dp = BrainTumor_GeneticSequence(mri_types[0])
# dp.prepare_dataframe()
# testloader = dp.get_loader()
# plot_gifs(testloader)

# model = tf.keras.models.load_model(f'./{model_arch}_1.h5')
# dp = BrainTumor_GeneticSequence(mri_types[0])
# dp.prepare_dataframe(mode='test')
# dp.df_data['MGMT_pred'] = 0
# proba = dp.predit(model)
# sample_submission_path = os.path.join(config['data_path'], 'sample_submission.csv')
# sample_df = pd.read_csv(sample_submission_path); 
# dp.df_data['MGMT_pred'] += proba.squeeze()    
# dp.df_data[['BraTS21ID','MGMT_pred']].head()

In [None]:
model_arch = 'xception'
for m_type in mri_types:
    print(f'Predicting for....{m_type}')
    dp = BrainTumor_GeneticSequence(m_type)
    dp.prepare_dataframe()
    model_weight_path = f'{config["models_path"]}{model_arch}_3.h5'
    model = tf.keras.models.load_model(model_weight_path)
    proba = dp.predict(model)
    test_df['MGMT_pred'] += proba.squeeze()    

test_df['MGMT_pred'] /= len(mri_types)
test_df['MGMT_pred'] = [round(x,3) for x in test_df['MGMT_pred']]
sample_df['MGMT_value'] = test_df['MGMT_pred']
sample_df

In [None]:
sample_df.to_csv("submission.csv", index=False)

In [None]:
sns.histplot(sample_df['MGMT_value'])