In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'

import random as rn
rn.seed(12345)

from numpy.random import seed
seed(42)

from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import nibabel as nib
import nipy as ni

import numpy as np
import tensorflow as tf

from matplotlib import pyplot as plt

from tensorflow.keras.layers import (Input, Reshape, Dropout, BatchNormalization, Lambda, Dense)
from tensorflow.keras import backend as K

import warnings
warnings.filterwarnings('ignore')

In [2]:
from multiprocessing import cpu_count
import glob

from GPUtil import getFirstAvailable
import datetime
import shutil

import yaml
import csv

import pandas as pd

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint

import tensorflow_probability as tfp

tfd = tfp.distributions

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [4]:
class input_sequence(tf.keras.utils.Sequence):
    def __init__(self, dwi_path, tck_path, batch_size, n_incoming=1, istraining=True):
        """"""
        self.batch_size = batch_size
        self.dwi_path = dwi_path
        self.tck_path = tck_path
        self.n_incoming = n_incoming
        self.istraining = istraining

        # TODO: Implement Caching of generated samples
        cache_path = (
            os.path.basename(self.dwi_path).split(".")[0] + "_" +
            os.path.basename(self.tck_path).split(".")[0]
        )
        
        if os.path.exists(cache_path + ".npy"):
            self.inputs = np.load(cache_path + "_inputs.npy")
            self.outputs = np.load(cache_path + "_outputs.npy")
            self.n_samples = len(self.inputs)
            assert self.inputs.shape[1] / self.n_incoming == 15
        else:
            dwi_img = ni.load_image(self.dwi_path)
            xyz2ijk = lambda r: dwi_img.coordmap.inverse()([r[0], r[1], r[2], 0]).round().astype(int)
            dwi = dwi_img.get_data()

            tck = nib.streamlines.load(self.tck_path)
            fibers = tck.tractogram.streamlines

            self.n_fibers = len(fibers)
            self.fiber_lengths = [len(f) for f in fibers]
            self.n_samples = 0
            for f in fibers:
                self.n_samples += (len(f) - self.n_incoming - 1)

            self.inputs = []
            self.outputs = []
            for fi, f in enumerate(fibers):
                print("Finished {:3.0f}%".format(100*fi/self.n_fibers), end="\r")
                for i in range(self.n_incoming + 1, len(f)):
                    vout = f[i, :] - f[i-1, :]
                    vout /= np.linalg.norm(vout)
                    self.outputs.append(vout.astype("float32"))

                    vin = [f[i-j-1, :] - f[i-j-2, :] for j in range(self.n_incoming)]
                    # Normalize relative to first vin (does it make sense?)
                    vin = [v / np.linalg.norm(vin[0]) for v in vin]
                    vin = np.hstack(vin)

                    idx = xyz2ijk(f[i-1, :]) # anchor point
                    d = dwi[idx[0], idx[1], idx[2], :]

                    self.inputs.append(np.hstack([vin, d]).astype("float32"))

            # TODO: Randomize samples
            
            np.save(cache_path + "_inputs", self.inputs)
            np.save(cache_path + "_outputs", self.outputs)

        assert self.n_samples == len(self.inputs)
        assert self.n_samples == len(self.outputs)

        assert self.inputs[0].shape == (3 * self.n_incoming + 15, )
        assert self.outputs[0].shape == (3, )

        
    def __len__(self):
        if self.istraining:
            return self.n_samples // self.batch_size # drop remainder
        else:
             return np.ceil(self.n_samples / self.batch_size).astype(int)
    
    def __getitem__(self, idx):
        x_batch = np.vstack(self.inputs[idx * self.batch_size:(idx + 1) * self.batch_size])
        y_batch = np.vstack(self.outputs[idx * self.batch_size:(idx + 1) * self.batch_size])
        
        return x_batch, y_batch

In [20]:
config = dict(
    dwi_path = "/local/home/vwegmayr/ijcv19/992774_fod.nii.gz",
    tck_path = "CC.tck",
    batch_size = 128,
    epochs = 5,
    n_incoming = 3,
    save_model = True
)

In [6]:
features = Input(shape=(15 + 3 * config["n_incoming"], ),
                 name="features")

def detrack_fn(features):
    
    x = Dense(1024, activation="relu")(features)
    
    x = Dense(1024, activation="relu")(x)
    
    x = Dense(3, activation="linear")(x)
    
    return Lambda(lambda x: K.l2_normalize(x, axis=-1), name="vout")(x)

model = tf.keras.Model(features, detrack_fn(features), name="detrack")

In [7]:
model.summary()

Model: "detrack"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
features (InputLayer)        [(None, 24)]              0         
_________________________________________________________________
dense (Dense)                (None, 1024)              25600     
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 3075      
_________________________________________________________________
vout (Lambda)                (None, 3)                 0         
Total params: 1,078,275
Trainable params: 1,078,275
Non-trainable params: 0
_________________________________________________________________


In [8]:
train_seq = input_sequence(config["dwi_path"],
                           config["tck_path"],
                           config["batch_size"],
                           config["n_incoming"])

Finished 100%

In [9]:
def cosine_loss(y_true, y_pred):
    return -K.mean(K.sum(y_true * y_pred, axis=1))

In [10]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
model_dir = os.path.join("models/detrack", timestamp)

print(model_dir + "\n")

try:
    no_exception = True
    
    os.makedirs(model_dir)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=cosine_loss
    )
    train_history = model.fit_generator(
        train_seq,
        epochs=config["epochs"],
        #validation_data=eval_sequence,
        callbacks=[
            TensorBoard(log_dir=model_dir, write_graph=False),
            #ModelCheckpoint("weights.{epoch:02d}-{val_loss:.2f}.hdf5", save_freq="epoch"),
            # EarlyStopping(min_delta=0.05, patience=10, restore_best_weights=True, verbose=1)
        ],
        #validation_steps=1,
        max_queue_size=2*config["batch_size"],
        use_multiprocessing=True,
        workers=cpu_count()
    )
except KeyboardInterrupt:
    os.rename(model_dir, model_dir + "_stopped")
    model_dir = model_dir + "_stopped"
except Exception as e:
    shutil.rmtree(model_dir)
    no_exception = False
    raise e
finally:
    if no_exception and config["save_model"]:

        with open(os.path.join(model_dir, "config" + ".yml"), "w") as file:
            yaml.dump(config, file, default_flow_style=False)           

        model.save(os.path.join(model_dir, "model" + ".h5"))

models/detrack/2019-10-07-15:45:55

Epoch 1/5


W1007 15:45:56.082983 139701351245568 deprecation.py:323] From /local/home/vwegmayr/miniconda2/envs/thesis/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
