# Intro

This notebook is an attempt to localize tumors areas on the FLAIR scans series and use only this areas to train the model. Allocation will be performed based on the pixel brightness level as the tumor areas are highlighted on the FLAIR scans. Anyway, the notebook should only be considered as just a way to play with data. <br>

One of the main tasks is to reduce the level of brightness of of the brain edges, since that bright pixels are affect the measurements. This can be done with some tools from `scipy.ndimage` library such as `gaussian_filter`, `convolve` or `grey_erosion`, which gives some results. But we will try to implement filter from scratch to make the brain boundaries less bright and highlight tumor area even more. The location of the brightest area will be located with the `center_of_mass` function from `scipy.ndimage`.

Some code and ideas from this works are used:
* [https://www.kaggle.com/sreevishnudamodaran/tpu-rsna-keras-3d-cnn-voxel-train](https://www.kaggle.com/sreevishnudamodaran/tpu-rsna-keras-3d-cnn-voxel-train)
* [https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling](https://www.kaggle.com/ihelon/brain-tumor-eda-with-animations-and-modeling)
* [https://www.kaggle.com/smoschou55/dicom-to-2d-resized-axial-pngs-256x256-x36](https://www.kaggle.com/smoschou55/dicom-to-2d-resized-axial-pngs-256x256-x36)



In [None]:
import numpy as np
import pandas as pd
import os
import glob
import random
import gc
from multiprocessing import Pool
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from IPython.display import HTML, Image
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import pydicom
import PIL 
from scipy import ndimage
import cv2
from sklearn.model_selection import KFold
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
# Paths 
KAGGLE_DIR = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/'
IMG_PATH_TRAIN = KAGGLE_DIR + 'train/'
IMG_PATH_TEST = KAGGLE_DIR + 'test/'
TRAIN_CSV_PATH = KAGGLE_DIR + 'train_labels.csv'
TEST_CSV_PATH = KAGGLE_DIR + 'sample_submission.csv'

# All filenames for train and test images
train_images = os.listdir(IMG_PATH_TRAIN)
test_images = os.listdir(IMG_PATH_TEST)

train=pd.read_csv(TRAIN_CSV_PATH)
test=pd.read_csv(TEST_CSV_PATH)

#write patient id to df
train['patient_id']=sorted(train_images)
test['patient_id']=sorted(test_images)

#drop problem cases
train = train[(train.patient_id != "00109") & 
                     (train.patient_id != "00123") &
                    (train.patient_id != "00709")]


In [None]:
#obtaining file pathes and images characteristics

mods = ['FLAIR', 'T1w', 'T1wCE', 'T2w']
for mod in mods:
    train[mod+'_path'] = IMG_PATH_TRAIN + train['patient_id'] + "/" + mod
    
count=[] 
sizes=[]
for path in train['FLAIR_path']:
    files = glob.glob(os.path.join(path, "*"))
    count.append(len(files))
    dicom = pydicom.read_file(files[0])
    sizes.append(dicom.pixel_array.shape)

train['Flair_count']=count
train['Resolution']=[str(x[0]) + ' x ' + str(x[1]) for x in sizes]
train['Pixel_count']=[x[0]*x[1] for x in sizes]
train['Tumor_type']=train['MGMT_value'].map(lambda x: 'MGMT: 1' if x == 1 else 'MGMT: 0' )

train.head()

In [None]:
for mod in mods:
    test[mod+'_path'] = IMG_PATH_TEST + test['patient_id'] + "/" + mod
test.head()

In [None]:
config = {
    'depth': 24,
    'img_size': 96,
    'nfolds': 4, 
    'batch_size': 16,
    'learning_rate': 0.0008,
    'num_epochs': 10
}

def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = data.astype(np.float32)
    return data

def load_dicom_line(path):
    t_paths = sorted(
        glob.glob(os.path.join(path, "*")), 
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    images = []
    for filename in t_paths:
        data = load_dicom(filename)
        images.append(data)
    return np.array(images)

def crop_image(image):
    keep = (image.mean(axis=(0, 1)) > 0)
    image = image[:, :, keep]
    keep = (image.mean(axis=(0, 2)) > 0)
    image = image[:, keep, :]
    keep = (image.mean(axis=(1, 2)) > 0)
    image = image[keep, :, :]
    return image

def zoom_img(image, height, width, depth):
    current_height = image.shape[1] 
    current_width = image.shape[2]
    current_depth = image.shape[0]
    
    height_factor = 1 / (current_height/height)
    width_factor = 1 / (current_width/width)
    depth_factor = 1 / (current_depth/depth)
        
    image = ndimage.zoom(image, (depth_factor, height_factor, width_factor), order=1)
    return image


# Exploratory Data Analysis

The graphs below shows some scans statistics: 
* The count of the pixels per one scan farame (resolution);
* Total number of the images for patients;
* Distribution of the resolutions in the train folder.

In [None]:
fig=px.bar(train, y='Pixel_count', color='Resolution', title="The count of pixels per frame (resolution)")
fig.show()
fig=px.bar(train, y='Flair_count', color='Tumor_type', title="The number of scans by the case", color_discrete_sequence=["crimson", "gray"])
fig.show()

fig = px.histogram(train, y='Resolution', color='Tumor_type', 
                   title="The scans resolution counts by the tumor type", color_discrete_sequence=["crimson", "gray"]).update_yaxes(categoryorder='total ascending', title='Scan resolution')   
    
fig.show()

### Lets visualise some cases projections and look at the brightness statistics along axes. 
Here and below, the axes of the array and the dimensions are designated as follows: (0, 1, 2) or (z, y, x) or (depth, height, width). The stats includes:
* Mean value of brightness along axes;
* Third qurtile (percentile 0.75) of brightness along axes;
* Third qurtile (percentile 0.75) of brightness along axes for non-zero pixels only.

In [None]:
def plot_case_stats(path):

    images = load_dicom_line(path)
    images=crop_image(images)
    images = zoom_img(images, 150, 150, 150)
    
   
    x_q75 = np.quantile(images,0.75,axis=(1, 0))
    x_iqr = x_q75 - np.quantile(images,0.25,axis=(1, 0))
    x_mean = np.mean(images,axis=(1, 0))   
    non_zeros = [(np.trim_zeros(images[:,:,i].ravel())) for i in range(images.shape[2])]
    nz_x_q75 = [np.quantile(x,0.75) for x in non_zeros]  
        
    y_q75 = np.quantile(images,0.75,axis=(2, 0))
    y_iqr = y_q75 - np.quantile(images,0.25,axis=(2, 0))
    y_mean = np.mean(images,axis=(2, 0)) 
    non_zeros = [(np.trim_zeros(images[:,i,:].ravel())) for i in range(images.shape[1])]
    nz_y_q75 = [np.quantile(x,0.75) for x in non_zeros]  

    z_q75 = np.quantile(images,0.75,axis=(1, 2))
    z_iqr = z_q75 - np.quantile(images,0.25,axis=(1, 2))
    z_mean = np.mean(images,axis=(1, 2)) 
    non_zeros = [(np.trim_zeros(x.ravel())) for x in images]
    nz_z_q75 = [np.quantile(x,0.75) for x in non_zeros]  
    #nz_q75 = [np.mean(x) for x in non_zeros]  
    
    # plot Width and Height projections
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Width (X-axis)",
                "Axes (X,Y) projection", "Height (Y-axis)"))

         #fig.add_trace(go.Scatter(y=x_iqr, name='x_IQR'),row=1, col=1)
    fig.add_trace(go.Scatter(y=x_mean, name='x_Mean'),row=1, col=1)
    fig.add_trace(go.Scatter(y = x_q75, name='x_q75'),row=1, col=1)
    fig.add_trace(go.Scatter(y = nz_x_q75, name='nonzero_x_q75'),row=1, col=1)     
        
    fig.add_trace(px.imshow(np.sum(images, axis=0), binary_string=True).data[0], row=1, col=2)
    
    #fig.add_trace(go.Scatter(x=y_iqr, name='y_IQR'),row=1, col=3)
    fig.add_trace(go.Scatter(x=y_mean, name='y_Mean'),row=1, col=3)
    fig.add_trace(go.Scatter(x = y_q75, name='y_q75'),row=1, col=3)
    fig.add_trace(go.Scatter(x = nz_y_q75, name='nonzero_y_q75'),row=1, col=3)  
    fig.update_yaxes(row=1, col=3, autorange='reversed')        
        
    fig.update_layout(height=300, margin=dict(l=5, r=5, t=70, b=25),
                      title_text='Projections and pixel brightness statistics for the case: {}'.format(path.split("/")[-2]))
    fig.show()    
        
    # plot Depth and Height projections    
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Depth (Z-axis)", "Axes (Y,Z) projection", "Height (Y-axis)"))

    #fig.add_trace(go.Scatter(y=z_iqr, name='z_IQR'),row=1, col=1)
    fig.add_trace(go.Scatter(y=z_mean, name='z_Mean'),row=1, col=1)
    fig.add_trace(go.Scatter(y = z_q75, name='z_q75'),row=1, col=1)
    fig.add_trace(go.Scatter(y = nz_z_q75, name='nonzero_z_q75'),row=1, col=1)
    
    fig.add_trace(px.imshow(np.transpose(np.sum(images, axis=2)), binary_string=True).data[0], row=1, col=2)

    #fig.add_trace(go.Scatter(x=y_iqr, name='y_IQR'),row=1, col=3)
    fig.add_trace(go.Scatter(x=y_mean, name='y_Mean'),row=1, col=3)
    fig.add_trace(go.Scatter(x = y_q75, name='y_q75'),row=1, col=3)
    fig.add_trace(go.Scatter(x = nz_y_q75, name='nonzero_y_q75'),row=1, col=3)  
    fig.update_yaxes(row=1, col=3, autorange='reversed')    

    fig.update_layout(height=300, margin=dict(l=5, r=5, t=25, b=50))
    fig.show()
    

In [None]:
#plot statistics for 3 random cases 
sample=train[train.Flair_count>70].sample(3)
sample=train.sample(3)
for path in sample['FLAIR_path']:
    plot_case_stats(path)

# Filtering steps

The goal is to make inner bright pixels of the tumor brighter and outer brain pixels darker. The operation includes following steps:
* Leave only scans with th number of non-zero pixels more than given threshold;
* Calculate the 3rd quartile of brightness along the width axis for the original image. Multiply each frame of the image along the width axis by its 3rd quartile value;
* Calculate the 3rd quartile of brightness along the height axis for the original image. Multiply each frame of the image filtered in the previous step along the height axis by 3rd quartile values of the original image;
* Calculate the interquartile range (IQR) of brightness along the depth axis of the image filtered in the previous step. Divide each frame of the filtered image along the depth axis by its IQR value.

Multiplying by height and width reduces the brightness of the edges. You can see as the tumor location affects the value of the 3rd quartile. But here is the first con - small tumors at the edges can be darkened. For the brightness normalization the dividing by depth's brightness IQR is performed. As shown at the "Step 2 projection" below, it will reduce the brightness ifference between the areas as brighter frames will be divided by higher values than darker frames. And here is the second con - the threshold of at the first step should be picked carefully, since the edges has less pixels and IQR value tends to be zero. In the same time we may lose some valuable information like small tumors located in the beginning or the end of the MRI.

In [None]:
image = load_dicom_line("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00777/FLAIR")
image=crop_image(image)
image = zoom_img(image, 128, 128, 128)
#keep only scans with 'min_pixels'*100 % non-zero pixels
# scans with less then 40% of meaningful pixels will be dropped
min_pixels = 0.4
depth_quantile = np.quantile(image, 1-min_pixels ,axis=(1, 2))
keep = (depth_quantile > 0)
image = image[keep, :, :]

#step 0 projection
step0=np.sum(image, axis=0)
#filter for initial 2-axis
qX = np.quantile(image, 0.75 ,axis=(1, 0))
filtered = np.array([image[:,:,i] * qX[i] for i in range(image.shape[2])])

#step 1 projection
step1=np.sum(filtered, axis=0)
#filter for initial 1-axis (array reshaped after 2-axis filtering)
qY = np.quantile(image, 0.75, axis=(0, 2))
filtered = np.array([filtered[:,:,i] * qY[i] for i in range(filtered.shape[2])]) 

#step 2 projection
step2=np.sum(filtered, axis=0)
#filter for initial 0-axis (array reshaped after 1-axis filtering)
q75Z = np.quantile(filtered, 0.75, axis=(0, 1))
q25Z = np.quantile(filtered, 0.25, axis=(0, 1))
iqrZ = q75Z - q25Z
filtered = np.array([filtered[:,:,i] / iqrZ[i] for i in range(filtered.shape[2])]) 
filtered=filtered/np.max(filtered)

#step 3 projection
step3=np.sum(filtered, axis=0)

mean_z = np.mean(filtered ,axis=(1, 2))
mean_y = np.mean(filtered ,axis=(0, 2))
mean_x = np.mean(filtered ,axis=(1, 0))

#vectors of non-zero pixels
image_flatten=np.trim_zeros(image.ravel())
filtered_flatten=np.trim_zeros(filtered.ravel())

In [None]:
fig = make_subplots(rows=4, cols=2, column_widths=[0.4, 0.6], subplot_titles=("Original projection","Brightness 0.75 quantile along X-axis", 
                                                                              "Step 1 projection","Brightness 0.75 quantile along Y-axis", 
                                                                              "Step 2 projection","Brightness IQR along Z-axis", 
                                                                              "Filtered projection","Original and Filtered images histograms",))             

fig.add_trace(px.imshow(step0, binary_string=True).data[0], row=1, col=1)   
fig.add_trace(go.Scatter(y=qX, name='X-axis brightness 0.75 quantile'),row=1, col=2)

fig.add_trace(px.imshow(step1, binary_string=True).data[0], row=2, col=1)   
fig.add_trace(go.Scatter(y=qY, name='Y-axis brightness 0.75 quantile'),row=2, col=2)

fig.add_trace(px.imshow(step2, binary_string=True).data[0], row=3, col=1)   
fig.add_trace(go.Scatter(y=iqrZ, name='IQR brightness'),row=3, col=2)

fig.add_trace(px.imshow(step3, binary_string=True).data[0], row=4, col=1)
fig.add_trace(go.Histogram(x=image_flatten,nbinsx=30, name = 'Original image hist'),row=4, col=2)
fig.add_trace(go.Histogram(x=filtered_flatten,nbinsx=30, name = 'Filtered image hist'),row=4, col=2)

fig.update_layout(height=1000, title_text='Filtering steps and values')
fig.show()

### Define the filter function and get few cases data

In [None]:
def q_filter(image, q1=0.25, q2=0.75, min_pixels=0.4, filter_mode = True):
    q1, q2 = sorted([q1, q2])
    threshold = 1-min_pixels
    if q1 < threshold <q2:
        #keep only scans with 'min_pixels'*100 % non-zero pixels
        depth_quantile = np.quantile(image, 1-min_pixels ,axis=(1, 2))
        keep = (depth_quantile > 0)
        image = image[keep, :, :]
        
        #cropped by depth and filtered
        if filter_mode:
            #filter for initial 2-axis
            qX = np.quantile(image, q2 ,axis=(1, 0))
            filtered = np.array([image[:,:,i] * qX[i] for i in range(image.shape[2])])
            
            #filter for initial 1-axis (array reshaped after 2-axis filtering)
            qY = np.quantile(image, q2, axis=(0, 2))
            filtered = np.array([filtered[:,:,i] * qY[i] for i in range(filtered.shape[2])]) 
            #filter for initial 0-axis (array reshaped after 1-axis filtering)
            iqrZ = np.quantile(filtered, q2, axis=(0, 1)) - np.quantile(filtered, q1, axis=(0, 1))

            filtered = np.array([filtered[:,:,i] / iqrZ[i] for i in range(filtered.shape[2])]) 
            filtered=filtered/np.max(filtered)
            return filtered
        #only cropped by depth
        else:
            return image
    else:
        print("Wrong 'min_pixels' value, should be: q1 < (1-min_pixels) <q2, and given q1=", q1, "q2=", q2, '(1-min_pixels)=', 1-min_pixels)

In [None]:
#get 3D scans for random cases
sample=train.sample(8)
images=[]
filtered=[]
im_filenames=[]
filt_filenames=[]

for case in sample['FLAIR_path']:
    image = load_dicom_line(case)
    image = crop_image(image)
    image_filt = q_filter(image, q1=0.25, q2=0.75, min_pixels=0.4)
    image = zoom_img(image, 160, 160, 48)   
    image_filt=zoom_img(image_filt, 160, 160, 48)
    images.append(image)
    filtered.append(image_filt)
    
for i, img in enumerate(images):
    filename = "image" + str(i) + ".gif"
    im_filenames.append('<img src='+filename+'>')
    gif = [PIL.Image.fromarray(frame) for frame in img*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=150, loop=0)
    
for i, img in enumerate(filtered):
    filename = "filt" + str(i) + ".gif"
    filt_filenames.append('<img src='+filename+'>')
    gif = [PIL.Image.fromarray(frame) for frame in img*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=150, loop=0)   

np.array(images).shape

### Define the cropping function

For location of center we use `scipy.ndimage.center_of_mass` function with 3D mask as input. The mask consists only the pixels with brightness higher than 0.995 percentile of array values (brightness). Unfortunately the performed manipulation could not eliminate all highlighted edge pixels, but for the some cases the difference between masks are visible.

After obtaining center of mass coordinates, we crop the area around it. 

In [None]:
original_masks=[]
filtered_masks=[]
# plot masks for the samples
for i, img in enumerate(images):
    filename = "mask" + str(i) + ".gif"
    mask = (img > np.quantile(img, 0.995)) * img  
    original_masks.append('<img src='+filename+'>')
    gif = [PIL.Image.fromarray(frame) for frame in mask*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=150, loop=0)

for i, img in enumerate(filtered):
    filename = "mask_f" + str(i) + ".gif"
    mask = (img > np.quantile(img, 0.995)) * img  
    filtered_masks.append('<img src='+filename+'>')
    gif = [PIL.Image.fromarray(frame) for frame in mask*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=150, loop=0)

masks=pd.DataFrame.from_dict({'Original images masks': original_masks, 
                              'Filtered images masks':filtered_masks}, orient = 'index')

HTML(masks.to_html(escape=False))

In [None]:
#function to locate center of mass and leave the area around it
def crop_3d(image, size_ratio, depth_ratio, filter_image = True):
    
            #crop scans with less than 40% non-zero pixels and return filtered voxel
    if filter_image:
        image =  q_filter(image, filter_mode = True)    
        masked = (image > np.quantile(image, 0.995)) * image  
        center= ndimage.center_of_mass(masked)
        center=np.array(center).astype(int)
    else:
            #just crop scans with less than 40% non-zero pixels
        image =  q_filter(image, filter_mode = False)   
            #get filtered image for the center obtaining  
        filtered = q_filter(image, filter_mode = True) 
        masked = (filtered > np.quantile(filtered, 0.995)) * filtered  
        center= ndimage.center_of_mass(masked)
        center=np.array(center).astype(int)
    
    current_height = image.shape[1] 
    current_width = image.shape[2]
    current_depth = image.shape[0]
    
    #sizes of the crop: height and width  
    size = int(max(image.shape[1], image.shape[2])*size_ratio)
    #the depth of the crop
    new_depth = int(image.shape[0]*depth_ratio)
    
    y1, y2 = max([center[1]-size//2, 0]), min([center[1]+size//2, current_height])
    x1, x2 = max([center[2]-size//2, 0]), min([center[2]+size//2, current_width])
    z1, z2 = max([center[0]-new_depth//2, 0]), min([center[0]+new_depth//2, current_depth])   
    
    if y1 == 0:
        y2 = y1 + size
    elif y2 == current_height:
        y1 = y2 - size
    
    if x1 == 0:
        x2 = x1 + size
    elif x2 == current_width:
        x1 = x2 - size
        
    if z1 == 0:
        z2 = z1 + new_depth
    elif z2 == current_depth:
        z1 = z2 - new_depth
 
    image = image[z1:z2, y1:y2, x1:x2]   

    return image

In [None]:
#areas of the original voxels cropped around filtered center of mass
crop_75=[]
crop_66=[]

for i, img in enumerate(images):
    cropped = crop_3d(img, size_ratio=0.75, depth_ratio=1, filter_image = False)
    cropped = zoom_img(cropped, 160, 160, cropped.shape[0])
    filename = "crop75_" + str(i) + ".gif"
    crop_75.append('<img src='+filename+'>')
    
    gif = [PIL.Image.fromarray(frame) for frame in cropped*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=200, loop=0)   

for i, img in enumerate(images):
    cropped = crop_3d(img,  size_ratio=0.66, depth_ratio=1, filter_image = False)
    cropped = zoom_img(cropped, 160, 160, cropped.shape[0])
    filename = "crop66_" + str(i) + ".gif"
    crop_66.append('<img src='+filename+'>')
    
    gif = [PIL.Image.fromarray(frame) for frame in cropped*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=200, loop=0)       

In [None]:
#areas of the filtered voxels cropped around the center of mass
crop_filt_75=[]
crop_filt_66=[]

for i, img in enumerate(images):
    cropped = crop_3d(img,  size_ratio=0.75, depth_ratio=1, filter_image = True)
    cropped = zoom_img(cropped, 160, 160, cropped.shape[0])
    filename = "crop_filt75_" + str(i) + ".gif"
    crop_filt_75.append('<img src='+filename+'>')
    
    gif = [PIL.Image.fromarray(frame) for frame in cropped*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=150, loop=0)   

for i, img in enumerate(images):
    cropped = crop_3d(img,  size_ratio=0.66, depth_ratio=1, filter_image = True)
    cropped = zoom_img(cropped, 160, 160, cropped.shape[0])
    filename = "crop_filt66_" + str(i) + ".gif"
    crop_filt_66.append('<img src='+filename+'>')
    
    gif = [PIL.Image.fromarray(frame) for frame in cropped*255]
    gif[0].save(filename, save_all=True, append_images=gif[1:], duration=150, loop=0)       

### Comparison of original and filtered voxels

In [None]:
data=sample[['BraTS21ID', 'MGMT_value', 'Resolution', 'Flair_count']].copy()
data['Original']=im_filenames
data['Filtered']=filt_filenames
data['Cropped 75%']=crop_75
data['Filt. and Cropped 75%']=crop_filt_75
data['Cropped 66%']=crop_66
data['Filt. and Cropped 66%']=crop_filt_66

HTML(data.to_html(escape=False))

As shown above, the tumor localized right for the most of cases. However, unfortunate outcomes of cropping are possible. 

# Building the model

Now lets try to use the FLAIR slices for the model training. Further steps:
* Data loading and preparation;
* Defining and training the model;
* Results evaluation.

The `seresnet50` model from [https://github.com/ZFTurbo/classification_models_3D](https://github.com/ZFTurbo/classification_models_3D) will be used for training.

In [None]:
#function to load voxels
#try unfiltered slices with square sizes equals to 66% of max(height, width)
def read_img(path):
    image = load_dicom_line(path)
    image = crop_image(image)
    image = crop_3d(image, size_ratio=0.66, depth_ratio=1, filter_image = False)
    #image = zoom_img(image, config['img_size'], config['img_size'], config['depth'])
       #sampling frames instead of zooming by depth for less quality loss
    image = zoom_img(image, config['img_size'], config['img_size'], image.shape[0])
    ind = np.linspace(0, image.shape[0]-1, config['depth']).astype(int)   
    image = image[tuple(ind),: ,:]
            
    image=np.stack(image, axis=-1)
    return image

def read_modality(paths):
    if __name__ == '__main__':
        with Pool(8) as p:
            images=list(tqdm(p.imap(read_img, paths), total=len(paths)))
            
    images = tf.expand_dims(images, -1)
    return images

In [None]:
flair_train=read_modality(train.FLAIR_path)
flair_train.shape 

In [None]:
flair_test=read_modality(test.FLAIR_path)
flair_test.shape

In [None]:
y=np.asarray(train['MGMT_value']).astype('float32').reshape((-1,1))

In [None]:
!pip install classification-models-3D
!pip install keras_applications

In [None]:
#in case of TPU run
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
#https://www.kaggle.com/sreevishnudamodaran/tpu-rsna-keras-3d-cnn-voxel-train

from classification_models_3D.tfkeras import Classifiers

model_arch = 'seresnet50'

def create_model(input_shape, num_classes):
    inputs = tf.keras.layers.Input((*input_shape, 1), name='inputs')
    x = tf.keras.layers.Conv3D(3, (3, 3, 3), strides=(1, 1, 1), 
                          padding='same', use_bias=True)(inputs)
    
    net, preprocess_input = Classifiers.get(model_arch)
    x = net(input_shape=(*input_shape, 3), include_top=False,
                   weights='imagenet')(x)
    
    x = tf.keras.layers.GlobalAveragePooling3D()(x)
    x = tf.keras.layers.Dropout(rate=0.5)(x)
    
    # Cast output to float32 for numerical stability
    outputs = tf.keras.layers.Dense(num_classes, activation='sigmoid',
                                   dtype='float32')(x)
    model  = tf.keras.Model(inputs, outputs)
   
    model.compile(loss='binary_crossentropy',
                      optimizer=keras.optimizers.SGD(learning_rate=config['learning_rate']),
                      metrics=['AUC'])
    return model

create_model((config['img_size'], config['img_size'], config['depth']), 1).summary()

In [None]:
#https://keras.io/examples/vision/3D_image_classification/

#3d data augmentation
@tf.function
def rotate_shift(voxel):
    """Rotate and shift the voxel"""
    def scipy_rotate_shift(voxel):
        
        # define the values for np.power(array)
        degrees=np.round([x for x in np.linspace(0.85, 1.15, 7)], 2)
        degree=random.choice(degrees)
        # power array 
        voxel = np.power(voxel, degree)
        
        # define some rotation angles
        #angles = [x for x in np.linspace(-5, 5, 6)] + [0, 180]
        angles = [0, 180]
        angle = random.choice(angles)
        # pick rotation axes
        axes = [(0,1), (0,2), (1,2)]
        ax = random.choice(axes)
        # rotate volume 
        voxel = ndimage.rotate(voxel, angle, axes = ax, reshape=False)

        voxel[voxel < 0] = 0
        voxel[voxel > 1] = 1
        return voxel

    augmented = tf.numpy_function(scipy_rotate_shift, [voxel], tf.float32)
    return augmented

def train_preprocessing(voxel, label):
    """Process training data by rotating and shifting"""
    # Rotate volume
    voxel = rotate_shift(voxel)
    return voxel, label

In [None]:
plateau = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.25,
                                               patience=3, verbose=1, mode='min', min_lr=1e-07)

def evaluate_model(X, y, data_test=flair_test, n_folds=config['nfolds']):
    
    counter = 1
    preds=pd.DataFrame(columns=list(range(n_folds)))
    histories = list()
    kfold = KFold(n_folds, shuffle=True, random_state=1)
       
    for ind_train, ind_test in kfold.split(X):
        
        print('CV {}/{}'.format(counter, n_folds))
        
        with strategy.scope():
            model = create_model((config['img_size'], config['img_size'], config['depth']), 1)
               
        X_train = tf.stack([X[x] for x in ind_train], axis=0)
        X_test = tf.stack([X[x] for x in ind_test], axis=0)
        y_train, y_test = y[ind_train], y[ind_test]

        print('Train shape:', X_train.shape, 'Test shape:', X_test.shape)
        
        train_loader = tf.data.Dataset.from_tensor_slices((X_train, y_train))
        validation_loader = tf.data.Dataset.from_tensor_slices((X_test, y_test))
        
        train_dataset = (
            train_loader.shuffle(len(X_train), 
                                 reshuffle_each_iteration=True).map(train_preprocessing).batch(config['batch_size']).prefetch(2))
            
        validation_dataset = (
            validation_loader.shuffle(len(X_test), 
                                      reshuffle_each_iteration=True).batch(config['batch_size']).prefetch(2))
        
        history = model.fit(train_dataset, epochs=config['num_epochs'], 
                            validation_data=validation_dataset , verbose=1, callbacks=plateau)
  
        histories.append(history)
        
        pred=model.predict(data_test)
        pred=pred.reshape(pred.shape[0])
        preds[counter-1]=pred
        
        counter+=1 
        
    return histories, preds

In [None]:
tf.random.set_seed(1)
histories, predictions = evaluate_model(flair_train, y)

In [None]:
losses={}
val_losses={}
aucs={}
val_aucs={}
n = 1
for history in histories:
    df=pd.DataFrame(history.history)
    val_losses["val_loss, CV{}".format(n)] = df['val_loss']
    losses["loss, CV{}".format(n)] = df['loss']
    val_aucs["val_auc, CV{}".format(n)] = df['val_auc']
    aucs["auc, CV{}".format(n)] = df['auc']
    n+=1
    
losses=pd.DataFrame(losses)  
val_losses=pd.DataFrame(val_losses)  
aucs=pd.DataFrame(aucs)
val_aucs=pd.DataFrame(val_aucs)

In [None]:
titles=[]
for i in range(config['nfolds']):
    titles.append('Losses, CV{}'.format(i+1)+"/{}".format(config['nfolds']))

fig = make_subplots(rows=config['nfolds'], cols=1, x_title = "Epoch", subplot_titles=tuple(titles))
for i, loss in enumerate(losses):
    fig.add_trace(go.Scatter(y=losses[loss], name=loss),row=i+1, col=1)
for i, loss in enumerate(val_losses):
    fig.add_trace(go.Scatter(y=val_losses[loss], name=loss),row=i+1, col=1)
        
fig.update_layout(title_text='Losses values')
fig.update_layout(height=1000)
fig.show()

In [None]:
titles=[]
for i in range(config['nfolds']):
    titles.append('AUC values, CV{}'.format(i+1)+"/{}".format(config['nfolds']))

fig = make_subplots(rows=config['nfolds'], cols=1, x_title = "Epoch", subplot_titles=tuple(titles))
for i, auc in enumerate(aucs):
    fig.add_trace(go.Scatter(y=aucs[auc], name=auc),row=i+1, col=1)
for i, auc in enumerate(val_aucs):
    fig.add_trace(go.Scatter(y=val_aucs[auc], name=auc),row=i+1, col=1)
        
fig.update_layout(title_text='AUC values')
fig.update_layout(height=1000)
fig.show()

In [None]:
test['MGMT_value'] = predictions.mean(axis=1)
test.head(10)

In [None]:
test[['BraTS21ID', 'MGMT_value']].to_csv('submission.csv',index = False)

In [None]:
fig = make_subplots(rows=1, cols=2, column_widths=[0.7, 0.3])
                    
fig.add_trace(go.Histogram(x=test['MGMT_value'], name = 'Probability of MGMT'),row=1, col=1)
fig.add_trace(go.Histogram(x=round(test['MGMT_value']), name = 'Predicted Labels'),row=1, col=2)

fig.update_layout(title_text='Predicted probabilities and labels')
fig.show()