In [1]:
import os
import re
import csv
import pickle
import random
import math
import dicom
import numpy as np
import scipy
from tqdm import tqdm
from natsort import natsorted
from skimage import transform
from sklearn.externals import joblib

import matplotlib.pyplot as plt
%matplotlib inline

#### helper functions

In [2]:
def get_filepaths():
    with open('filepaths_train.pkl', 'rb') as f:
        filepaths_train = pickle.load(f)
    with open('filepaths_val.pkl', 'rb') as f:
        filepaths_val = pickle.load(f)
    return filepaths_train, filepaths_val

In [3]:
def get_training_labels():
    systole_labels = {}
    diastole_labels = {}
    with open('../data/train.csv', 'r') as f:
        for _id, systole, diastole in csv.reader(f):
            if _id == 'Id':
                continue
            systole_labels[int(_id)] = float(systole)
            diastole_labels[int(_id)] = float(diastole)
    return systole_labels, diastole_labels

In [4]:
def apply_window(arr, window_center, window_width):
    return np.clip(arr, window_center - window_width/2, window_center + window_width/2)

In [5]:
def apply_per_slice_norm(arr):
    mean = np.mean(arr.ravel())
    std = np.std(arr.ravel())
    if std == 0:
        return np.zeros(arr.shape)
    return (arr - mean) / std

In [6]:
def crop_to_square(arr, size):
    x_len, y_len = arr.shape
    shorter_len = min(x_len, y_len)
    x_start = (arr.shape[0] - shorter_len) // 2
    x_end = x_start + shorter_len
    y_start = (arr.shape[1] - shorter_len) // 2
    y_end = y_start + shorter_len
    return transform.resize(arr[x_start:x_end, y_start:y_end], 
                            (size, size), order=1, clip=True, preserve_range=True)

In [7]:
def create_image_stack(series_filepaths, img_size=196, stack_size=30):
    series_filepaths_all = []
    if len(series_filepaths) == stack_size:
        series_filepaths_all.append(series_filepaths)
    elif len(series_filepaths) < stack_size:
        series_filepaths_all.append(series_filepaths[:] + series_filepaths[:(stack_size - len(series_filepaths))])
    else:
        if re.match(r'^\w+-\d+-\d+-\d+.dcm$', series_filepaths[0][0]) is not None:
            series_filepaths_split = []
            slices_list = []
            series_filepaths_sort_by_slice = sorted(series_filepaths[:], key=lambda x: '{}-{}'.format(x[0].split('-')[-1].split('.')[0], x[0].split('-')[-2]))
            for fname, fpath in series_filepaths_sort_by_slice:
                nslice = fname.split('-')[-1].split('.')[0]
                tframe = fname.split('-')[-2]
                if nslice not in slices_list:
                    if len(series_filepaths_split) == stack_size:
                        series_filepaths_all.append(series_filepaths_split)
                    elif len(series_filepaths_split) < stack_size and len(series_filepaths_split) > 0:
                        series_filepaths_all.append(series_filepaths_split[:] + series_filepaths_split[:(stack_size - len(series_filepaths_split))])
                    series_filepaths_split = []
                    series_filepaths_split.append((fname, fpath))
                    slices_list.append(nslice)
                else:
                    series_filepaths_split.append((fname, fpath))
            
    image_stacks = []
    for series_filepaths_mod in series_filepaths_all:
        image_stack = []
        for i, (fname, fpath) in enumerate(series_filepaths_mod):
            df = dicom.read_file(fpath)
            slice_arr = apply_per_slice_norm(crop_to_square(df.pixel_array, img_size))
            image_stack.append(slice_arr)
        image_stacks.append(np.array(image_stack).astype(np.float32))
        
    return image_stacks

In [8]:
systole_labels, diastole_labels = get_training_labels()

def create_label(pt, mode='systole'):
    if mode == 'systole':
        return systole_labels[pt] < np.arange(600)
    elif mode == 'diastole':
        return diastole_labels[pt] < np.arange(600)
    else:
        raise

#### preprocess data

In [9]:
filepaths_train, filepaths_val = get_filepaths()

In [10]:
train_val_split = 0.05
pts_train_val = random.sample(list(range(1, 501)), int(500 * train_val_split))
pts_train = list(set(range(1, 501)) - set(pts_train_val))

In [11]:
# training
data_train = []
label_sys_train = []
label_dia_train = []

# training local validation
data_train_val = []
label_sys_train_val = []
label_dia_train_val = []

# validation
data_val = []
data_val_pt_index = []

for pt in tqdm(filepaths_train.keys()):
    for series in filepaths_train[pt].keys():
        if series.startswith('sax'):
            imgs3d = create_image_stack(filepaths_train[pt][series], img_size=250, stack_size=30)
            labels_sys = [create_label(pt, mode='systole')] * len(imgs3d)
            labels_dia = [create_label(pt, mode='diastole')] * len(imgs3d)
            if pt in pts_train:
                data_train.extend(imgs3d)
                label_sys_train.extend(labels_sys)
                label_dia_train.extend(labels_dia)
            elif pt in pts_train_val:
                data_train_val.extend(imgs3d)
                label_sys_train_val.extend(labels_sys)
                label_dia_train_val.extend(labels_dia)

for pt in tqdm(filepaths_val.keys()):
    for series in filepaths_val[pt].keys():
        if series.startswith('sax'):
            imgs3d = create_image_stack(filepaths_val[pt][series], img_size=250, stack_size=30)
            data_val.extend(imgs3d)
            data_val_pt_index.extend([int(pt)] * len(imgs3d))
            
data_train = np.array(data_train).astype(np.float32)
label_sys_train = np.array(label_sys_train).astype(np.bool)
label_dia_train = np.array(label_dia_train).astype(np.bool)
print(data_train.shape, label_sys_train.shape, label_dia_train.shape)

data_train_val = np.array(data_train_val).astype(np.float32)
label_sys_train_val = np.array(label_sys_train_val).astype(np.bool)
label_dia_train_val = np.array(label_dia_train_val).astype(np.bool)
print(data_train_val.shape, label_sys_train_val.shape, label_dia_train_val.shape)

data_val = np.array(data_val).astype(np.float32)
data_val_pt_index = np.array(data_val_pt_index).astype(np.uint16)
print(data_val.shape, data_val_pt_index.shape)

                                                 

(5113, 30, 250, 250) (5113, 600) (5113, 600)
(259, 30, 250, 250) (259, 600) (259, 600)
(2169, 30, 250, 250) (2169,)




In [12]:
joblib.dump((pts_train, pts_train_val, 
             data_train, label_sys_train, label_dia_train, 
             data_train_val, label_sys_train_val, label_dia_train_val, 
             data_val, data_val_pt_index), 
            '../data_proc/2-data_processed.pkl')

['../data_proc/2-data_processed.pkl',
 '../data_proc/2-data_processed.pkl_01.npy',
 '../data_proc/2-data_processed.pkl_02.npy',
 '../data_proc/2-data_processed.pkl_03.npy',
 '../data_proc/2-data_processed.pkl_04.npy',
 '../data_proc/2-data_processed.pkl_05.npy',
 '../data_proc/2-data_processed.pkl_06.npy',
 '../data_proc/2-data_processed.pkl_07.npy',
 '../data_proc/2-data_processed.pkl_08.npy']

In [13]:
print(np.min(data_train), np.max(data_train))
print(np.min(data_train_val), np.max(data_train_val))
print(np.min(data_val), np.max(data_val))

-1.62978 29.2818
-1.52743 23.5037
-1.68505 31.3315
