In [22]:
import os
import re
import csv
import pickle
import random
import math
import dicom
import numpy as np
import scipy
from tqdm import tqdm
from natsort import natsorted
from skimage import transform
from sklearn.externals import joblib

import matplotlib.pyplot as plt
%matplotlib inline

#### helper functions

In [23]:
def get_filepaths():
    with open('filepaths_train.pkl', 'rb') as f:
        filepaths_train = pickle.load(f)
    with open('filepaths_val.pkl', 'rb') as f:
        filepaths_val = pickle.load(f)
    return filepaths_train, filepaths_val

In [24]:
def get_training_labels():
    systole_labels = {}
    diastole_labels = {}
    with open('../data/train.csv', 'r') as f:
        for _id, systole, diastole in csv.reader(f):
            if _id == 'Id':
                continue
            systole_labels[int(_id)] = float(systole)
            diastole_labels[int(_id)] = float(diastole)
    return systole_labels, diastole_labels

In [25]:
def apply_window(arr, window_center, window_width):
    return np.clip(arr, window_center - window_width/2, window_center + window_width/2)

In [50]:
def apply_norm(arr):
    max_val = np.max(arr.ravel())
    if max_val == 0:
        return np.zeros(arr.shape)
    arr = arr / max_val
    mean = np.mean(arr.ravel())
    std = np.std(arr.ravel())
    if std == 0:
        return np.zeros(arr.shape)
    return (arr - mean)

In [27]:
def crop_to_square(arr, size):
    x_len, y_len = arr.shape
    shorter_len = min(x_len, y_len)
    x_start = (arr.shape[0] - shorter_len) // 2
    x_end = x_start + shorter_len
    y_start = (arr.shape[1] - shorter_len) // 2
    y_end = y_start + shorter_len
    return transform.resize(arr[x_start:x_end, y_start:y_end], 
                            (size, size), order=1, clip=True, preserve_range=True)

In [51]:
def interp_vol4d(vol4d, z_locs, nb_slices_z):
    
    def find_interval(test_loc):
        for i in range(len(z_locs) - 1):
            if z_locs[i] <= test_loc <= z_locs[i+1]:
                return i
        return 0
        
    # vol4d shape is (depth, time, row, col)
    # z_locs are the slice locations corresponding to slices along the depth axis
    # we reshape through interpolation to nb_slices_z along the depth axis
    vol4d_new = np.zeros((nb_slices_z, vol4d.shape[1], vol4d.shape[2], vol4d.shape[3]))
    z_locs_new = [z_locs[0]]
    for i in range(1, nb_slices_z - 1):
        z_locs_new.append(z_locs[0] + i * (z_locs[-1] - z_locs[0]) / (nb_slices_z - 1))
    z_locs_new.append(z_locs[-1])
    
    for z, loc in enumerate(z_locs_new):
        if z == 0:
            vol4d_new[0, :, :, :] = vol4d[0, :, :, :]
        elif z == len(z_locs_new) - 1:
            vol4d_new[-1, :, :, :] = vol4d[-1, :, :, :]
        else:
            idx = find_interval(loc)
            z_loc_lower = z_locs[idx]
            z_loc_upper = z_locs[idx + 1]
            vol4d_new[z, :, :, :] = vol4d[idx, :, :, :] + ((loc - z_loc_lower) / (z_loc_upper - z_loc_lower)) * (vol4d[idx+1, :, :, :] - vol4d[idx, :, :, :])
            
    return vol4d_new
    

In [None]:
def create_tensors_2ch(filepaths, img_size=196):

In [None]:
def create_tensors_4ch(filepaths, img_size=196):

In [52]:
def create_tensors_sax(filepaths, img_size=196, nb_slices_z=10):
    t_slices = 30
    
    # create sax series filepaths
    # handles irregularies such as those including z-slices and t-slices in the same folder
    series_filepaths_all = []
    for view in filepaths.keys(): 
        if not re.match(r'^sax', view):
            continue
        if len(filepaths[view]) == t_slices:
            series_filepaths_all.append(filepaths[view])
        elif len(filepaths[view]) < t_slices:
            series_filepaths_all.append(filepaths[view][:] + filepaths[view][:(t_slices - len(filepaths[view]))])
        else:
            if re.match(r'^\w+-\d+-\d+-\d+.dcm$', filepaths[view][0][0]) is not None:
                series_filepaths_split = []
                slices_list = []
                series_filepaths_sort_by_slice = sorted(filepaths[view][:], key=lambda x: '{}-{}'.format(x[0].split('-')[-1].split('.')[0], x[0].split('-')[-2]))
                for fname, fpath in series_filepaths_sort_by_slice:
                    nslice = fname.split('-')[-1].split('.')[0]
                    tframe = fname.split('-')[-2]
                    if nslice not in slices_list:
                        if len(series_filepaths_split) == t_slices:
                            series_filepaths_all.append(series_filepaths_split)
                        elif len(series_filepaths_split) < t_slices and len(series_filepaths_split) > 0:
                            series_filepaths_all.append(series_filepaths_split[:] + series_filepaths_split[:(t_slices - len(series_filepaths_split))])
                        series_filepaths_split = []
                        series_filepaths_split.append((fname, fpath))
                        slices_list.append(nslice)
                    else:
                        series_filepaths_split.append((fname, fpath))
                        
    # sort series by z-locations
    z_locs = []
    for series_filepaths in series_filepaths_all:
        df = dicom.read_file(series_filepaths[0][1])
        z_locs.append(df.SliceLocation)
    series_filepaths_all_zsorted = [series_filepaths for (loc, series_filepaths) in sorted(zip(z_locs, series_filepaths_all), key=lambda pair: pair[0])]
    z_locs = sorted(z_locs)
    
    # create 4d volumens
    vol4d = []
    for series_filepaths in series_filepaths_all_zsorted:
        image_stack = []
        for i, (fname, fpath) in enumerate(series_filepaths):
            df = dicom.read_file(fpath)
            slice_arr = crop_to_square(df.pixel_array, img_size)
            image_stack.append(slice_arr)
        vol4d.append(np.array(image_stack).astype(np.float32))
    vol4d = interp_vol4d(np.array(vol4d).astype(np.float32), z_locs, nb_slices_z)
        
    return apply_norm(vol4d)

In [53]:
systole_labels, diastole_labels = get_training_labels()

def create_label(pt, mode='systole'):
    if mode == 'systole':
        return systole_labels[pt] < np.arange(600)
    elif mode == 'diastole':
        return diastole_labels[pt] < np.arange(600)
    else:
        raise

#### preprocess data

In [62]:
filepaths_train, filepaths_val = get_filepaths()

In [63]:
train_val_split = 0.05
pts_train_val = random.sample(list(range(1, 501)), int(500 * train_val_split))
pts_train = list(set(range(1, 501)) - set(pts_train_val))

In [64]:
# training
data_train = []
label_sys_train = []
label_dia_train = []

# training local validation
data_train_val = []
label_sys_train_val = []
label_dia_train_val = []

# validation
data_val = []
data_val_pt_index = []

for pt in tqdm(filepaths_train.keys()):
    vol4d = create_tensors_sax(filepaths_train[pt], img_size=96, nb_slices_z=20)
    labels_sys = create_label(pt, mode='systole')
    labels_dia = create_label(pt, mode='diastole')
    if pt in pts_train:
        data_train.append(vol4d)
        label_sys_train.append(labels_sys)
        label_dia_train.append(labels_dia)
        data_train.append(vol4d[:,::-1,:,:])
        label_sys_train.append(labels_sys)
        label_dia_train.append(labels_dia)
    elif pt in pts_train_val:
        data_train_val.append(vol4d)
        label_sys_train_val.append(labels_sys)
        label_dia_train_val.append(labels_dia)
        data_train_val.append(vol4d[:,::-1,:,:])
        label_sys_train_val.append(labels_sys)
        label_dia_train_val.append(labels_dia)

for pt in tqdm(filepaths_val.keys()):
    vol4d = create_tensors_sax(filepaths_val[pt], img_size=96, nb_slices_z=20)
    data_val.append(vol4d)
    data_val_pt_index.append(int(pt))
            
data_train = np.array(data_train).astype(np.float32)
label_sys_train = np.array(label_sys_train).astype(np.bool)
label_dia_train = np.array(label_dia_train).astype(np.bool)
print(data_train.shape, label_sys_train.shape, label_dia_train.shape)

data_train_val = np.array(data_train_val).astype(np.float32)
label_sys_train_val = np.array(label_sys_train_val).astype(np.bool)
label_dia_train_val = np.array(label_dia_train_val).astype(np.bool)
print(data_train_val.shape, label_sys_train_val.shape, label_dia_train_val.shape)

data_val = np.array(data_val).astype(np.float32)
data_val_pt_index = np.array(data_val_pt_index).astype(np.uint16)
print(data_val.shape, data_val_pt_index.shape)

                                                 

(950, 20, 30, 96, 96) (950, 600) (950, 600)
(50, 20, 30, 96, 96) (50, 600) (50, 600)
(200, 20, 30, 96, 96) (200,)




In [65]:
print(np.min(data_train), np.max(data_train))
print(np.min(data_train_val), np.max(data_train_val))
print(np.min(data_val), np.max(data_val))

-0.263469 0.979237
-0.191686 0.969832
-0.217834 0.981337


In [66]:
joblib.dump((pts_train, pts_train_val, 
             data_train, label_sys_train, label_dia_train, 
             data_train_val, label_sys_train_val, label_dia_train_val, 
             data_val, data_val_pt_index), 
            '../data_proc/3-data_processed.pkl')

['../data_proc/3-data_processed.pkl',
 '../data_proc/3-data_processed.pkl_01.npy',
 '../data_proc/3-data_processed.pkl_02.npy',
 '../data_proc/3-data_processed.pkl_03.npy',
 '../data_proc/3-data_processed.pkl_04.npy',
 '../data_proc/3-data_processed.pkl_05.npy',
 '../data_proc/3-data_processed.pkl_06.npy',
 '../data_proc/3-data_processed.pkl_07.npy',
 '../data_proc/3-data_processed.pkl_08.npy']