In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

In [20]:
import yaml
import os

import nipy as ni
import nibabel as nib
import numpy as np

from tensorflow.keras.models import load_model
from nibabel.streamlines.array_sequence import ArraySequence, concatenate
from nibabel.streamlines.tractogram import Tractogram
from GPUtil import getFirstAvailable

from tensorflow.keras import backend as K

import warnings
warnings.filterwarnings('ignore')

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [4]:
model_dir = "models/detrack/2019-10-07-15:45:55"

In [10]:
def cosine_loss(y_true, y_pred):
    return -K.mean(K.sum(y_true * y_pred, axis=1))

In [11]:
config = yaml.load(open(os.path.join(model_dir, "config.yml")))
model = load_model(os.path.join(model_dir, "model.h5"),
                   custom_objects={"cosine_loss": cosine_loss})

W1008 09:20:44.088097 140150231160576 deprecation.py:323] From /local/home/vwegmayr/miniconda2/envs/thesis/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [14]:
dwi_img = ni.load_image(config["dwi_path"])
dwi = dwi_img.get_data()

wm_mask = ni.load_image("scoring_data/masks/wm.nii.gz").get_data()

In [17]:
x = dwi[73,95,74]
vin = np.array([1,1,0])
vin = vin / np.linalg.norm(vin)
model.predict(np.hstack([vin, vin, vin, x]).reshape(1,-1))

array([[0.99875474, 0.02187484, 0.0448388 ]], dtype=float32)

In [18]:
def seeds2fibers(seeds, step_size, max_iter, dwi_path, wm_path, model, config, save_path="tractogram.tck"):
    """
    seeds: [[x1,y1,z1],...,[xn,yn,zn]]
    """
    # Load data
    dwi_img = ni.load_image(dwi_path)
    dwi = dwi_img.get_data()
    
    wm_img = ni.load_image(wm_path)
    wm = wm_img.get_data()
    
    def ijk2xyz(indices):
        indices = np.hstack([indices, np.zeros((len(indices),1))])
        fn = lambda idx: dwi_img.coordmap(idx)[:3]
        return np.array(list(map(fn, indices)))
    
    def xyz2ijk(coords):
        coords = np.hstack([coords, np.zeros((len(coords),1))])
        fn = lambda coord: dwi_img.coordmap.inverse()(coord).round().astype(int)[:3]
        return np.array(list(map(fn, coords)))
    
    # Begin iteration
    
    ijk = seeds.copy()
    xyz = ijk2xyz(ijk).reshape(-1, 1, 3) # [fiber, pt, coo]
    fibers = []
    for i in range(max_iter):
        
        d = dwi[ijk[:,0], ijk[:,1], ijk[:,2], :]
        
        vin = np.zeros((len(xyz), 3*config["n_incoming"]))
        for j in range(min(config["n_incoming"], len(xyz[0,:,:])-1)):
            vin[:,3*j:3*(j+1)] = xyz[:,-j-1,:] - xyz[:,-j-2,:]

        vout = model.predict(np.hstack([vin,d]))
        rout = (xyz[:,-1,:] + step_size * vout).reshape(-1, 1, 3)
        
        xyz = np.concatenate((xyz, rout), axis=1)
        
        ijk = xyz2ijk(xyz[:,-1,:])
        isterminal = np.where(wm[ijk[:,0],ijk[:,1],ijk[:,2]] < 0.5)[0]

        for idx in isterminal:
            fibers.append(xyz[idx])
        
        xyz = np.delete(xyz, isterminal, axis=0)
        ijk = xyz2ijk(xyz[:,-1,:])
        
        print("Iteration {}".format(i+1), end="\r")
    
        if len(fibers) == len(seeds):
            break
    
    tractogram = Tractogram(
        streamlines=ArraySequence(fibers),
        affine_to_rasmm=np.eye(4) # Fiber coords already in correct space
    )
    
    tck_file = nib.streamlines.tck.TckFile(tractogram=tractogram)
    
    tck_file.save(save_path)
    
    return tractogram

In [21]:
seeds2fibers(np.array([[73,95,74],[74,95,74]]),
             step_size=1,
             max_iter=20,
             dwi_path=config["dwi_path"],
             wm_path="scoring_data/masks/wm.nii.gz",
             model=model,
             config=config)

Iteration 1Iteration 2Iteration 3Iteration 4Iteration 5Iteration 6Iteration 7Iteration 8

<nibabel.streamlines.tractogram.Tractogram at 0x7f76541e4ac8>