In [1]:
from os import path

import nibabel as nib
from nibabel.streamlines import save as save_trk
from nibabel.streamlines import Tractogram
import numpy as np

from dipy.core.gradients import gradient_table
from dipy.data import (
    fetch_stanford_hardi, read_stanford_hardi, get_sphere)
from dipy.direction import peaks_from_model
from dipy.io import read_bvals_bvecs
from dipy.io.image import save_nifti
from dipy.reconst.dti import TensorModel
from dipy.reconst.csdeconv import (
    ConstrainedSphericalDeconvModel,auto_response)
from dipy.segment.mask import median_otsu
from dipy.tracking.local import (
    LocalTracking, ThresholdTissueClassifier)
from dipy.tracking.streamline import Streamlines
from dipy.tracking.utils import random_seeds_from_mask
from dipy.viz import actor, window

## Load Diffusion Data

In [2]:
dwi_img_filename = ("dwi.nii.gz")
dwi_img = nib.load(dwi_img_filename)
dwi = dwi_img.get_data()
print("Data Shape: " + str(dwi.shape))

fbval=("dwi.bval")
fbvec = ("dwi.bvec")
bvals, bvecs = read_bvals_bvecs(fbval, fbvec)
gtab = gradient_table(bvals,bvecs)
print("B-Values: \n" + str(gtab.bvals))

recompute=False

Data Shape: (140, 140, 96, 69)
B-Values: 
[    0.     0.     0.     0.     0.  1000.  1000.  1000.  1000.  1000.
  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.
  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.
  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.
  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.
  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.
  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.  1000.]


## Compute Brain Mask

In [3]:
brain_mask_filename = "brain_mask.nii.gz"
if path.exists(brain_mask_filename) and not recompute:
    brain_mask_img = nib.load(brain_mask_filename).get_data()
else:
    recompute = True
    _, brain_mask_img = median_otsu(dwi, 4, 1)
    save_nifti(brain_mask_filename, brain_mask_img.astype("uint8"),
               dwi_img.affine)

## Compute DTI

In [4]:
fa_filename = "fa.nii.gz"
if path.exists(fa_filename) and not recompute:
    fa = nib.load(fa_filename).get_data()
else:
    recompute = True
    tensor_model = TensorModel(gtab, fit_method='WLS')
    tensor_fit = tensor_model.fit(dwi, brain_mask_img)
    fa = tensor_fit.fa
    save_nifti(fa_filename, fa, dwi_img.affine)

## Compute CSD

In [5]:
import pickle

csd_peaks_filename = "csd_peaks.pkl"
if path.exists(csd_peaks_filename) and not recompute:
    pkl_file = open(csd_peaks_filename, "rb")
    csd_peaks = pickle.load(pkl_file)
    pkl_file.close
else:
    recompute = True
    response, ratio = auto_response(
        gtab, dwi, roi_radius=10, fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response)
    sphere = get_sphere('symmetric724')
    csd_peaks = peaks_from_model(
        model=csd_model, data=dwi, sphere=sphere, mask=brain_mask_img,
        relative_peak_threshold=.5, min_separation_angle=25,
        parallel=True)
    pkl_file = open(csd_peaks_filename, "wb")
    pickle.dump(csd_peaks, pkl_file)
    pkl_file.close

## CSD Visualization

In [6]:
interactive = True
ren = window.Renderer()
ren.add(actor.peak_slicer(
    csd_peaks.peak_dirs, csd_peaks.peak_values, colors=None))

if interactive:
    window.show(ren, size=(900, 900))
else:
    window.record(
        ren, out_path='csd_direction_field.png', size=(900, 900))

  orient = np.abs(orient / np.linalg.norm(orient))


## Compute Streamlines

In [7]:
tractogram_filename = "det_csd_streamlines.trk"
if path.exists(tractogram_filename) and not recompute:
    streamlines = nib.streamlines.load(tractogram_filename).streamlines
else:
    tissue_classifier = ThresholdTissueClassifier(fa, 0.1)
    seeds = random_seeds_from_mask(fa > 0.5, seeds_count=1)
    streamline_generator = LocalTracking(
        csd_peaks, tissue_classifier, seeds, affine=np.eye(4),
        step_size=0.5)

    streamlines = Streamlines(streamline_generator)
    save_trk(
        Tractogram(streamlines, affine_to_rasmm=dwi_img.affine),
        tractogram_filename)
print("Number of streamlines: " + str(len(streamlines)))

Number of streamlines: 88155


## Streamlines Visualization

In [8]:
ren.clear()
ren.add(actor.line(streamlines))

if interactive:
    window.show(ren, size=(900, 900))
else:
    print('Saving illustration as det_streamlines.png')
    window.record(ren, out_path='det_streamlines.png', size=(900, 900))

  orient = np.abs(orient / np.linalg.norm(orient))
