In [3]:
import cv2
import numpy as np
import pydicom
import json
import os
import shutil
import sys
import random
from matplotlib import image
from scipy.ndimage import label
from zipfile import ZipFile
import re
import pandas as pd
from skimage.transform import resize

In [4]:
# this varibale should be set as where your train.zip, validate.zip, test.zip store
data_path = '.'

In [5]:
# Create a ZipFile Object and load train.zip in it
#with ZipFile(os.path.join(data_path, "train.zip"), 'r') as zipObj:
#   # Extract all the contents of zip file in different directory
#   zipObj.extractall()

In [6]:
study_train = next(os.walk(os.path.join(data_path, "train")))[1]
# load labels in 'train.csv'
# the first column means id
# the second and third columns mean the volume
labels = np.loadtxt(os.path.join(data_path, "train.csv"), delimiter=",",skiprows=1)
labels[0:10]

array([[  1. , 108.3, 246.7],
       [  2. ,  54.6, 137.2],
       [  3. ,  32.7,  99.3],
       [  4. ,  57.7, 154.5],
       [  5. ,  83.3, 235.5],
       [  6. , 225.3, 317.9],
       [  7. ,  64.9, 138. ],
       [  8. , 158.3, 305.5],
       [  9. ,  61.4, 152.2],
       [ 10. , 105.2, 219.3]])

In [7]:
class Dataset(object):
    dataset_count = 0

    def __init__(self, directory, subdir):
        # deal with any intervening directories
        while True:
            subdirs = next(os.walk(directory))[1]
            if len(subdirs) == 1:
                directory = os.path.join(directory, subdirs[0])
            else:
                break

        slices = []
        for s in subdirs:
            m = re.match("sax_(\d+)", s)
            if m is not None:
                slices.append(int(m.group(1)))

        slices_map = {}
        first = True
        times = []
        for s in slices:
            files = next(os.walk(os.path.join(directory, "sax_%d" % s)))[2]
            offset = None

            for f in files:
                m = re.match("IM-(\d{4,})-(\d{4})\.dcm", f)
                if m is not None:
                    if first:
                        times.append(int(m.group(2)))
                    if offset is None:
                        offset = int(m.group(1))

            first = False
            slices_map[s] = offset

        self.directory = directory
        self.time = sorted(times)
        self.slices = sorted(slices)
        self.slices_map = slices_map
        Dataset.dataset_count += 1
        self.name = subdir

    def _filename(self, s, t):
        return os.path.join(self.directory,"sax_%d" % s, "IM-%04d-%04d.dcm" % (self.slices_map[s], t))

    def _read_dicom_image(self, filename):
        d = pydicom.read_file(filename)
        img = d.pixel_array
        IMG_PX_SIZE = 64
        resized_img = resize(img, (IMG_PX_SIZE, IMG_PX_SIZE), anti_aliasing=True)
        return np.array(resized_img)

    def _read_all_dicom_images(self):
        f1 = self._filename(self.slices[0], self.time[0])
        d1 = pydicom.read_file(f1)
        (x, y) = d1.PixelSpacing
        (x, y) = (float(x), float(y))
        f2 = self._filename(self.slices[1], self.time[0])
        d2 = pydicom.read_file(f2)

        # try a couple of things to measure distance between slices
        try:
            dist = np.abs(d2.SliceLocation - d1.SliceLocation)
        except AttributeError:
            try:
                dist = d1.SliceThickness
            except AttributeError:
                dist = 8  # better than nothing...

        self.images = np.array([[self._read_dicom_image(self._filename(d, i))
                                 for i in self.time]
                                for d in self.slices])
        self.dist = dist
        self.area_multiplier = x * y

    def load(self):
        self._read_all_dicom_images()

In [8]:
dset = []
for i,s in enumerate(study_train):
    full_path = os.path.join(data_path, "train", s)
    dset.append(Dataset(full_path, s))
    print("Processing dataset %s..." % dset[i].name)
    p_edv = 0
    p_esv = 0
    try:
        dset[i].load()
        print("Dataset %s processing done." % dset[i].name)
    except Exception as e:
        print("ERROR: Exception %s thrown by dataset %s" % (str(e), dset[i].name))
        print("Omit index: %s" % i)

Processing dataset 135...


  warn("The default mode, 'constant', will be changed to 'reflect' in "


Dataset 135 processing done.
Processing dataset 307...
Dataset 307 processing done.
Processing dataset 61...
Dataset 61 processing done.
Processing dataset 95...
Dataset 95 processing done.
Processing dataset 338...
Dataset 338 processing done.
Processing dataset 300...
Dataset 300 processing done.
Processing dataset 132...
Dataset 132 processing done.
Processing dataset 59...
Dataset 59 processing done.
Processing dataset 92...
Dataset 92 processing done.
Processing dataset 66...
Dataset 66 processing done.
Processing dataset 336...
Dataset 336 processing done.
Processing dataset 104...
Dataset 104 processing done.
Processing dataset 309...
Dataset 309 processing done.
Processing dataset 50...
Dataset 50 processing done.
Processing dataset 68...
Dataset 68 processing done.
Processing dataset 103...
Dataset 103 processing done.
Processing dataset 331...
Dataset 331 processing done.
Processing dataset 57...
Dataset 57 processing done.
Processing dataset 168...
Dataset 168 processing don

Dataset 409 processing done.
Processing dataset 267...
Dataset 267 processing done.
Processing dataset 431...
Dataset 431 processing done.
Processing dataset 293...
Dataset 293 processing done.
Processing dataset 258...
Dataset 258 processing done.
Processing dataset 407...
Dataset 407 processing done.
Processing dataset 251...
Dataset 251 processing done.
Processing dataset 438...
Dataset 438 processing done.
Processing dataset 256...
Dataset 256 processing done.
Processing dataset 400...
Dataset 400 processing done.
Processing dataset 269...
Dataset 269 processing done.
Processing dataset 202...
Dataset 202 processing done.
Processing dataset 454...
Dataset 454 processing done.
Processing dataset 498...
Dataset 498 processing done.
Processing dataset 453...
Dataset 453 processing done.
Processing dataset 205...
Dataset 205 processing done.
Processing dataset 233...
Dataset 233 processing done.
Processing dataset 465...
Dataset 465 processing done.
Processing dataset 491...
Dataset 49

Dataset 155 processing done.
Processing dataset 393...
Dataset 393 processing done.
Processing dataset 199...
Dataset 199 processing done.
Processing dataset 39...
Dataset 39 processing done.
Processing dataset 394...
Dataset 394 processing done.
Processing dataset 152...
Dataset 152 processing done.
Processing dataset 360...
Dataset 360 processing done.
Processing dataset 334...
ERROR: Exception %d format: a number is required, not NoneType thrown by dataset 334
Omit index: 305
Processing dataset 106...
Dataset 106 processing done.
Processing dataset 99...
Dataset 99 processing done.
Processing dataset 52...
Dataset 52 processing done.
Processing dataset 139...
Dataset 139 processing done.
Processing dataset 101...
Dataset 101 processing done.
Processing dataset 333...
Dataset 333 processing done.
Processing dataset 55...
Dataset 55 processing done.
Processing dataset 137...
Dataset 137 processing done.
Processing dataset 305...
Dataset 305 processing done.
Processing dataset 97...
Da

Dataset 24 processing done.
Processing dataset 389...
Dataset 389 processing done.
Processing dataset 177...
Dataset 177 processing done.
Processing dataset 345...
Dataset 345 processing done.
Processing dataset 183...
Dataset 183 processing done.
Processing dataset 148...
Dataset 148 processing done.
Processing dataset 23...
Dataset 23 processing done.
Processing dataset 141...
Dataset 141 processing done.
Processing dataset 373...
Dataset 373 processing done.
Processing dataset 4...
Dataset 4 processing done.
Processing dataset 387...
Dataset 387 processing done.
Processing dataset 15...
Dataset 15 processing done.
Processing dataset 380...
Dataset 380 processing done.
Processing dataset 3...
Dataset 3 processing done.
Processing dataset 374...
Dataset 374 processing done.
Processing dataset 146...
Dataset 146 processing done.
Processing dataset 12...
Dataset 12 processing done.
Processing dataset 179...
Dataset 179 processing done.
Processing dataset 328...
Dataset 328 processing do

In [9]:
# note 337, 437, 463, 499, 234, 393, 334, 305, 279, 416, 41, 123 can not loaded 
# so we just remove them from our training data
omit_subject = [463, 499, 234, 334, 279, 416, 123]
omit_index = [134, 140, 169, 305, 333, 352, 433]

In [10]:
# refine the whole dataset
study_index = [int(ele) for ele in study_train]
study_index = [study_index[i] for i in range(len(study_index)) if i not in omit_index]
X = []
for ind, val in enumerate(study_train):
    try:
        new_image = dset[ind].images
        X.append(new_image)
    except Exception as e:
        print("ERROR: Exception %s" % str(e))
        print("Stop at index: %s" % val)

ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 463
ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 499
ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 234
ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 334
ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 279
ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 416
ERROR: Exception 'Dataset' object has no attribute 'images'
Stop at index: 123


In [11]:
# For example, to simplify this problem, we may just take the average of images
# for each subject
# Of course, you can consider more complicated method to obtain better performance
X_train_average = []
for i in range(len(X)):
    images = X[i]
    t, s, w, h = X[i].shape
    image_sum = np.zeros([64,64])
    for j in range(t):
        for k in range(s):
            image_sum = image_sum + images[j,k,:,:]
            
    image_average = image_sum / (t*s)  
    X_train_average.append(image_average)
X_train_average

[array([[3.97517083e-05, 8.01098650e-05, 9.76968658e-05, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [9.10259345e-05, 1.51416448e-04, 1.28527915e-04, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.19751043e-05, 1.34670650e-04, 1.33027370e-04, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [6.01166337e-05, 2.32406641e-04, 2.36045331e-04, ...,
         2.85578464e-04, 2.75288405e-04, 1.72270445e-04],
        [2.87965131e-05, 1.22796001e-04, 1.82110557e-04, ...,
         2.06622807e-04, 1.67340607e-04, 1.02039818e-04],
        [1.81347606e-05, 6.00383823e-05, 1.26415128e-04, ...,
         3.64064612e-05, 1.23637203e-05, 3.48218705e-06]]),
 array([[7.49716013e-05, 1.04846965e-04, 1.06214449e-04, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.32330688e-04, 1.80443762e-04, 1.81243801e-04, ...,
         1.33369679e-04, 7.29269983e-05, 1.08890076e-05],
        [1.37066283e-04, 

In [12]:
# Target for example: systole only
label = pd.DataFrame(labels)
label.columns = ['Id','Systole','Diastole']
actual_value = np.round(label.loc[np.array(study_index)-1,'Systole']).astype(int)
Y_train = actual_value

# one-hot encode
label_train = np.zeros([len(Y_train), 600])
for i in range(len(Y_train)):
    value = Y_train.iloc[i]
    label_train[i,value-1] = 1

In [13]:
# predictor
X_train_average = np.array(X_train_average).reshape([-1, 64, 64, 1])
X_train_average.shape

(493, 64, 64, 1)

In [23]:
X_train_average.reshape([493,64*64])
#X_train_average[0].reshape([64*64])

array([[3.97517083e-05, 8.01098650e-05, 9.76968658e-05, ...,
        3.64064612e-05, 1.23637203e-05, 3.48218705e-06],
       [7.49716013e-05, 1.04846965e-04, 1.06214449e-04, ...,
        1.20071548e-04, 8.47548596e-05, 3.52917864e-05],
       [4.76960033e-05, 6.58681112e-05, 6.73708936e-05, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       ...,
       [4.12918380e-05, 7.92544349e-05, 2.14921011e-04, ...,
        1.32499173e-04, 1.51919747e-04, 9.71491061e-05],
       [0.00000000e+00, 1.55981113e-05, 2.10800562e-05, ...,
        9.32495783e-07, 8.75980887e-07, 0.00000000e+00],
       [5.24817874e-06, 5.71057335e-05, 7.46420488e-05, ...,
        3.74886424e-05, 3.26103794e-05, 2.40098398e-05]])

In [27]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_train_average, label_train, test_size=0.33, random_state=42)
np.savetxt('X_train.csv', X_train.reshape((330,64*64)), delimiter=',')
np.savetxt('X_test.csv', X_test.reshape((163,64*64)), delimiter=',')
np.savetxt('y_train.csv',y_train, delimiter=',')
np.savetxt('y_test.csv', y_test, delimiter=',')

In [36]:
from __future__ import division, print_function, absolute_import

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression

In [37]:
network = input_data(shape=[None, 64, 64, 1], name='input')
network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = fully_connected(network, 128, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 256, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 600, activation='softmax')

Instructions for updating:
Use tf.initializers.variance_scaling instead with distribution=uniform to get equivalent behavior.


In [38]:
network = regression(network, optimizer='adam', learning_rate=0.01,
                     loss='categorical_crossentropy', name='target')

Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [None]:
model = tflearn.DNN(network, tensorboard_verbose=0)
model.fit({'input': X_train}, {'target': y_train}, n_epoch=1,
           validation_set=({'input': X_test}, {'target': y_train}),
           snapshot_step=100, show_metric=True, run_id='convnet')

---------------------------------
Run id: convnet
Log directory: /tmp/tflearn_logs/
INFO:tensorflow:Summary name Accuracy/ (raw) is illegal; using Accuracy/__raw_ instead.
---------------------------------
Training samples: 330
Validation samples: 163
--
