In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import yaml
import datetime

import nipy as ni
import nibabel as nib
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import load_model
from GPUtil import getFirstAvailable
from tensorflow.keras import backend as K
from nibabel.streamlines.tck import TckFile
from nibabel.streamlines.array_sequence import ArraySequence, concatenate
from nibabel.streamlines.tractogram import Tractogram
from hashlib import md5
from sklearn.preprocessing import normalize
from pathos.multiprocessing import Pool
from scipy.interpolate import RegularGridInterpolator
from time import time

import tensorflow_probability as tfp

import warnings
warnings.filterwarnings('ignore')

In [43]:
def predict(config):
    """
    seeds: [[x1,y1,z1],...,[xn,yn,zn]]
    """
    # Load Data
    
    subject_dir = os.path.join("..", "subjects", config["subject"])
    
    hasher = md5()
    for v in config.values():
        hasher.update(str(v).encode())
    
    save_dir = os.path.join(subject_dir, "predicted_fibers", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Predictions with this config have been created already:\n{}".format(save_dir))
        return
    
    print("Loading DWI...")
    
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    dwi_img = ni.load_image(dwi_path)
    dwi = dwi_img.get_data()
    
    print("Loading Model...")
    
    model_path = os.path.join(config["model_dir"], "model.h5")
    
    def negative_log_likelihood(observed_y, predicted_distribution):
        return -K.mean(predicted_distribution.log_prob(observed_y))
    
    model = load_model(model_path,
                       custom_objects={"negative_log_likelihood": negative_log_likelihood,
                                       "DistributionLambda": tfp.layers.DistributionLambda})
    
    input_shape = model.layers[0].get_output_at(0).get_shape().as_list()[-1]
    block_size = int(np.cbrt(input_shape / dwi.shape[-1]))
    
    def ijk2xyz(indices):
        indices = np.hstack([indices, np.zeros((len(indices),1))])
        return dwi_img.coordmap(indices)[:, :3]
    
    def xyz2ijk(coords, snap=False):
        coords = np.hstack([coords, np.zeros((len(coords),1))])
        ijk = dwi_img.coordmap.inverse()(coords)[:, :3]
        if snap:
            ijk = np.round(ijk).astype(int)
        return np.clip(ijk,
                       np.tile([0,0,0], (len(ijk), 1)) + (block_size // 2),
                       np.tile(dwi.shape[:3], (len(ijk), 1)) - (block_size // 2) - 1
        )
    
    # Define Fiber Termination
    
    class Terminator(object):
        def __init__(self):
            if ".nii" in config["terminator"][0]:
                scalar_img = ni.load_image(os.path.join(subject_dir, config["terminator"][0]))
                self.scalar = scalar_img.get_data()
            else:
                raise NotImplementedError # TODO: Implement termination model
            self.threshold = config["terminator"][1]
        def __call__(self, ijk):
            if hasattr(self, "scalar"):
                ijk = np.round(ijk).astype(int)
                return np.where(self.scalar[ijk[:,0],ijk[:,1],ijk[:,2]] < self.threshold)[0]
            else:
                raise NotImplementedError

    terminator = Terminator()
    
    print("Loading Seeds...")
    
    seeds = np.load(os.path.join(subject_dir, config["seeds"]))
    # Duplicate seeds for positive and negative starting direction
    seeds = np.vstack([seeds, seeds])
    
    # Define Prior for First Fiber Direction
    
    class Prior(object):
        def __init__(self):
            if ".nii" in config["prior"]:
                peak_img = ni.load_image(os.path.join(subject_dir, config["prior"]))
                self.peak = peak_img.get_data()
            elif ".h5" in config:
                raise NotImplementedError # TODO: Implement prior model
                
        def __call__(self, ijk):
            if hasattr(self, "peak"):
                ijk = np.round(ijk).astype(int)
                # Assuming that seeds have been duplicated
                return self.peak[ijk[:,0],ijk[:,1],ijk[:,2]] * np.repeat([[1],[-1]], len(ijk)/2, axis=0)
            elif hasattr(self, "model"):
                raise NotImplementedError # TODO: Implement prior model
        
    prior = Prior()
    
    # Define Interpolation
    
    def interpolate(ijk):

        def inpol_fn(idx):
            IDX = np.round(idx).astype(int)

            values = np.zeros([3, 3, 3,
                               block_size, block_size, block_size,
                               dwi.shape[-1]])

            for x in range(block_size):
                for y in range(block_size):
                    for z in range(block_size):
                        values[x, y, z,:] = dwi[
                            IDX[0] + x - 2 * (block_size // 2) : IDX[0] + x + 1,
                            IDX[1] + y - 2 * (block_size // 2) : IDX[1] + y + 1,
                            IDX[2] + z - 2 * (block_size // 2) : IDX[2] + z + 1,
                            :]
                        
            fn = RegularGridInterpolator(([-1,0,1],[-1,0,1],[-1,0,1]), values)

            d = fn([idx[0]-IDX[0], idx[1]-IDX[1], idx[2]-IDX[2]])[0]
            
            return d.flatten()
        
        with Pool(processes=20) as pool:  
            return pool.map(inpol_fn, ijk)
    
    print("Initialize Fibers...")
    
    ijk = xyz2ijk(seeds)
    xyz = seeds.reshape(-1, 1, 3) # [fiber, pt, coo]
    fiber_idx = np.hstack([np.arange(len(seeds)//2), np.arange(len(seeds)//2)])
    fibers = [[] for _ in range(len(seeds)//2)]
    
    print("Start Iteration...")
    
    for i in range(config["max_steps"]):
        t0 = time()
        
        if config["interpolate_dwi"]:
            d = interpolate(ijk)
        else:
            ijk = np.round(ijk).astype(int)

            d = np.zeros([len(ijk), block_size, block_size, block_size, dwi.shape[-1]])
            
            for ii, idx in enumerate(ijk):
                d[ii] = dwi[idx[0] - (block_size // 2) : idx[0] + (block_size // 2) + 1,
                            idx[1] - (block_size // 2) : idx[1] + (block_size // 2) + 1,
                            idx[2] - (block_size // 2) : idx[2] + (block_size // 2) + 1,
                        :]
                
            d = d.reshape(-1, dwi.shape[-1] * block_size**3)
        
        if i == 0:
            vin = prior(xyz[:,0,:])
        else:
            vin = vout.copy()
        
        if config["predict_fn"] == "mean":
            vout = model(np.hstack([vin,d])).mean().numpy()
            vout = normalize(vout) # Careful, the FvM mean is not a unit vector!
        else:
            vout = model(np.hstack([vin,d])).sample().numpy() # Samples are unit length, though!
        
        rout = (xyz[:,-1,:] + config["step_size"] * vout).reshape(-1, 1, 3)
        
        xyz = np.concatenate([xyz, rout], axis=1)
        
        ijk = xyz2ijk(xyz[:,-1,:])
        
        terminal_indices = terminator(ijk)

        for idx in terminal_indices:
            gidx = fiber_idx[idx]
            # Other end not yet added
            if not fibers[gidx]:
                fibers[gidx].append(xyz[idx])
            # Other end already added
            else:
                this_end = xyz[idx]
                other_end = fibers[gidx][0]
                merged_fiber = np.vstack([np.flip(this_end[1:], axis=0), other_end]) # stitch ends together
                fibers[gidx] = [merged_fiber]
                
        xyz = np.delete(xyz, terminal_indices, axis=0)
        vout = np.delete(vout, terminal_indices, axis=0)
        fiber_idx = np.delete(fiber_idx, terminal_indices)
        
        ijk = xyz2ijk(xyz[:,-1,:])
        
        print("Iter {:4d}/{}, finished {:5d}/{:5d} ({:3.0f}%) of all seeds with {:6.0f} steps/sec".format(
            (i+1), config["max_steps"], len(seeds)-len(fiber_idx), len(seeds),
            100*(1-len(fiber_idx)/len(seeds)), len(vin) / (time() - t0), end="\r")
        )
        if len(fiber_idx) == 0:
            break
    
    # Include unfinished fibers:
    
    for idx, gidx in enumerate(fiber_idx):
        if not fibers[gidx]:
            fibers[gidx].append(xyz[idx])
        else:
            this_end = xyz[idx]
            other_end = fibers[gidx][0]
            merged_fiber = np.vstack([np.flip(this_end[1:], axis=0), other_end])
            fibers[gidx] = [merged_fiber]
    
    # Save Result
    
    fibers = [f[0] for f in fibers]
    
    tractogram = Tractogram(
        streamlines=ArraySequence(fibers),
        affine_to_rasmm=np.eye(4)
    )
    
    os.makedirs(save_dir, exist_ok=True)

    fiber_path = os.path.join(save_dir, "fibers.trk")
    print("\nSaving {}".format(fiber_path))
    nib.streamlines.save(tractogram, fiber_path)

    config_path = os.path.join(save_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False) 
    
    return tractogram

In [44]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [45]:
config = dict(
    subject="992774",
    model_dir="../models/entrack_conditional/8d5593b08d4548286cc8564373e82e11",
    dwi_path="fod.nii.gz",
    prior="peak.nii.gz",
    seeds="seeds/bcc2e734e49c597e25ece8a3d499d060/seeds.npy",
    terminator=["fa.nii.gz", 0.21],
    predict_fn="mean",
    interpolate_dwi=False,
    step_size=0.2,
    max_steps=100
)

In [46]:
tractogram = predict(config)

Loading DWI...
Loading Model...
Loading Seeds...
Initialize Fibers...
Start Iteration...
Iter    1/100, finished 138917/364160 ( 38%) of all seeds with 105672 steps/sec
Iter    2/100, finished 149175/364160 ( 41%) of all seeds with 120591 steps/sec
Iter    3/100, finished 159289/364160 ( 44%) of all seeds with  96373 steps/sec
Iter    4/100, finished 168847/364160 ( 46%) of all seeds with 132386 steps/sec
Iter    5/100, finished 178168/364160 ( 49%) of all seeds with 128756 steps/sec
Iter    6/100, finished 186622/364160 ( 51%) of all seeds with 112993 steps/sec
Iter    7/100, finished 194479/364160 ( 53%) of all seeds with 126785 steps/sec
Iter    8/100, finished 201545/364160 ( 55%) of all seeds with 125309 steps/sec
Iter    9/100, finished 207896/364160 ( 57%) of all seeds with 121051 steps/sec
Iter   10/100, finished 213426/364160 ( 59%) of all seeds with  91243 steps/sec
Iter   11/100, finished 218077/364160 ( 60%) of all seeds with 124993 steps/sec
Iter   12/100, finished 222566/