In [3]:
import tensorflow as tf
import os
#import pandas
from tensorflow import keras
from functools import partial

import SimpleITK for sitk # to read nii files
from dltk.io.preprocessing import whitening # for data normalization

In [4]:
DefaultConv3D = partial(keras.layers.Conv3D, kernel_size=3, strides=(1,)*3,
        padding="SAME", use_bias=False)

In [None]:
class ResidualUnit(keras.layers.Layer):
    # separate construction and execution
    # be aware of the strides' shape
    def __init__(self, filters, strides=(1,)*3, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        
        # a list a layers that can be iterated
        self.main_layers = [
                keras.layers.BatchNormalization(),
                self.activation,
                DefaultConv3D(filters, strides=strides),
                keras.layers.BatchNormalization(),
                self.activation,
                DefaultConv3D(filters, strides=(1,)*3),
                ]
            
    def call(self, inputs):
        x = inputs
        orig_x = inputs
        orig_x = DefaultConv3D(filters,kernel_size=1, strides=strides)
        
        for layer in self.main_layers:
            x = layer(x)
            
        return orig_x + x

In [None]:
filters = (16, 32, 64, 128)
strides = (1, 2, 2, 2)

model = keras.models.Sequential()
model.add(DefaultConv3D(filters[0], kernel_size=3, strides=(1,)*3,
        input_shape=[96, 96, 48, 1]))
model.add(keras.layers.MaxPool3D(pool_size=(3,)*3, strides=(2,)*3, padding="SAME"))

for filter, stride in zip(filters, strides):
    model.add(ResidualUnit(filters, strides=(strides,)*3))

model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.GlobalAvgPool3D())

model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(2, activation="softmax"))

model.compile(loss="sparse_categorical_crossentropy",
        optimizer="sgd",
        metrics=["accuracy"])

In [6]:
mini_batch_size = 8

In [None]:
file_references = pd.read_csv('').values


for f in file_references:
    subject_id = f[0]

    data_path = '../../../data/IXI_HH/2mm'

    # Read the image nii with sitk
    t1_fn = os.path.join(data_path, '{}/T1_2mm.nii.gz'.format(subject_id))
    t1 = sitk.GetArrayFromImage(sitk.ReadImage(str(t1_fn)))

    # Normalise volume image
    t1 = whitening(t1)

    images = np.expand_dims(t1, axis=-1).astype(np.float32)

In [None]:
model.fit(train_set, steps_per_epoch=len(X_train) // batch_size, epochs=10,
        validation_data=valid_set,
        validation_steps=len(X_valid) // batch_size)