In [15]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import nibabel as nib
import nipy as ni
import numpy as np
import datetime
import shutil
import yaml
import csv
import json
import warnings
import time

from hashlib import md5
from scipy.interpolate import RegularGridInterpolator

warnings.filterwarnings('ignore')

In [19]:
def create_samples(config):
    """"""
    subject_dir = os.path.join("../subjects", config["subject"])
    
    trk_path = os.path.join(subject_dir, config["trk_path"])
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    
    hasher = md5()
    hasher.update(config["sample_type"].encode())
    hasher.update(open(dwi_path, "rb").read())
    hasher.update(open(trk_path, "rb").read())
    hasher.update(str(config["reverse_samples"]).encode())
    hasher.update(str(config["max_n_samples"]).encode())
    
    save_dir = os.path.join(subject_dir, "samples", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Samples with this config have been created already:\n{}".format(save_dir))
        return
    
    tracts = nib.streamlines.load(trk_path).tractogram
    assert tracts.data_per_point is not None
    assert "t" in tracts.data_per_point
    
    dwi_img = ni.load_image(dwi_path)
    xyz2ijk = lambda r: dwi_img.coordmap.inverse()([r[0], r[1], r[2], 0])[:3]
    dwi = dwi_img.get_data()

    n_fibers = len(tracts)
    fiber_lengths = [len(f) for f in tracts]
    n_samples = np.sum(fiber_lengths) - 2 * n_fibers
    if config["reverse_samples"]:
        n_samples *= 2 
    n_samples = min(n_samples, config["max_n_samples"])
    
    np.random.seed(42)
    perm = np.random.permutation(len(tracts))
    tracts = tracts[perm]
    
    inputs = []
    outputs = []
    total_samples = 0
    done=False
    for fi, f in enumerate(tracts):
        start = time.time()
        fib_len = (len(f)-2) * 2 if config["reverse_samples"] else len(f)-2
        split_num = fib_len // config['seq_len']
        max_samples = split_num // 2 if config["reverse_samples"] else split_num
        n_samples = n_samples - (fib_len % config['seq_len'])
        
        tract_input = np.zeros((split_num, config['seq_len'], 408))
        tract_output = np.zeros((split_num, config['seq_len'], 3))
        for i, r in enumerate(f.streamline[1:-1]): # Exclude end points        
            try:
                idx = xyz2ijk(r) # anchor idx
                IDX = np.round(idx).astype(int)
                
                values = np.zeros([3, 3, 3,
                                   config["block_size"], config["block_size"], config["block_size"],
                                   dwi.shape[-1]])
                
                for x in range(config["block_size"]):
                    for y in range(config["block_size"]):
                        for z in range(config["block_size"]):
                            values[x, y, z,:] = dwi[
                                IDX[0] + x - 2 * (config["block_size"] // 2) : IDX[0] + x + 1,
                                IDX[1] + y - 2 * (config["block_size"] // 2) : IDX[1] + y + 1,
                                IDX[2] + z - 2 * (config["block_size"] // 2) : IDX[2] + z + 1,
                                :]
                fn = RegularGridInterpolator(([-1,0,1],[-1,0,1],[-1,0,1]), values)
                
                d = fn([idx[0]-IDX[0], idx[1]-IDX[1], idx[2]-IDX[2]])[0]
                d = d.flatten() # to get back the spatial order: reshape(bs, bs, bs, dwi.shape[-1])
                
            except IndexError:
                n_samples -= (2 if config["reverse_samples"] else 1)
                print("Index error at r={}, idx={}, fiber_idx={}\n".format(r,idx,perm[fi]) +
                      "Maybe wrong reference frame, or resampling failed."
                     )
                continue
                
            vout = f.data_for_points["t"][i+1].astype("float32")
            vin = f.data_for_points["t"][i].astype("float32")

            i_pos = i // config['seq_len']
            j_pos = i % config['seq_len']
            
            if (i_pos == max_samples):
                break
            tract_input[i_pos, j_pos,...] = np.hstack([vin, d]).astype("float32")
            tract_output[i_pos, j_pos ,...] = vout
            total_samples = total_samples + 1

            if config["reverse_samples"]:
                i_pos = (split_num // 2) + (i // config['seq_len'])
                j_pos = i % config['seq_len']
                
                tract_input[i_pos, j_pos,...] = np.hstack([-vout, d]).astype("float32")
                tract_output[i_pos, j_pos ,...] = -vin
                total_samples = total_samples + 1
      
            if total_samples == n_samples:
                done = True
                break
                
        
        inputs.append(tract_input)
        outputs.append(tract_output)
        print("Finished {:3.0f}% in {}".format(100*total_samples/n_samples, time.time() - start), end="\r")
        
        if done:
            break         

    os.makedirs(save_dir)

    inputs = np.concatenate(inputs, axis=0)
    outputs = np.concatenate(outputs, axis=0)
    save_path = os.path.join(save_dir, "samples.npz")
    
    print("Saving {}".format(save_path))
    np.savez_compressed(save_path, inputs=inputs, outputs=outputs)
    
    config["n_samples"] = int(n_samples)
    config_path = os.path.join(save_dir, "config" + ".yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
            yaml.dump(config, file, default_flow_style=False)
            
    return inputs, outputs

In [20]:
config = dict(
    subject="992774",
    sample_type="rnn",
    seq_len=3,
    dwi_path = "fod.nii.gz",
    trk_path = "resampled_fibers/merged_w20_smooth=6_npts=auto.trk",
    reverse_samples = True,
    block_size = 3,
    max_n_samples = 1000,
)

In [21]:
inputs, outputs = create_samples(config)

Saving ../subjects/992774/samples/e09d97e5447d8a380d51369e4e4f692f/samples.npz
Saving ../subjects/992774/samples/e09d97e5447d8a380d51369e4e4f692f/config.yml


In [24]:
outputs.shape

(346, 3, 3)