In [2]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import warnings
warnings.filterwarnings('ignore')

import nibabel as nib
import nipy as ni
import numpy as np
import datetime
import shutil
import yaml
import csv
import json

from hashlib import md5
from scipy.interpolate import RegularGridInterpolator

In [59]:
def create_samples(config):
    """"""
    hasher = md5()
    for v in config.values():
        hasher.update(str(v).encode())
    
    save_dir = os.path.join(subject_dir, "samples", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Samples with this config have been created already:\n{}".format(save_dir))
        return None, None
    
    if config["block_size"] % 2 == 0:
        raise ValueError("block_size must be an odd number (1, 3, 5,...)")
        
    if config["reverse_samples"] and config["max_n_samples"] % 2 != 0:
        raise ValueError("max_n_samples can not be an odd number for reverse_samples == True.")
    
    subject_dir = os.path.join("subjects", config["subject"])
    
    trk_path = os.path.join(subject_dir, config["trk_path"])
    trk_file = nib.streamlines.load(trk_path)
    assert trk_file.tractogram.data_per_point is not None
    assert "t" in trk_file.tractogram.data_per_point
    
    #=================================================
    
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    dwi_img = nib.load(dwi_path)
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi_xyz2ijk = lambda r: dwi_affi.dot([r[0], r[1], r[2], 1])[:3]
    dwi = dwi_img.get_data()

    tracts = trk_file.tractogram # fiber coordinates in rasmm

    n_fibers = len(tracts)
    fiber_lengths = [len(f) for f in tracts]
    n_samples = np.sum(fiber_lengths) - 2 * n_fibers
    if config["reverse_samples"]:
        n_samples *= 2 
    n_samples = min(n_samples, config["max_n_samples"])
    
    np.random.seed(42)
    perm = np.random.permutation(len(tracts))
    tracts = tracts[perm]
    
    #=================================================
    
    inputs = []
    outputs = []
    done=False
    for fi, f in enumerate(tracts):  
        for i, r in enumerate(f.streamline[1:-1]): # Exclude end points for conditional model
            try:
                idx = dwi_xyz2ijk(r) # anchor idx
                IDX = np.round(idx).astype(int)
                
                values = np.zeros([3, 3, 3,
                                   config["block_size"], config["block_size"], config["block_size"],
                                   dwi.shape[-1]])
                
                for x in range(config["block_size"]):
                    for y in range(config["block_size"]):
                        for z in range(config["block_size"]):
                            values[x, y, z,:] = dwi[
                                IDX[0] + x - 2 * (config["block_size"] // 2) : IDX[0] + x + 1,
                                IDX[1] + y - 2 * (config["block_size"] // 2) : IDX[1] + y + 1,
                                IDX[2] + z - 2 * (config["block_size"] // 2) : IDX[2] + z + 1,
                                :]
                fn = RegularGridInterpolator(([-1,0,1],[-1,0,1],[-1,0,1]), values)
                
                d = fn([idx[0]-IDX[0], idx[1]-IDX[1], idx[2]-IDX[2]])[0]
                d = d.flatten() # to get back the spatial order: reshape(bs, bs, bs, dwi.shape[-1])
                
            except IndexError:
                n_samples -= (2 if config["reverse_samples"] else 1)
                print("Index error at r={}, idx={}, fiber_idx={}\n".format(r,idx,perm[fi]) +
                      "Maybe wrong reference frame, or resampling failed."
                     )
                continue
                
            vout = f.data_for_points["t"][i+1].astype("float32")
            vin = f.data_for_points["t"][i].astype("float32")

            outputs.append(vout)
            inputs.append(np.hstack([vin, d]).astype("float32"))

            if config["reverse_samples"]:
                inputs.append(np.hstack([-vout, d]).astype("float32"))
                outputs.append(-vin)

            if len(inputs) == n_samples:
                done = True
                break

        print("Finished {:3.0f}%".format(100*len(inputs)/n_samples), end="\r")

        if done:
            break

    assert n_samples == len(inputs)
    assert n_samples == len(outputs)
    assert inputs[0].shape == (3 + dwi_img.shape[-1] * config["block_size"]**3, )
    assert outputs[0].shape == (3, )

    os.makedirs(save_dir)

    save_path = os.path.join(save_dir, "samples.npz")
    
    print("\nSaving {}".format(save_path))
    np.savez_compressed(save_path, inputs=inputs, outputs=outputs)
    
    config["n_samples"] = int(n_samples)
    config_path = os.path.join(save_dir, "config" + ".yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
            yaml.dump(config, file, default_flow_style=False)
            
    return inputs, outputs

In [64]:
config = dict(
    subject="917255",
    sample_type="conditional_t",
    dwi_path = "fod_norm.nii.gz",
    trk_path = "resampled_fibers/merged_w1_smooth=5_npts=auto.trk",
    block_size = 1,
    reverse_samples = True,
    max_n_samples = 511,
)

In [65]:
inputs, outputs = create_samples(config)

ValueError: max_n_samples can not be an odd number for reverse_samples == True.