In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
import pydicom
import numpy as np 
import cv2
import os 
import shutil
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

CONFIG = {
    'REMOVE_ZERO_SLICES' : True,
    'INPUT_DIM' : (128,128,128)
}

TRAIN_DIR = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train'
TEST_DIR = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/test'

In [None]:
def trim3d(arr_3d, plot=False):
    slices = arr_3d.shape[2]
    largest_area = 0
    largest_idx = 0
    x,y,w,h = 0,0,arr_3d.shape[0],arr_3d.shape[1]
    for i in range(slices):
        s = arr_3d[:,:,i]
        if np.sum(s) > 0.0:
        
            _,thresh = cv2.threshold(s,0,255,cv2.THRESH_BINARY)
            contours,hierarchy = cv2.findContours(thresh,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
            areas = [cv2.contourArea(c) for c in contours]
            max_idx = np.argmax(areas)

            if areas[max_idx] >= largest_area:
                largest_area = areas[max_idx]
                largest_idx = i
                x,y,w,h = cv2.boundingRect(contours[max_idx])
    
        
    result = arr_3d[y:y+h, x:x+w, :]
    
    if plot:
        fig, axes = plt.subplots(nrows=1, ncols=2)
        axes[0].imshow(arr_3d[:,:,largest_idx])
        axes[1].imshow(result[:,:,largest_idx])
        plt.show()
    
    return result

def dicom_to_img(ds):
    data = pydicom.pixel_data_handlers.apply_voi_lut(ds.pixel_array, ds)
    if ds.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data

def same_size(img):
    img = img.astype(np.float32) 
    img = torch.from_numpy(img) 
    img = img.unsqueeze(0).unsqueeze(0) 
    img = F.interpolate(img, size=CONFIG['INPUT_DIM'], mode='trilinear', align_corners=True) 
    img = img.type(torch.uint8) 
    img = img.squeeze(0).squeeze(0)
    img = img.numpy()
    return img

### Look at Data Dimensions

In [None]:
flair_df, t1w_df, t2w_df, t1wCE_df = [],[],[],[]

patients = os.listdir(TRAIN_DIR)
for patient in tqdm(patients):
    patient_dir = os.path.join(TRAIN_DIR, patient)
    img_types = os.listdir(patient_dir)
    for im_type in img_types:
        files = os.listdir(os.path.join(patient_dir, im_type)) #input
        
        ds = pydicom.dcmread(os.path.join(patient_dir, im_type, files[0]))
        
        entry = [patient, ds.pixel_array.shape[0], ds.pixel_array.shape[1], len(files)]
        if im_type == 'FLAIR':
            flair_df.append(entry)
        elif im_type == 'T1wCE':
            t1wCE_df.append(entry)
        elif im_type == 'T1w':
            t1w_df.append(entry)
        else:
            t2w_df.append(entry)
            
flair = pd.DataFrame(flair_df, columns=["patient","height", "width", "depth"])
t1w = pd.DataFrame(t1w_df, columns=["patient","height", "width", "depth"])
t1wCE = pd.DataFrame(t1wCE_df, columns=["patient","height", "width", "depth"])
t2w = pd.DataFrame(t2w_df, columns=["patient","height", "width", "depth"])

In [None]:
datalist = [flair, t1w, t1wCE, t2w]
names = ["FLAIR", "T1w", "T1wCE", "T2w"]
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(15,10))
for i, x in enumerate(datalist):
    sns.histplot(data=x['height'],binwidth=10, ax=axes[i, 0], multiple='stack')
    axes[i,0].set_title(f'{names[i]} Height')
    sns.histplot(data=x['width'],binwidth=10, ax=axes[i, 1], multiple='stack')
    axes[i,1].set_title(f'{names[i]} Width')
    sns.histplot(data=x['depth'],binwidth=10, ax=axes[i, 2], multiple='stack')
    axes[i,2].set_title(f'{names[i]} Depth')
plt.tight_layout()
plt.show()

## Crop Images to Object of Interest

In [None]:
save_dir = os.getcwd()
save_train_dir = os.path.join(save_dir, 'train')
save_test_dir = os.path.join(save_dir, 'test')

if not os.path.exists(save_train_dir):
    os.makedirs(save_train_dir)

if not os.path.exists(save_test_dir):
    os.makedirs(save_test_dir)

In [None]:
def combine_scans(input_dir, output_dir):
    flair_df, t1w_df, t2w_df, t1wCE_df = [],[],[],[]
    patients = os.listdir(input_dir)
    for patient in tqdm(patients):
        img_types = os.listdir(os.path.join(input_dir, patient))
        img_types.sort() #FFLAIR -> T1w -> T1wCE -> T2w
        patient_dir = os.path.join(output_dir, patient)
        if not os.path.exists(patient_dir):
            os.makedirs(patient_dir) #output
            
        # For each image type

        for i,img_type in enumerate(img_types):
            img_dir = os.path.join(input_dir, patient, img_type) #input
            img_files = os.listdir(img_dir)
            img_files = [os.path.join(img_dir, x) for x in img_files]
            
            ds = [pydicom.dcmread(x) for x in img_files]
            ds.sort(key = lambda x: int(x.ImagePositionPatient[2]))
            imgs = [dicom_to_img(x) for x in ds]
            if CONFIG['REMOVE_ZERO_SLICES']:
                imgs = [x for x in imgs if np.sum(x) > 0]
                
            if len(imgs) > 0: # if there are only slices of blanks. 
                img3d = np.stack(imgs, axis=-1)
                if np.sum(img3d) > 0:
                    img3d = trim3d(img3d)
                    # img3d = same_size(img3d)
                    entry = [patient, img3d.shape[0], img3d.shape[1], img3d.shape[2]]
                    
                    if img_type == 'FLAIR':
                        flair_df.append(entry)
                    elif img_type == 'T1wCE':
                        t1wCE_df.append(entry)
                    elif img_type == 'T1w':
                        t1w_df.append(entry)
                    else:
                        t2w_df.append(entry)
                        
                    save_img_dir = os.path.join(patient_dir, img_type)
                    if not os.path.exists(save_img_dir):
                        os.makedirs(save_img_dir)
                    file_path = os.path.join(save_img_dir, f'{patient}_{img_type}_Image.npz')
                    np.savez_compressed(file_path, img3d)
                else:
                    print(f'All Blanks for {patient} - {img_type}')
            else:
                print(f"No Images for patient {patient} - {img_type}")
    
    flair = pd.DataFrame(flair_df, columns=["patient","height", "width", "depth"])
    t1w = pd.DataFrame(t1w_df, columns=["patient","height", "width", "depth"])
    t1wCE = pd.DataFrame(t1wCE_df, columns=["patient","height", "width", "depth"])
    t2w = pd.DataFrame(t2w_df, columns=["patient","height", "width", "depth"])
    
    return [flair, t1w, t1wCE, t2w]
        
datalist_train = combine_scans(TRAIN_DIR, save_train_dir)
print("Done with Train Dataset")
print("Start Test Dataset")
datalist_test = combine_scans(TEST_DIR, save_test_dir)

In [None]:
names = ["FLAIR", "T1w", "T1wCE", "T2w"]
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(15,10))
for i, x in enumerate(datalist_train):
    sns.histplot(data=x['height'],binwidth=10, ax=axes[i, 0], multiple='stack')
    axes[i,0].set_title(f'{names[i]} Height')
    sns.histplot(data=x['width'],binwidth=10, ax=axes[i, 1], multiple='stack')
    axes[i,1].set_title(f'{names[i]} Width')
    sns.histplot(data=x['depth'],binwidth=10, ax=axes[i, 2], multiple='stack')
    axes[i,2].set_title(f'{names[i]} Depth')
plt.tight_layout()
plt.show()

In [None]:
# shutil.rmtree('./')
# os.listdir('./')