In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as lines
import pydicom

from glob import glob

plt.style.use('ggplot')

In [None]:
def level(mean, std):
    return mean + 1.7 * std

def read_dicom_files(cohort, case, mpMRI):
    PATH = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    files_glob = f'{PATH}/{cohort}/{case}/{mpMRI}/*.dcm'
    sorted_files = sorted(glob(files_glob),key=lambda f: int(f.split('Image-')[1].split('.')[0]))
    return [pydicom.read_file(f) for f in sorted_files]

def image_orientation(dicom):
    rt = 'unkown'
    # https://www.kaggle.com/davidbroberts/determining-mr-image-planes
    (x1,y1,_,x2,y2,_) = [round(v) for v in dicom.ImageOrientationPatient]
    if (x1,y1,x2,y2) == (1,0,0,0):
        rt = 'coronal'
    if (x1,y1,x2,y2) == (1,0,0,1):
        rt = 'axial'
    if (x1,y1,x2,y2) == (0,1,0,0):
        rt = 'sagittal'
    
    if rt == 'unkown':
        raise ValueError(f'unkown ImageOrientationPatient: {dicom.ImageOrientationPatient}')
        
    return rt

def stats_image(image):
    noncero_pixels = image[np.nonzero(image)]
    if noncero_pixels.shape == (0,):
        mean = 0
        std = 0
    else:
        mean = np.mean(noncero_pixels)
        std = np.std(noncero_pixels)
    return (mean,std)

def calc_idx(image):
    (mean,std) = stats_image(image)
    non_cero_pixels = np.count_nonzero(image > level(mean,std))
    return non_cero_pixels
    
def top_brilliant_image(images):
    idx = [calc_idx(image) for image in images]
    top_image = np.argsort(idx)[::-1][0]
    return top_image

def top_brilliant_line(image, axis):
    (mean,std) = stats_image(image)
    non_cero_pixels = np.count_nonzero(image > level(mean,std),  axis=axis)
    top_line = np.argsort(non_cero_pixels)[::-1][0]
    return top_line

def normalize_image(image):
    (mean,std) = stats_image(image)
    image = (image - mean) / std
    return image

def cropped_image(image):
    noncero_pixels = image[np.nonzero(image)]
    if noncero_pixels.shape == (0,):
        return image
    min=np.array(np.nonzero(image)).min(axis=1)
    max=np.array(np.nonzero(image)).max(axis=1)
    return image[min[0]:max[0],min[1]:max[1]]

def cropped_images(images):
    min=np.array(np.nonzero(images)).min(axis=1)
    max=np.array(np.nonzero(images)).max(axis=1)
    return images[min[0]:max[0],min[1]:max[1],min[2]:max[2]]

def calc_center(dicom_file, r, c):
    orientation = image_orientation(dicom_file)
    if orientation == 'coronal':
        center = [dicom_file.ImagePositionPatient[0] + dicom_file.PixelSpacing[0] * c,
                  dicom_file.ImagePositionPatient[1],
                  dicom_file.ImagePositionPatient[2] - dicom_file.PixelSpacing[1] * r]

    if orientation == 'sagittal':
        center = [dicom_file.ImagePositionPatient[0],
                  dicom_file.ImagePositionPatient[1] + dicom_file.PixelSpacing[0] * c,
                  dicom_file.ImagePositionPatient[2] - dicom_file.PixelSpacing[1] * r]
        
    if orientation == 'axial':
        center = [dicom_file.ImagePositionPatient[0] + dicom_file.PixelSpacing[0] * c,
                  dicom_file.ImagePositionPatient[1] + dicom_file.PixelSpacing[0] * r,
                  dicom_file.ImagePositionPatient[2]]

    return center

def find_nearest_scan(dicom_files, center):
    axis_move = {'sagittal': 0, 'coronal': 1, 'axial': 2}
    orientation = image_orientation(dicom_files[0])
    a = np.array([f.ImagePositionPatient for f in dicom_files])
    scan = np.argsort(np.abs(a - center),axis=0)[0][axis_move[orientation]]
    return scan

def plot_image_hist(image):
    pixels = image.ravel()
    noncero_pixels = pixels[np.nonzero(pixels)]
    (mean,std) = stats_image(noncero_pixels)
    noncero_pixels = (noncero_pixels - mean) / std
    (mean,std) = stats_image(noncero_pixels)
    over_threshold = np.count_nonzero(noncero_pixels > level(mean, std))

    fig, (axi, axh) = plt.subplots(1, 2, figsize = (20,3), gridspec_kw={'width_ratios': [1, 4]})
    fig.suptitle(f'scan # ({over_threshold})')

    axh.hist(noncero_pixels, 200)
    axh.set_xlim(-5,5)

    ax_limits = axh.get_ylim()
    axh.vlines(mean, ymin=ax_limits[0], ymax=ax_limits[1], colors='b')
    axh.vlines(mean+std, ymin=ax_limits[0], ymax=ax_limits[1], colors='b', linestyles='dotted')
    axh.vlines(level(mean, std), ymin=ax_limits[0], ymax=ax_limits[1], colors='b', linestyles='dashed')
    axi.imshow(image, cmap = plt.cm.gray)
    axi.grid(False)
    axi.axis('off')
    plt.show()



### Example DICOM File

In [None]:
cohort = 'train'
case = '00386'

flair_dicom_files = read_dicom_files(cohort, case, 'FLAIR')
t1w_dicom_files = read_dicom_files(cohort, case, 'T1w')
t1wce_dicom_files = read_dicom_files(cohort, case, 'T1wCE')
t2w_dicom_files = read_dicom_files(cohort, case, 'T2w')

### Basic Information

In [None]:
flair_orientation = image_orientation(flair_dicom_files[0])
flair_nscans = len(flair_dicom_files)
t1w_orientation = image_orientation(t1w_dicom_files[0])
t1w_nscans = len(t1w_dicom_files)
t1wce_orientation = image_orientation(t1wce_dicom_files[0])
t1wce_nscans = len(t1wce_dicom_files)
t2w_orientation = image_orientation(t2w_dicom_files[0])
t2w_nscans = len(t2w_dicom_files)

print(f"FLAIR: {flair_orientation}, {flair_nscans} scans")
print(f"T1w: {t1w_orientation}, {t1w_nscans} scans")
print(f"T1wce: {t1wce_orientation}, {t1wce_nscans} scans")
print(f"T2w: {t2w_orientation}, {t2w_nscans} scans")

In [None]:
# test: all dicom files must be of the same patient
assert flair_dicom_files[0].PatientID == t1w_dicom_files[0].PatientID
assert flair_dicom_files[0].PatientID == t1wce_dicom_files[0].PatientID
assert flair_dicom_files[0].PatientID == t2w_dicom_files[0].PatientID

### Images

In [None]:
flair_images = cropped_images(np.array([s.pixel_array for s in flair_dicom_files]))
t1wce_images = cropped_images(np.array([s.pixel_array for s in t1wce_dicom_files]))
t1w_images = cropped_images(np.array([s.pixel_array for s in t1w_dicom_files]))
t2w_images = cropped_images(np.array([s.pixel_array for s in t2w_dicom_files]))

In [None]:
import plotly.express as px

fig = px.imshow(flair_images, animation_frame=0, binary_string=True, labels=dict(x="FLAIR Images",animation_frame="scan"), height=800)
fig.show()

In [None]:
fig = px.imshow(t1w_images, animation_frame=0, binary_string=True, labels=dict(x="T1w Images",animation_frame="scan"), height=800)
fig.show()

In [None]:
fig = px.imshow(t1wce_images, animation_frame=0, binary_string=True, labels=dict(x="T1wCE Images",animation_frame="scan"), height=800)
fig.show()

In [None]:
fig = px.imshow(t2w_images, animation_frame=0, binary_string=True, labels=dict(x="T2w Images",animation_frame="scan"), height=800)
fig.show()

### FLAIR Images (Histogram)

In [None]:
for img in flair_images:
    plot_image_hist(img)

### Hypothesis

Tumor is seeing like a brilliant zone in FLAIR images

### Top brilliant Image in FLAIR serie

In [None]:
flair_images = np.array([s.pixel_array for s in flair_dicom_files])

top = top_brilliant_image(flair_images)
top

In [None]:
from pydicom.pixel_data_handlers.util import apply_voi_lut

fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20,10))
fig.suptitle('normalized image vs VOI LUT image')

image = flair_images[top]
ax1.imshow(image, cmap = plt.cm.gray)

im = apply_voi_lut(flair_dicom_files[top].pixel_array, flair_dicom_files[top])
ax2.imshow(im, cmap = plt.cm.gray)

plt.show()

### Center of brilliant image 

In [None]:
(top,flair_dicom_files[top].ImagePositionPatient)

In [None]:
rtop = top_brilliant_line(flair_images[top], axis=1)
ctop = top_brilliant_line(flair_images[top], axis=0)

(rtop,ctop)

In [None]:
center = calc_center(flair_dicom_files[top], rtop, ctop)
center

### Equivalent scan in other series

In [None]:
scan_t1wce = find_nearest_scan(t1wce_dicom_files, center)
(t1wce_orientation, scan_t1wce, t1wce_dicom_files[scan_t1wce].ImagePositionPatient)

In [None]:
scan_t1w = find_nearest_scan(t1w_dicom_files, center)
(t1w_orientation, scan_t1w, t1w_dicom_files[scan_t1w].ImagePositionPatient)

In [None]:
scan_t2w = find_nearest_scan(t2w_dicom_files, center)
(t2w_orientation, scan_t2w, t2w_dicom_files[scan_t2w].ImagePositionPatient)

In [None]:
fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2, 2, figsize = (20,20))
fig.suptitle('images')

im = normalize_image(cropped_image(flair_dicom_files[top].pixel_array))
ax1.imshow(im, cmap = plt.cm.gray)
ax1.set_title(f'FLAIR #scan {top}')

im = normalize_image(cropped_image(t1w_dicom_files[scan_t1w].pixel_array))
ax2.imshow(im, cmap = plt.cm.gray)
ax2.set_title(f'T1w #scan {scan_t1w}')

im = normalize_image(cropped_image(t1wce_dicom_files[scan_t1wce].pixel_array))
ax3.imshow(im, cmap = plt.cm.gray)
ax3.set_title(f'T1wCE #scan {scan_t1wce}')

im = normalize_image(cropped_image(t2w_dicom_files[scan_t2w].pixel_array))
ax4.imshow(im, cmap = plt.cm.gray)
ax4.set_title(f'T2w #scan {scan_t2w}')

plt.show()

### Test: train cohort

In [None]:
def process_case_and_plot(cohort, case):
    flair_dicom_files = read_dicom_files(cohort, case, 'FLAIR')
    t1w_dicom_files = read_dicom_files(cohort, case, 'T1w')
    t1wce_dicom_files = read_dicom_files(cohort, case, 'T1wCE')
    t2w_dicom_files = read_dicom_files(cohort, case, 'T2w')
    
    flair_images = np.array([s.pixel_array for s in flair_dicom_files])
    
    top = top_brilliant_image(flair_images)
    rtop = top_brilliant_line(flair_images[top], axis=1)
    ctop = top_brilliant_line(flair_images[top], axis=0)

    center = calc_center(flair_dicom_files[top], rtop, ctop)
        
    scan_t1w = find_nearest_scan(t1w_dicom_files, center)
    scan_t1wce = find_nearest_scan(t1wce_dicom_files, center)
    scan_t2w = find_nearest_scan(t2w_dicom_files, center)

    flair_image = normalize_image(cropped_image(flair_dicom_files[top].pixel_array))
    t1w_image = normalize_image(cropped_image(t1w_dicom_files[scan_t1w].pixel_array))
    t1wce_image = normalize_image(cropped_image(t1wce_dicom_files[scan_t1wce].pixel_array))
    t2w_image = normalize_image(cropped_image(t2w_dicom_files[scan_t2w].pixel_array))

    fig, (ax1,ax2,ax3,ax4) = plt.subplots(1, 4, figsize = (20,5))
    fig.suptitle(f'Case {case}')
    
    ax1.imshow(flair_image, cmap = plt.cm.gray)
    ax1.set_title(f'FLAIR #scan {top}')
    ax1.grid(False)

    ax2.imshow(t1w_image, cmap = plt.cm.gray)
    ax2.set_title(f'T1w #scan {scan_t1w}')
    ax2.grid(False)
    
    ax3.imshow(t1wce_image, cmap = plt.cm.gray)
    ax3.set_title(f'T1wCE #scan {scan_t1wce}')
    ax3.grid(False)

    ax4.imshow(t2w_image, cmap = plt.cm.gray)
    ax4.set_title(f'T2w #scan {scan_t2w}')
    ax4.grid(False)

    plt.show()

In [None]:
train = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv', converters = {'BraTS21ID': str,})

In [None]:
cohort = 'train'
for case in train.sample(10).BraTS21ID:
    process_case_and_plot(cohort, case)

### Test: test cohort

In [None]:
test = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv', converters = {'BraTS21ID': str,})

In [None]:
cohort = 'test'
for case in test.sample(10).BraTS21ID:
    process_case_and_plot(cohort, case)