In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import yaml
import os
import datetime

import nipy as ni
import nibabel as nib
import numpy as np

from tensorflow.keras.models import load_model
from GPUtil import getFirstAvailable
from tensorflow.keras import backend as K
from nibabel.streamlines.tck import TckFile
from nibabel.streamlines.array_sequence import ArraySequence, concatenate
from nibabel.streamlines.tractogram import Tractogram
from hashlib import md5

import tensorflow_probability as tfp

import warnings
warnings.filterwarnings('ignore')

In [8]:
def predict(config):
    """
    seeds: [[x1,y1,z1],...,[xn,yn,zn]]
    """
    # Load Data
    
    subject_dir = os.path.join("..", "subjects", config["subject"])
    
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    model_path = os.path.join(config["model_dir"], "model.h5")
    
    hasher = md5()
    for k in config.keys():
        hasher.update(str(k).encode())
    
    save_dir = os.path.join(subject_dir, "predicted_fibers", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Predictions with this config have been created already:\n{}".format(save_dir))
        return
    
    dwi_img = ni.load_image(dwi_path)
    dwi = dwi_img.get_data()
    
    def ijk2xyz(indices):
        indices = np.hstack([indices, np.zeros((len(indices),1))])
        fn = lambda idx: dwi_img.coordmap(idx)[:3]
        return np.array(list(map(fn, indices)))
    
    def xyz2ijk(coords):
        coords = np.hstack([coords, np.zeros((len(coords),1))])
        fn = lambda coord: dwi_img.coordmap.inverse()(coord).round().astype(int)[:3]
        return np.array(list(map(fn, coords)))
    
    # Define Fiber Termination
    
    class Terminator(object):
        def __init__(self):
            if ".nii" in config["terminator"][0]:
                scalar_img = ni.load_image(os.path.join(subject_dir, config["terminator"][0]))
                self.scalar = scalar_img.get_data()
            else:
                raise NotImplementedError # TODO: Implement termination model
                
            self.threshold = config["terminator"][1]
                
        def __call__(self, ijk):
            if hasattr(self, "scalar"):
                return np.where(self.scalar[ijk[:,0],ijk[:,1],ijk[:,2]] < self.threshold)[0]
            else:
                raise NotImplementedError

    terminator = Terminator()
    
    seeds = np.load(os.path.join(subject_dir, config["seeds"]))
    
    # Define First Step Prior
    
    class Prior(object):
        def __init__(self):
            if ".nii" in config["prior"]:
                peak_img = ni.load_image(os.path.join(subject_dir, config["prior"]))
                self.peak = peak_img.get_data()
            elif ".h5" in config:
                raise NotImplementedError # TODO: Implement prior model
                
        def __call__(self, xyz):
            if hasattr(self, "peak"):
                ijk = xyz2ijk(xyz)
                return self.peak[ijk[:,0],ijk[:,1],ijk[:,2]]
            elif hasattr(self, "model"):
                raise NotImplementedError # TODO: Implement prior model
        
    prior = Prior()
    
    # Load Model
    
    def negative_log_likelihood(observed_y, predicted_distribution):
        return -K.mean(predicted_distribution.log_prob(observed_y))
    
    model = load_model(model_path,
                       custom_objects={"negative_log_likelihood": negative_log_likelihood,
                                       "DistributionLambda": tfp.layers.DistributionLambda})
    
    # Begin Iterative Prediction
    
    ijk = xyz2ijk(seeds)
    xyz = seeds.reshape(-1, 1, 3) # [fiber, pt, coo]
    fibers = []
    for i in range(config["max_steps"]):
        
        d = dwi[ijk[:,0], ijk[:,1], ijk[:,2], :]
        
        if i == 0:
            vin = prior(xyz[:,0,:])
        else:
            vin = vout
        
        if config["predict_fn"] == "mean":
            vout = model(np.hstack([vin,d])).mean()
        else:
            vout = model(np.hstack([vin,d])).sample()
            
        rout = (xyz[:,-1,:] + config["step_size"] * vout).reshape(-1, 1, 3)
        
        xyz = np.concatenate((xyz, rout), axis=1)
        
        ijk = xyz2ijk(xyz[:,-1,:])
        
        terminal_indices = terminator(ijk)

        for idx in terminal_indices:
            fibers.append(xyz[idx])
        
        xyz = np.delete(xyz, terminal_indices, axis=0)
        ijk = xyz2ijk(xyz[:,-1,:])
        
        print("Finished {:3.0f}%".format(100*(i+1)/config["max_steps"]), end="\r")
    
        if len(fibers) == len(seeds):
            break
    
    # Save Result
    
    os.makedirs(save_dir, exist_ok=True)

    fiber_path = os.path.join(save_dir, "fibers.trk")
    print("Saving {}".format(fiber_path))
    nib.streamlines.save(tractogram, fiber_path)

    config_path = os.path.join(save_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False) 
    
    return tractogram

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = str(getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0])

In [4]:
config = dict(
    subject="992774",
    model_dir="../models/entrack_conditional/7b5ee612a5df9f2447a27eff48382649",
    dwi_path="fod.nii.gz",
    prior="peak.nii.gz",
    seeds="seeds/bcc2e734e49c597e25ece8a3d499d060/seeds.npy",
    terminator=["fa.nii.gz", 0.5],
    predict_fn="mean",
    step_size=0.5,
    max_steps=500
)

In [9]:
tractogram = predict(config)

ResourceExhaustedError: OOM when allocating tensor with shape[182080,182080,3] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Mul] name: entrack_conditional/distribution_lambda/VonMisesFisher/mean/mul/