In [1]:
import os
import csv
import pickle
import random
import math
import dicom
import numpy as np
import scipy
from tqdm import tqdm
from tinydb import TinyDB, Query
from natsort import natsorted
from skimage import transform
from sklearn.externals import joblib

import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image

#### helper functions

In [2]:
def get_filepaths():
    with open('filepaths_train.pkl', 'rb') as f:
        filepaths_train = pickle.load(f)
    with open('filepaths_val.pkl', 'rb') as f:
        filepaths_val = pickle.load(f)
    return filepaths_train, filepaths_val

In [3]:
def get_training_labels():
    systole_labels = {}
    diastole_labels = {}
    with open('../data/train.csv', 'r') as f:
        for _id, systole, diastole in csv.reader(f):
            if _id == 'Id':
                continue
            systole_labels[int(_id)] = float(systole)
            diastole_labels[int(_id)] = float(diastole)
    return systole_labels, diastole_labels

In [4]:
def apply_window(arr, window_center, window_width):
    return np.clip(arr, window_center - window_width/2, window_center + window_width/2)

In [5]:
def apply_per_slice_norm(arr):
    mean = np.mean(arr.ravel())
    std = np.std(arr.ravel())
    if std == 0:
        return np.zeros(arr.shape)
    return (arr - mean) / std

In [6]:
def crop_to_square(arr, size):
    x_len, y_len = arr.shape
    shorter_len = min(x_len, y_len)
    x_start = (arr.shape[0] - shorter_len) // 2
    x_end = x_start + shorter_len
    y_start = (arr.shape[1] - shorter_len) // 2
    y_end = y_start + shorter_len
    return transform.resize(arr[x_start:x_end, y_start:y_end], 
                            (size, size), order=1, clip=True, preserve_range=False)

In [12]:
def create_image_stack(series_filepaths, img_size=196, stack_size=30):
    image_stack = []
    image_stack_diffs = []
    series_filepaths_mod = series_filepaths[:stack_size]
    if len(series_filepaths_mod) < stack_size:
        series_filepaths_mod += series_filepaths[:(stack_size - len(series_filepaths_mod))]
    for i, (fname, fpath) in enumerate(series_filepaths_mod):
        df = dicom.read_file(fpath)
        slice_arr = apply_window(crop_to_square(df.pixel_array, img_size).astype(np.float32), 
                                 df.WindowCenter, df.WindowWidth)
        image_stack.append(slice_arr)
    image_stack_diffs = [apply_per_slice_norm(image_stack[i+1] - image_stack[i]) for i in range(len(image_stack) - 1)]
    image_stack = [apply_per_slice_norm(img) for img in image_stack]
    return (np.array(image_stack).astype(np.float32), np.array(image_stack_diffs).astype(np.float32))

In [13]:
systole_labels, diastole_labels = get_training_labels()

def create_label(pt, mode='systole'):
    if mode == 'systole':
        return systole_labels[pt] < np.arange(600)
    elif mode == 'diastole':
        return diastole_labels[pt] < np.arange(600)
    else:
        raise

#### preprocess data

In [14]:
filepaths_train, filepaths_val = get_filepaths()

In [15]:
train_val_split = 0.05
pts_train_val = random.sample(list(range(1, 501)), int(500 * train_val_split))
pts_train = list(set(range(1, 501)) - set(pts_train_val))

In [16]:
# training
data_train = []
data_diffs_train = []
label_sys_train = []
label_dia_train = []

# training local validation
data_train_val = []
data_diffs_train_val = []
label_sys_train_val = []
label_dia_train_val = []

# validation
data_val = []
data_diffs_val = []
data_val_pt_index = []

for pt in tqdm(filepaths_train.keys()):
    for series in filepaths_train[pt].keys():
        if series.startswith('sax'):
            img3d, img3d_diffs = create_image_stack(filepaths_train[pt][series], img_size=196, stack_size=30)
            label_sys = create_label(pt, mode='systole')
            label_dia = create_label(pt, mode='diastole')
            if pt in pts_train:
                data_train.append(img3d)
                data_diffs_train.append(img3d_diffs)
                label_sys_train.append(label_sys)
                label_dia_train.append(label_dia)
            elif pt in pts_train_val:
                data_train_val.append(img3d)
                data_diffs_train_val.append(img3d_diffs)
                label_sys_train_val.append(label_sys)
                label_dia_train_val.append(label_dia)

for pt in tqdm(filepaths_val.keys()):
    for series in filepaths_val[pt].keys():
        if series.startswith('sax'):
            img3d, img3d_diffs = create_image_stack(filepaths_val[pt][series], img_size=196, stack_size=30)
            data_val.append(img3d)
            data_diffs_val.append(img3d_diffs)
            data_val_pt_index.append(int(pt))
            
data_train = np.array(data_train).astype(np.float32)
data_diffs_train = np.array(data_diffs_train).astype(np.float32)
label_sys_train = np.array(label_sys_train).astype(np.bool)
label_dia_train = np.array(label_dia_train).astype(np.bool)
print(data_train.shape, data_diffs_train.shape, label_sys_train.shape, label_dia_train.shape)

data_train_val = np.array(data_train_val).astype(np.float32)
data_diffs_train_val = np.array(data_diffs_train_val).astype(np.float32)
label_sys_train_val = np.array(label_sys_train_val).astype(np.bool)
label_dia_train_val = np.array(label_dia_train_val).astype(np.bool)
print(data_train_val.shape, data_diffs_train_val.shape, label_sys_train_val.shape, label_dia_train_val.shape)

data_val = np.array(data_val).astype(np.float32)
data_diffs_val = np.array(data_diffs_val).astype(np.float32)
data_val_pt_index = np.array(data_val_pt_index).astype(np.uint16)
print(data_val.shape, data_diffs_val.shape, data_val_pt_index.shape)

                                                 

(5065, 30, 196, 196) (5065, 29, 196, 196) (5065, 600) (5065, 600)
(266, 30, 196, 196) (266, 29, 196, 196) (266, 600) (266, 600)
(2144, 30, 196, 196) (2144, 29, 196, 196) (2144,)




In [17]:
joblib.dump((pts_train, pts_train_val, 
             data_train, data_diffs_train, label_sys_train, label_dia_train, 
             data_train_val, data_diffs_train_val, label_sys_train_val, label_dia_train_val, 
             data_val, data_diffs_val, data_val_pt_index), 
            '../data_proc/0-data_processed.pkl')

['../data_proc/0-data_processed.pkl',
 '../data_proc/0-data_processed.pkl_01.npy',
 '../data_proc/0-data_processed.pkl_02.npy',
 '../data_proc/0-data_processed.pkl_03.npy',
 '../data_proc/0-data_processed.pkl_04.npy',
 '../data_proc/0-data_processed.pkl_05.npy',
 '../data_proc/0-data_processed.pkl_06.npy',
 '../data_proc/0-data_processed.pkl_07.npy',
 '../data_proc/0-data_processed.pkl_08.npy',
 '../data_proc/0-data_processed.pkl_09.npy',
 '../data_proc/0-data_processed.pkl_10.npy',
 '../data_proc/0-data_processed.pkl_11.npy']