In [2]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import yaml
import datetime

import nipy as ni
import nibabel as nib
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import load_model
from GPUtil import getFirstAvailable
from tensorflow.keras import backend as K
from nibabel.streamlines.tck import TckFile
from nibabel.streamlines.array_sequence import ArraySequence, concatenate
from nibabel.streamlines.tractogram import Tractogram
from hashlib import md5
from sklearn.preprocessing import normalize
from pathos.multiprocessing import Pool
from scipy.interpolate import RegularGridInterpolator
from time import time

import tensorflow_probability as tfp

import warnings
warnings.filterwarnings('ignore')

In [5]:
def predict(config):
    """
    seeds: [[x1,y1,z1],...,[xn,yn,zn]]
    """
    # Load Data
    
    subject_dir = os.path.join("..", "subjects", config["subject"])
    
    hasher = md5()
    for v in config.values():
        hasher.update(str(v).encode())
    
    save_dir = os.path.join(subject_dir, "predicted_fibers", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Predictions with this config have been created already:\n{}".format(save_dir))
        return
    
    print("Loading DWI...")
    
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    
    dwi_img = nib.load(dwi_path)
    
    affine_original = dwi_img.affine
    
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    
    affine_canonical = dwi_img.affine
    affine_canonical_inv = np.linalg.inv(affine_canonical)
    
    dwi = dwi_img.get_data()
    
    print("Loading Model...")
    
    model_path = os.path.join(config["model_dir"], "model.h5")
    
    def negative_log_likelihood(observed_y, predicted_distribution):
        return -K.mean(predicted_distribution.log_prob(observed_y))
    
    model = load_model(model_path,
                       custom_objects={"negative_log_likelihood": negative_log_likelihood,
                                       "DistributionLambda": tfp.layers.DistributionLambda})
    
    # Define coordinate transforms
    
    input_shape = model.layers[0].get_output_at(0).get_shape().as_list()[-1]
    block_size = int(np.cbrt(input_shape / dwi.shape[-1]))
    
    def ijk2xyz(indices):
        #indices = np.hstack([indices, np.ones((len(indices), 1))])
        return affine_canonical.dot(indices.T).T
    
    def xyz2ijk(coords, snap=False):
        #coords = np.hstack([coords, np.ones((len(coords), 1))])
        ijk = affine_canonical_inv.dot(coords.T).T
        if snap:
            ijk = np.round(ijk).astype(int)
        return ijk
    
    # Define Fiber Termination
    
    class Terminator(object):
        def __init__(self):
            if ".nii" in config["terminator"][0]:
                scalar_img = nib.load(os.path.join(subject_dir, config["terminator"][0]))
                self.scalar = scalar_img.get_data()
                aff_inv = np.linalg.inv(scalar_img.affine)
                self.xyz2ijk = lambda xyz: np.round(aff_inv.dot(xyz.T).T).astype(int)
            else:
                raise NotImplementedError # TODO: Implement termination model
            self.threshold = config["terminator"][1]
        def __call__(self, xyz):
            if hasattr(self, "scalar"):
                ijk = self.xyz2ijk(xyz)
                return np.where(self.scalar[ijk[:,0], ijk[:,1], ijk[:,2]] < self.threshold)[0]
            else:
                raise NotImplementedError

    terminator = Terminator()
    
    print("Loading Seeds...")
    
    seeds = np.load(os.path.join(subject_dir, config["seeds"]))
    # Duplicate seeds for positive and negative starting direction
    seeds = np.vstack([seeds, seeds])
    
    # Define Prior for First Fiber Direction
    
    class Prior(object):
        def __init__(self):
            if ".nii" in config["prior"]:
                peak_img = nib.load(os.path.join(subject_dir, config["prior"]))
                self.peak = peak_img.get_data()
                aff_inv = np.linalg.inv(peak_img.affine)
                self.xyz2ijk = lambda xyz: np.round(aff_inv.dot(xyz.T).T).astype(int)
            elif ".h5" in config["prior"]:
                raise NotImplementedError # TODO: Implement prior model
                
        def __call__(self, xyz):
            if hasattr(self, "peak"):
                ijk = self.xyz2ijk(xyz)
                # Assuming that seeds have been duplicated
                peaks = self.peak[ijk[:,0], ijk[:,1], ijk[:,2]]
                peaks[len(ijk)//2:, :3] *= -1
                return peaks
            elif hasattr(self, "model"):
                raise NotImplementedError # TODO: Implement prior model
        
    prior = Prior()
    
    # Define Interpolation
    
    def interpolate(ijk):

        def inpol_fn(idx):
            IDX = np.round(idx).astype(int)

            values = np.zeros([3, 3, 3,
                               block_size, block_size, block_size,
                               dwi.shape[-1]])

            for x in range(block_size):
                for y in range(block_size):
                    for z in range(block_size):
                        values[x, y, z,:] = dwi[
                            IDX[0] + x - 2 * (block_size // 2) : IDX[0] + x + 1,
                            IDX[1] + y - 2 * (block_size // 2) : IDX[1] + y + 1,
                            IDX[2] + z - 2 * (block_size // 2) : IDX[2] + z + 1,
                            :]
                        
            fn = RegularGridInterpolator(([-1,0,1],[-1,0,1],[-1,0,1]), values)

            d = fn([idx[0]-IDX[0], idx[1]-IDX[1], idx[2]-IDX[2]])[0]
            
            return d.flatten()
        
        with Pool(processes=20) as pool:  
            return pool.map(inpol_fn, ijk)
    
    print("Initialize Fibers...")
    
    assert seeds.shape[-1] == 4   # (x, y, z, 1)
    xyz = seeds.reshape(-1, 1, 4) # (fiber, segment, coord) we assume seeds to be in rasmm!
    
    fiber_idx = np.hstack([np.arange(len(seeds)//2), np.arange(len(seeds)//2)])
    fibers = [[] for _ in range(len(seeds)//2)]
    
    print("Start Iteration...")
    
    for i in range(config["max_steps"]):
        t0 = time()
        
        if config["interpolate_dwi"]:
            d = interpolate(ijk)
        else:
            ijk = xyz2ijk(xyz[:,-1,:], snap=True) # Get coords of latest segement for each fiber 

            d = np.zeros([len(ijk), block_size, block_size, block_size, dwi.shape[-1]])
            
            for ii, idx in enumerate(ijk):
                d[ii] = dwi[idx[0] - (block_size // 2) : idx[0] + (block_size // 2) + 1,
                            idx[1] - (block_size // 2) : idx[1] + (block_size // 2) + 1,
                            idx[2] - (block_size // 2) : idx[2] + (block_size // 2) + 1,
                        :]
                
            d = d.reshape(-1, dwi.shape[-1] * block_size**3)
        
        if i == 0:
            vin = prior(xyz[:,0,:])
        else:
            vin = vout.copy()
        
        if config["predict_fn"] == "mean":
            vout = model(np.hstack([vin,d])).mean().numpy()
            vout = normalize(vout) # Careful, the FvM mean is not a unit vector!
        else:
            vout = model(np.hstack([vin,d])).sample().numpy() # Samples are unit length, though!
        
        rout = (xyz[:, -1, :3] + config["step_size"] * vout)
        rout = np.hstack([rout, np.ones((len(rout), 1))]).reshape(-1, 1, 4)
        
        xyz = np.concatenate([xyz, rout], axis=1)
        
        terminal_indices = terminator(xyz[:, -1, :]) # Check latest points for termination

        for idx in terminal_indices:
            gidx = fiber_idx[idx]
            # Other end not yet added
            if not fibers[gidx]:
                fibers[gidx].append(xyz[idx, :, :3])
            # Other end already added
            else:
                this_end = xyz[idx, :, :3]
                other_end = fibers[gidx][0]
                merged_fiber = np.vstack([np.flip(this_end[1:], axis=0), other_end]) # stitch ends together
                fibers[gidx] = [merged_fiber]
                
        xyz = np.delete(xyz, terminal_indices, axis=0)
        vout = np.delete(vout, terminal_indices, axis=0)
        fiber_idx = np.delete(fiber_idx, terminal_indices)
        
        print("Iter {:4d}/{}, finished {:5d}/{:5d} ({:3.0f}%) of all seeds with {:6.0f} steps/sec".format(
            (i+1), config["max_steps"], len(seeds)-len(fiber_idx), len(seeds),
            100*(1-len(fiber_idx)/len(seeds)), len(vin) / (time() - t0), end="\r")
        )
        if len(fiber_idx) == 0:
            break
    
    # Include unfinished fibers:
    
    for idx, gidx in enumerate(fiber_idx):
        if not fibers[gidx]:
            fibers[gidx].append(xyz[idx, :, :3])
        else:
            this_end = xyz[idx, :, :3]
            other_end = fibers[gidx][0]
            merged_fiber = np.vstack([np.flip(this_end[1:], axis=0), other_end])
            fibers[gidx] = [merged_fiber]
    
    # Save Result
    
    fibers = [f[0] for f in fibers]
    
    tractogram = Tractogram(
        streamlines=ArraySequence(fibers),
        affine_to_rasmm=np.eye(4)
    )
    
    tractogram.apply_affine(affine_original.dot(affine_canonical_inv))
    
    os.makedirs(save_dir, exist_ok=True)

    fiber_path = os.path.join(save_dir, "fibers.trk")
    print("\nSaving {}".format(fiber_path))
    nib.streamlines.save(tractogram, fiber_path)

    config_path = os.path.join(save_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False) 
    
    return tractogram

In [6]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [7]:
config = dict(
    subject="ismrm_rpe",
    model_dir="../models/entrack_conditional/07371985349211c04c0aee0a48c84a9e",
    dwi_path="fod_norm_125.nii.gz",
    prior="peaks_125.nii.gz",
    seeds="wm_seeds.npy",
    terminator=["fa_125.nii.gz", 0.3],
    predict_fn="mean",
    interpolate_dwi=False,
    step_size=0.5,
    max_steps=100
)

In [8]:
tractogram = predict(config)

Loading DWI...
Loading Model...


W1017 22:18:51.955837 140433845106432 deprecation.py:323] From /local/home/vwegmayr/miniconda2/envs/thesis/lib/python3.6/site-packages/tensorflow_probability/python/distributions/von_mises_fisher.py:312: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Loading Seeds...
Initialize Fibers...
Start Iteration...
Iter    1/100, finished    24/ 9136 (  0%) of all seeds with  13019 steps/sec
Iter    2/100, finished    53/ 9136 (  1%) of all seeds with  61476 steps/sec
Iter    3/100, finished    71/ 9136 (  1%) of all seeds with  62060 steps/sec
Iter    4/100, finished    92/ 9136 (  1%) of all seeds with  39239 steps/sec
Iter    5/100, finished   123/ 9136 (  1%) of all seeds with  35166 steps/sec
Iter    6/100, finished   132/ 9136 (  1%) of all seeds with  63949 steps/sec
Iter    7/100, finished   174/ 9136 (  2%) of all seeds with  49932 steps/sec
Iter    8/100, finished   208/ 9136 (  2%) of all seeds with  64406 steps/sec
Iter    9/100, finished   228/ 9136 (  2%) of all seeds with  42967 steps/sec
Iter   10/100, finished   309/ 9136 (  3%) of all seeds with  36693 steps/sec
Iter   11/100, finished   357/ 9136 (  4%) of all seeds with  66160 steps/sec
Iter   12/100, finished   438/ 9136 (  5%) of all seeds with  62363 steps/sec
Iter   

In [9]:
dwi_image = nib.load("../subjects/ismrm_rpe/fod_norm_125.nii.gz")

In [10]:
dwi_ni = ni.load_image("../subjects/ismrm_rpe/fod_norm_125.nii.gz")

In [11]:
dwi_ni.coordmap([1,2,3,1])

array([-0.875, -2.   ,  3.375,  1.   ])

In [21]:
dwi_image.affine

array([[-1.25 ,  0.   ,  0.   ,  0.375],
       [ 0.   , -1.25 ,  0.   ,  0.5  ],
       [ 0.   ,  0.   ,  1.25 , -0.375],
       [ 0.   ,  0.   ,  0.   ,  1.   ]])

In [13]:
dwi_canonical = nib.funcs.as_closest_canonical(dwi_image)

In [22]:
dwi_canonical.affine

array([[   1.25 ,    0.   ,    0.   , -178.375],
       [   0.   ,    1.25 ,    0.   , -214.5  ],
       [   0.   ,    0.   ,    1.25 ,   -0.375],
       [   0.   ,    0.   ,    0.   ,    1.   ]])

In [18]:
dwi_image.affine.dot(np.linalg.inv(dwi_canonical.affine))

array([[  -1.,    0.,    0., -178.],
       [   0.,   -1.,    0., -214.],
       [   0.,    0.,    1.,    0.],
       [   0.,    0.,    0.,    1.]])

In [20]:
dwi_image.shape

(144, 173, 144, 15)