In [8]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)
from tensorflow.compat.v1 import set_random_seed
set_random_seed(42)

import nibabel as nib
import nipy as ni
import numpy as np
import datetime
import shutil
import yaml
import csv
import json
import warnings

from hashlib import md5

warnings.filterwarnings('ignore')

In [9]:
def create_samples(config):
    """"""
    subject_dir = os.path.join("subjects", config["subject"])
    
    trk_path = os.path.join(subject_dir, config["trk_path"])
    dwi_path = os.path.join(subject_dir, config["dwi_path"])
    
    hasher = md5()
    hasher.update(config["sample_type"].encode())
    hasher.update(open(dwi_path, "rb").read())
    hasher.update(open(trk_path, "rb").read())
    hasher.update(str(config["reverse_samples"]).encode())
    hasher.update(str(config["max_n_samples"]).encode())
    
    save_dir = os.path.join(subject_dir, "samples", hasher.hexdigest())
    if os.path.exists(save_dir):
        print("Samples with this config have been created already:\n{}".format(save_dir))
        return
    
    tracts = nib.streamlines.load(trk_path).tractogram # coordinates in mmRAS+ space
    assert tracts.data_per_point is not None
    assert "t" in tracts.data_per_point
    
    dwi_img = ni.load_image(dwi_path)
    xyz2ijk = lambda r: dwi_img.coordmap.inverse()([r[0], r[1], r[2], 0]).round().astype(int)
    dwi = dwi_img.get_data()

    n_fibers = len(tracts)
    fiber_lengths = [len(f) for f in tracts]
    n_samples = np.sum(fiber_lengths) - 2 * n_fibers
    if config["reverse_samples"]:
        n_samples *= 2 
    n_samples = min(n_samples, config["max_n_samples"])
    
    np.random.seed(42)
    perm = np.random.permutation(len(tracts))
    tracts = tracts[perm]
    
    inputs = []
    outputs = []
    done=False
    for fi, f in enumerate(tracts):  
        for i, r in enumerate(f.streamline[1:-1]): # Exclude end points
            
            try:
                idx = xyz2ijk(r) # anchor idx
                d = dwi[idx[0], idx[1], idx[2], :]
            except IndexError:
                n_samples -= (2 if config["reverse_samples"] else 1)
                print("Index error at r={}, idx={}, fiber_idx={}\n".format(r,idx,perm[fi]) +
                      "Maybe wrong reference frame, or resampling failed."
                     )
                continue
                
            vout = f.data_for_points["t"][i+1].astype("float32")
            vin = f.data_for_points["t"][i].astype("float32")

            outputs.append(vout)
            inputs.append(np.hstack([vin, d]).astype("float32"))

            if config["reverse_samples"]:
                inputs.append(np.hstack([-vout, d]).astype("float32"))
                outputs.append(-vin)

            if len(inputs) == n_samples:
                done = True
                break

        print("Finished {:3.0f}%".format(100*len(inputs)/n_samples), end="\r")

        if done:
            break

    assert n_samples == len(inputs)
    assert n_samples == len(outputs)
    assert inputs[0].shape == (3 + dwi_img.shape[-1], )
    assert outputs[0].shape == (3, )

    os.makedirs(save_dir)

    save_path = os.path.join(save_dir, "samples.npz")
    
    print("Saving {}".format(save_path))
    np.savez_compressed(save_path, inputs=inputs, outputs=outputs)
    
    config["n_samples"] = int(n_samples)
    config_path = os.path.join(save_dir, "config" + ".yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
            yaml.dump(config, file, default_flow_style=False)
            
    return inputs, outputs

In [10]:
config = dict(
    subject="992774",
    sample_type="conditional_t",
    dwi_path = "fod.nii.gz",
    trk_path = "resampled_fibers/CC_smooth=5_npts=auto.trk",
    reverse_samples = True,
    max_n_samples = np.inf,
)

In [11]:
inputs, outputs = create_samples(config)

Saving subjects/992774/samples/fa7c02604b92de5f32cd3b61dbc2f8b7/samples.npz
Saving subjects/992774/samples/fa7c02604b92de5f32cd3b61dbc2f8b7/config.yml
