In [5]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import nibabel as nib
import nipy as ni
import numpy as np
import datetime
import shutil
import yaml
import csv
import json
import warnings

from hashlib import md5

warnings.filterwarnings('ignore')

In [24]:
def create_samples(config):
    """"""
    subject_dir = os.path.join("subjects", config["subject"])
    
    trk_path = os.path.join(subject_dir, config["trk_path"])
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    
    hasher = md5()
    hasher.update(config["sample_type"].encode())
    hasher.update(open(dwi_path, "rb").read())
    hasher.update(open(trk_path, "rb").read())
    hasher.update(str(config["reverse_samples"]).encode())
    hasher.update(str(config["max_n_samples"]).encode())
    
    save_dir = os.path.join(subject_dir, "samples", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Samples with this config have been created already:\n{}".format(save_dir))
        return
    
    tracts = nib.streamlines.load(trk_path).tractogram # coordinates in mmRAS+ space
    assert tracts.data_per_point is not None
    assert "t" in tracts.data_per_point
    
    dwi_img = ni.load_image(dwi_path)
    xyz2ijk = lambda r: dwi_img.coordmap.inverse()([r[0], r[1], r[2], 0]).round().astype(int)
    dwi = dwi_img.get_data()

    n_fibers = len(tracts)
    fiber_lengths = [len(f) for f in tracts]
    n_samples = np.sum(fiber_lengths) - 2 * n_fibers
    if config["reverse_samples"]:
        n_samples *= 2 
    n_samples = min(n_samples, config["max_n_samples"])
    
    np.random.seed(42)
    perm = np.random.permutation(len(tracts))
    tracts = tracts[perm]
    
    inputs = []
    outputs = []
    total_samples = 0
    done=False
    for fi, f in enumerate(tracts):
        fib_len = (len(f)-2) * 2 if config["reverse_samples"] else len(f)-2
        split_num = fib_len // config['seq_len']
        max_samples = split_num // 2 if config["reverse_samples"] else split_num
        n_samples = n_samples - (fib_len % config['seq_len'])
        
        tract_input = np.zeros((split_num, config['seq_len'], 18))
        tract_output = np.zeros((split_num, config['seq_len'], 3))
        for i, r in enumerate(f.streamline[1:-1]): # Exclude end points        
            try:
                idx = xyz2ijk(r) # anchor idx
                d = dwi[idx[0], idx[1], idx[2], :]
            except IndexError:
                n_samples -= (2 if config["reverse_samples"] else 1)
                print("Index error at r={}, idx={}, fiber_idx={}\n".format(r,idx,perm[fi]) +
                      "Maybe wrong reference frame, or resampling failed."
                     )
                continue
                
            vout = f.data_for_points["t"][i+1].astype("float32")
            vin = f.data_for_points["t"][i].astype("float32")

            i_pos = i // config['seq_len']
            j_pos = i % config['seq_len']
            
            if (i_pos == max_samples):
                break
            tract_input[i_pos, j_pos,...] = np.hstack([vin, d]).astype("float32")
            tract_output[i_pos, j_pos ,...] = vout
            total_samples = total_samples + 1

            if config["reverse_samples"]:
                i_pos = (split_num // 2) + (i // config['seq_len'])
                j_pos = i % config['seq_len']
                
                tract_input[i_pos, j_pos,...] = np.hstack([-vout, d]).astype("float32")
                tract_output[i_pos, j_pos ,...] = -vin
                total_samples = total_samples + 1
      
            if total_samples == n_samples:
                done = True
                break
                
        
        inputs.append(tract_input)
        outputs.append(tract_output)
        
#         split_num = int(np.ceil(len(tract_inputs) / config['seq_len']))
#         inputs_sequences = np.array_split(tract_inputs, split_num, axis=0)[0:-1]
#         output_sequences = np.array_split(tract_outputs, split_num, axis=0)[0:-1]
        
        print("Finished {:3.0f}%".format(100*total_samples/n_samples), end="\r")
        
        if done:
            break
    
    print(f'Split data into sequences of {config["seq_len"]}')
          

#     os.makedirs(save_dir)

#     save_path = os.path.join(save_dir, "samples.npz")
    
#     print("Saving {}".format(save_path))
#     np.savez_compressed(save_path, inputs=inputs, outputs=outputs)
    
#     config["n_samples"] = int(n_samples)
#     config_path = os.path.join(save_dir, "config" + ".yml")
#     print("Saving {}".format(config_path))
#     with open(config_path, "w") as file:
#             yaml.dump(config, file, default_flow_style=False)
            
    return inputs, outputs

In [25]:
config = dict(
    subject="992774",
    sample_type="rnn",
    seq_len=3,
    dwi_path = "fod.nii.gz",
    trk_path = "resampled_fibers/CC_smooth=5_npts=auto.trk",
    reverse_samples = True,
    max_n_samples = np.inf,
)

In [26]:
inputs, outputs = create_samples(config)

Split data into sequences of 3


In [18]:
np.concatenate(inputs, axis=0).shape

(1583022, 3, 18)

In [19]:
np.concatenate(outputs, axis=0).shape

(1583022, 3, 3)

In [20]:
outputs

[array([[[ 6.00002892e-02,  2.09600300e-01, -9.75944519e-01],
         [-8.93968791e-02,  8.96645768e-04, -9.95995700e-01],
         [-7.28954226e-02, -1.28514752e-01, -9.89024878e-01]],
 
        [[ 9.60098878e-02, -2.41497517e-01, -9.65640247e-01],
         [ 4.82037097e-01, -3.44285280e-01, -8.05672288e-01],
         [ 7.45795071e-01, -3.53477180e-01, -5.64662397e-01]],
 
        [[ 8.18350792e-01, -3.12358648e-01, -4.82425183e-01],
         [ 7.71273375e-01, -2.44507104e-01, -5.87668002e-01],
         [ 6.11484289e-01, -1.42057583e-01, -7.78400064e-01]],
 
        [[ 4.86085534e-01, -4.89455797e-02, -8.72539520e-01],
         [ 4.32275623e-01,  2.39724759e-02, -9.01422799e-01],
         [ 4.54422951e-01,  8.05414692e-02, -8.87137473e-01]],
 
        [[ 5.53156316e-01,  1.24342039e-01, -8.23745787e-01],
         [ 6.74955189e-01,  1.49166316e-01, -7.22623646e-01],
         [ 7.68626630e-01,  1.49086058e-01, -6.22082412e-01]],
 
        [[ 8.36543977e-01,  1.25589639e-01, -5.33311725