In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import yaml
import datetime

import nipy as ni
import nibabel as nib
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import load_model
from GPUtil import getFirstAvailable
from tensorflow.keras import backend as K
from nibabel.streamlines.tck import TckFile
from nibabel.streamlines.array_sequence import ArraySequence, concatenate
from nibabel.streamlines.tractogram import Tractogram
from hashlib import md5
from sklearn.preprocessing import normalize

import tensorflow_probability as tfp

import warnings
warnings.filterwarnings('ignore')

In [83]:
def predict(config):
    """
    seeds: [[x1,y1,z1],...,[xn,yn,zn]]
    """
    # Load Data
    
    subject_dir = os.path.join("..", "subjects", config["subject"])
    
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    model_path = os.path.join(config["model_dir"], "model.h5")
    
    hasher = md5()
    for v in config.values():
        hasher.update(str(v).encode())
    
    save_dir = os.path.join(subject_dir, "predicted_fibers", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Predictions with this config have been created already:\n{}".format(save_dir))
        return
    
    dwi_img = ni.load_image(dwi_path)
    dwi = dwi_img.get_data()
    
    def ijk2xyz(indices):
        indices = np.hstack([indices, np.zeros((len(indices),1))])
        return dwi_img.coordmap(indices)[:, :3]
    
    def xyz2ijk(coords):
        coords = np.hstack([coords, np.zeros((len(coords),1))])
        ijk = dwi_img.coordmap.inverse()(coords).round().astype(int)[:, :3]
        return np.clip(ijk,
                       np.tile([0,0,0], (len(ijk), 1)),
                       np.tile(dwi.shape[:3], (len(ijk), 1)) - 1
        )
    
    # Define Fiber Termination
    
    class Terminator(object):
        def __init__(self):
            if ".nii" in config["terminator"][0]:
                scalar_img = ni.load_image(os.path.join(subject_dir, config["terminator"][0]))
                self.scalar = scalar_img.get_data()
            else:
                raise NotImplementedError # TODO: Implement termination model
            self.threshold = config["terminator"][1]
        def __call__(self, ijk):
            if hasattr(self, "scalar"):
                return np.where(self.scalar[ijk[:,0],ijk[:,1],ijk[:,2]] < self.threshold)[0]
            else:
                raise NotImplementedError

    terminator = Terminator()
    
    seeds = np.load(os.path.join(subject_dir, config["seeds"]))
    # Duplicate seeds for positive and negative starting direction
    seeds = np.vstack([seeds, seeds])
    
    # Define Prior for First Fiber Direction
    
    class Prior(object):
        def __init__(self):
            if ".nii" in config["prior"]:
                peak_img = ni.load_image(os.path.join(subject_dir, config["prior"]))
                self.peak = peak_img.get_data()
            elif ".h5" in config:
                raise NotImplementedError # TODO: Implement prior model
                
        def __call__(self, xyz):
            if hasattr(self, "peak"):
                ijk = xyz2ijk(xyz)
                # Assuming that seeds have been duplicated
                return self.peak[ijk[:,0],ijk[:,1],ijk[:,2]] * np.repeat([[1],[-1]], len(ijk)/2, axis=0)
            elif hasattr(self, "model"):
                raise NotImplementedError # TODO: Implement prior model
        
    prior = Prior()
    
    # Load Model
    
    def negative_log_likelihood(observed_y, predicted_distribution):
        return -K.mean(predicted_distribution.log_prob(observed_y))
    
    model = load_model(model_path,
                       custom_objects={"negative_log_likelihood": negative_log_likelihood,
                                       "DistributionLambda": tfp.layers.DistributionLambda})
    
    # Begin Iterative Prediction
    
    ijk = xyz2ijk(seeds)
    xyz = seeds.reshape(-1, 1, 3) # [fiber, pt, coo]
    fiber_idx = np.hstack([np.arange(len(seeds)//2), np.arange(len(seeds)//2)])
    fibers = [[] for _ in range(len(seeds)//2)]
    for i in range(config["max_steps"]):
        
        d = dwi[ijk[:,0], ijk[:,1], ijk[:,2], :]
        
        if i == 0:
            vin = prior(xyz[:,0,:])
        else:
            vin = vout.copy()
        
        if config["predict_fn"] == "mean":
            vout = model(np.hstack([vin,d])).mean().numpy()
            vout = normalize(vout) # Careful, the FvM mean is not a unit vector!
        else:
            vout = model(np.hstack([vin,d])).sample().numpy()
        
        rout = (xyz[:,-1,:] + config["step_size"] * vout).reshape(-1, 1, 3)
        
        xyz = np.concatenate([xyz, rout], axis=1)
        
        ijk = xyz2ijk(xyz[:,-1,:])
        
        terminal_indices = terminator(ijk)

        for idx in terminal_indices:
            gidx = fiber_idx[idx]
            # Other end not yet added
            if not fibers[gidx]:
                fibers[gidx].append(xyz[idx])
            # Other end already added
            else:
                this_end = xyz[idx]
                other_end = fibers[gidx][0]
                merged_fiber = np.vstack([np.flip(this_end[1:], axis=0), other_end]) # stitch ends together
                fibers[gidx] = [merged_fiber]
                
        xyz = np.delete(xyz, terminal_indices, axis=0)
        vout = np.delete(vout, terminal_indices, axis=0)
        fiber_idx = np.delete(fiber_idx, terminal_indices)
        
        ijk = xyz2ijk(xyz[:,-1,:])
        
        print("Step {:4d}/{}, finished {:5d}/{:5d} ({:3.0f}%) of all seeds.".format(
            (i+1), config["max_steps"], len(seeds)-len(fiber_idx), len(seeds),
            100*(1-len(fiber_idx)/len(seeds))), end="\r")

        if len(fiber_idx) == 0:
            break
    
    # Include unfinished fibers:
    
    for idx, gidx in enumerate(fiber_idx):
        if not fibers[gidx]:
            fibers[gidx].append(xyz[idx])
        else:
            this_end = xyz[idx]
            other_end = fibers[gidx][0]
            merged_fiber = np.vstack([np.flip(this_end[1:], axis=0), other_end])
            fibers[gidx] = [merged_fiber]
    
    # Save Result
    
    fibers = [f[0] for f in fibers]
    
    tractogram = Tractogram(
        streamlines=ArraySequence(fibers),
        affine_to_rasmm=np.eye(4)
    )
    
    os.makedirs(save_dir, exist_ok=True)

    fiber_path = os.path.join(save_dir, "fibers.trk")
    print("\nSaving {}".format(fiber_path))
    nib.streamlines.save(tractogram, fiber_path)

    config_path = os.path.join(save_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False) 
    
    return tractogram

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [87]:
config = dict(
    subject="992774",
    model_dir="../models/entrack_conditional/8d5593b08d4548286cc8564373e82e11",
    dwi_path="fod.nii.gz",
    prior="peak.nii.gz",
    seeds="seeds/bcc2e734e49c597e25ece8a3d499d060/seeds.npy",
    terminator=["fa.nii.gz", 0.21],
    predict_fn="mean",
    step_size=0.2,
    max_steps=500
)

In [88]:
tractogram = predict(config)

Step  500/500, finished 360701/364160 ( 99%) of all seeds.
Saving ../subjects/992774/predicted_fibers/f13ac4341f570118dd1de8dd3b424303/fibers.trk
Saving ../subjects/992774/predicted_fibers/f13ac4341f570118dd1de8dd3b424303/config.yml
