In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import pydicom
import cv2
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import ipywidgets as widgets

import os
from pathlib import Path

plt.style.use('ggplot')

## Resources Path
Before, we have put all dataset information into well structure by using panda dataframe.<br>
Just for making life easier!

In [None]:
base_dir = Path("/kaggle/input")
ds_dir = base_dir.joinpath("rsna-miccai-brain-tumor-radiogenomic-classification")
labels = ds_dir.joinpath("train_labels.csv")
train_dir = ds_dir.joinpath("train")
test_dir = ds_dir.joinpath("test")
clean_tain_dir = Path("/kaggle/working/brain")
IMCAT = ["FLAIR","T1w","T1wCE","T2w"]
test_pre = os.listdir(test_dir)
test_pre = [int(pre)for pre in test_pre]
test_pre.sort()
train_labels_df = pd.read_csv(labels)
train_df = pd.read_pickle(str(base_dir.joinpath("bt-in-nutshell/aggregate.pkl")))
train_df.head()

## Preprocessing

### Normalization method
#### Z-score

$$
\frac{im_i - \mu_i}{\sigma_i}
$$

In [None]:
def z_score(im):
    """
    z-score nomalization
    """
    mask_im = im>im.mean()
    logical_mask = mask_im>0.
    mean = im[logical_mask].mean()
    std = im[logical_mask].std()
    return (im-mean)/std

#### Nyul

In [None]:
def train_nyul(images,i_min=1,i_max=99,i_s_min = 1,i_s_max=100,l_perc=10,u_perc = 90,n_land = 10):
    percs = np.concatenate(([i_min],np.arange(l_perc,u_perc+1,n_land),[i_max]))
    standard_scale = np.zeros(len(percs))
    
    for _im in images:
        mask_data = _im>np.mean(_im)
        masked = _im[mask_data>0]
        landmarks = np.percentile(masked, percs)
        min_p = np.percentile(masked, i_min)
        max_p = np.percentile(masked, i_max)
        f = interp1d([min_p, max_p], [i_s_min, i_s_max])
        landmarks = np.array(f(landmarks))
        standard_scale += landmarks
    standard_scale = standard_scale / len(images)
    return standard_scale, percs

def go_on_hist(image,standard_scale,landmark_percs):
    mask_data = image>image.mean()
    masked = image[mask_data>0]
    landmarks = np.percentile(masked, landmark_percs)
    f = interp1d(landmarks, standard_scale, fill_value='extrapolate')
    normed = f(image)
    return normed

### Tool function

In [None]:
def image_state(im):
    """
    get image mean and std
    """
    non_zero_pixels = im[np.nonzero(im)]
    mean = np.mean(non_zero_pixels)
    std = np.std(non_zero_pixels)
    return mean,std

def level(mu,s,scale_factor=1.7):
    """
    determine level for creating mask
    """
    return mu+scale_factor*s

def calc_idx(image):
    mean,std = image_state(image)
    non_zero_pixels = np.count_nonzero(image>level(mean,std))
    return non_zero_pixels

def top_valuable(images):
    indices = [calc_idx(image)for image in images]
    top_image = np.argsort(indices)[::-1][0]
    return top_image

def top_valuable_line(image,axis):
    mean,std = image_state(image)
    non_cero_pixels = np.count_nonzero(image > level(mean,std),  axis=axis)
    top_line = np.argsort(non_cero_pixels)[::-1][0]
    return top_line

def calc_center(center_pos,r,c,t):
    if t == "Axial":
        center=[
              center_pos[0] + center_pos[0] * c,
              center_pos[1],
              center_pos[2] - center_pos[1] * r
        ]
    if t == "Saggital":
        center = [
              center_pos[0],
              center_pos[1] + center_pos[0] * c,
              center_pos[2] - center_pos[1] * r
        ]
    if t=="Coronal":
        center = [
              center_pos[0] + center_pos[0] * c,
              center_pos[1] + center_pos[0] * r,
              center_pos[2]
        ]
    return center

def find_nearest(_centeres,_center,_ori):
    axis_move = {'Sagittal': 0, 'Coronal': 1, 'Axial': 2}
    scan = np.argsort(np.abs(_centeres - _center),axis=0)[0][axis_move[_ori]]
    return scan

def read_dicom_images(dframe,case,study,orientation):
    s_dframe = dframe.loc[dframe["patient_id"]==case]
    if study is not None:
        s_dframe = s_dframe.loc[dframe["study"]==study]
    if orientation is not None:
        s_dframe = s_dframe.loc[dframe["orientation"]==orientation]
    ims_paths = s_dframe["path"].to_list()
    sorted_paths = sorted(ims_paths,key=lambda x:int(x.split("Image-")[1].split(".")[0]))
    _ans = list()
    _ans_path = list()
    for p in sorted_paths:
        p_im = pydicom.read_file(p).pixel_array
        if np.all((p_im==0)):
            continue
        _ans.append(p_im)
        _ans_path.append(p)
    return _ans,_ans_path

def clean_zero_images(row):
    im_obj = Image(row["path"])
    if np.all((im_obj.image==0)):
        return False
    if np.count_nonzero(im_obj.image)/(im_obj.image.shape[0]*im_obj.image.shape[1])<0.1:
        return False
    return True

def get_bounding_box(image):
    mins = np.min(np.nonzero(image),axis=1)
    maxs = np.max(np.nonzero(image),axis=1)
    return mins[0],mins[1],maxs[0],maxs[1]

def extract_bounding_boxes(images):
    bb = []
    for im in images:
        bb.append([*get_bounding_box(im)])
    
    return np.array(bb)

def extract_stuff(images,bb):
    cropped = list()
    
    for idx,im in enumerate(images):
        x_min,y_min,x_max,y_max = bb[idx]
        cropped.append(im[x_min:x_max,y_min:y_max])
    return cropped

In [None]:
def plot_observation(im):
    """
    take an image and plot and make hist of it
    """
    pixels = im.ravel()
    non_z_pix = pixels[np.nonzero(pixels)]
    mean,std = image_state(im)
    threshold = np.count_nonzero(non_z_pix>level(mean,std))
    fig, (axi, axh,axk) = plt.subplots(1, 3, figsize = (20,4), gridspec_kw={'width_ratios': [1, 4, 4]})
    normal_non_z = z_score(non_z_pix)
    normal_non_mean,normal_non_std = image_state(normal_non_z)
    normal_threshold = np.count_nonzero(normal_non_z>level(normal_non_mean,normal_non_std))
    
    
    fig.suptitle(f"# over threshold normal-{threshold} & zero_score-{normal_threshold}")
    
    axk.hist(normal_non_z, 200)
    axk.set_title("Zero Score")
    ax_limits = axk.get_ylim()
    axk.vlines(normal_non_mean, ymin=ax_limits[0], ymax=ax_limits[1], colors='b', label = "mean")
    axk.vlines(normal_non_mean+normal_non_std, ymin=ax_limits[0], ymax=ax_limits[1], colors='b', linestyles='dotted',label="mean+std")
    axk.vlines(level(normal_non_mean,normal_non_std), ymin=ax_limits[0], ymax=ax_limits[1], colors='g', linestyles='dashed',label="threshold")
    axk.set_xlim(-6,6)
    axk.legend(loc="upper left")
    axk.grid(False)
    
    axh.hist(non_z_pix, 200)
    axh.set_title("Original")
    ax_limits = axh.get_ylim()
    axh.vlines(mean, ymin=ax_limits[0], ymax=ax_limits[1], colors='b', label = "mean")
    axh.vlines(mean+std, ymin=ax_limits[0], ymax=ax_limits[1], colors='b', linestyles='dotted',label="mean+std")
    axh.vlines(level(mean,std), ymin=ax_limits[0], ymax=ax_limits[1], colors='g', linestyles='dashed',label="threshold")
    axh.grid(False)
    
    axi.imshow(im, cmap = plt.cm.gray)
    axi.grid(False)
    axi.axis('off')
    plt.show()

### Some MIR sample
we have **three** study for each of observation:
* **FLAIR**
* **T1w**
* **T1wCE**
* **T2w**

<p>As we can see, for each study, we have series of images, after that when we have ploted them, We saw some cases to mention:</p>

* **Contrast**: <p> each study has specific contrast, and he can use techniques to enhance them</p>
* **Some empty images**: <p> there were some empty images in folder, and for having a cleaner dataset we can explicitly remove and put them away.</p>
* **Series**: <p> we have series of images for each study, but a subset of these images are valuable. As [mentioned](https://www.kaggle.com/josecarmona/btrc-eda-final), we can score images based on non zero voxels, and histogram.</p>
* **Various shape**: <p>some images in **256x256** and **512x512**</p>
* **Plates**:<p>sampled MRI have different posisioned and plates</p>
    1. **Saggital**
    2. **Coronal**
    3. **Axial**




#### Some plot

In [None]:
cols_name = ["FLAIR","T1w","T1wCE","T2w"]
rows_name = ["Saggital","Coronal","Axial"]

sample1 = "00386"
sample_flair_images,_ = read_dicom_images(train_df,sample1,"FLAIR",None)
sample_tw1_images,_=read_dicom_images(train_df,sample1,"T1w",None)
sample_t1wce_image,_=read_dicom_images(train_df,sample1,"T1wCE",None)
sample_t2w_images,_=read_dicom_images(train_df,sample1,"T2w",None)

fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12, 8))

for ax, col in zip(axes, cols_name):
    ax.set_title(col)

axes[0].imshow(sample_flair_images[20],cmap = plt.cm.gray)
axes[0].grid(False)
axes[0].axis('off')

axes[1].imshow(sample_tw1_images[20],cmap = plt.cm.gray)
axes[1].grid(False)
axes[1].axis('off')

axes[2].imshow(sample_t2w_images[20],cmap = plt.cm.gray)
axes[2].grid(False)
axes[2].axis('off')

axes[3].imshow(sample_t1wce_image[20],cmap = plt.cm.gray)
axes[3].grid(False)
axes[3].axis('off')

    
fig.tight_layout()
plt.show()

### Compare
#### Zero score

In [None]:
patient_id = "00386"
flair_images,_ = read_dicom_images(train_df,patient_id,"FLAIR","Axial")
for im in flair_images:
    plot_observation(im)

### Nyul 
> we use a more cleaned data to create a trained landmark on nyul algorithm

In [None]:
c_train_df = pd.read_pickle(str(base_dir.joinpath("bt-in-nutshell/cleaned_data.pkl")))
c_train_df.head()

In [None]:
flaired_images = c_train_df.loc[c_train_df["study"]=="FLAIR"]
flaired_axial_images = flaired_images.loc[flaired_images["orientation"]=="Axial"]["path"]
flaired_axial_images = [pydicom.read_file(p).pixel_array for p in flaired_axial_images]

#### nyul hyperparameter

In [None]:
i_min = widgets.IntSlider(
    value=1,
    min=1,
    max=99,
    step=5,
    description='Minimum Percentil:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

i_max = widgets.IntSlider(
    value=99,
    min=1,
    max=99,
    step=4,
    description='Maximum Percentil:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

i_s_min = widgets.IntSlider(
    value=1,
    min=1,
    max=100,
    step=5,
    description='Minimum standard Percentil:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

i_s_max = widgets.IntSlider(
    value=100,
    min=1,
    max=100,
    step=5,
    description='Maximum Standard Percentil:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

l_perc = i_s_max = widgets.IntSlider(
    value=10,
    min=1,
    max=100,
    step=5,
    description='Low Middle Percentil:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

u_perc = i_s_max = widgets.IntSlider(
    value=90,
    min=1,
    max=100,
    step=5,
    description='Upper Middle Percentil:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

n_land = widgets.IntText(
    value=10,
    description='Number of landmarks:',
    disabled=False
)

container = widgets.VBox([i_min,i_max,i_s_min,i_s_max,l_perc,u_perc,n_land])
container

In [None]:
standard_scale,landmarks = train_nyul(flaired_axial_images,i_min.value,i_max.value,i_s_min.value,i_s_max.value,l_perc.value,u_perc.value,n_land.value)
fig,ax = plt.subplots(1,1,figsize = (15,4))
ax.plot(standard_scale,landmarks)
ax_limits = ax.get_ylim()
ax.grid(False)
ax.set_xlabel("landmark")
for land in landmarks:
    ax.vlines(land, ymin=ax_limits[0], ymax=ax_limits[1], colors='b', linestyles='dotted')
ax.set_ylabel("standard")
fig.show()

In [None]:
normal_flaired_axial_images = [go_on_hist(im,standard_scale,landmarks) for im in flaired_axial_images]
bboxes = extract_bounding_boxes(flaired_axial_images)
cropped_flaired_axial_images = extract_stuff(flaired_axial_images,bboxes)
cropped_normal_flaired_axial_images = extract_stuff(normal_flaired_axial_images,bboxes)

In [None]:
fig, axes = plt.subplots(10, 2, figsize = (20,100), gridspec_kw={'width_ratios': [50, 50]})
axes[0][0].set_title("Original")
axes[0][1].set_title("Nyul Normalized")

for i,(im,norm_im) in enumerate(zip(cropped_flaired_axial_images[:10],cropped_normal_flaired_axial_images[:10])):
    axes[i][0].imshow(im,cmap="gray")
    axes[i][0].grid(False)
    axes[i][0].axis("off")
    axes[i][1].imshow(norm_im,cmap="gray")
    axes[i][1].grid(False)
    axes[i][1].axis("off")
fig.show()