-
Notifications
You must be signed in to change notification settings - Fork 34
/
dti.py
142 lines (113 loc) · 4.41 KB
/
dti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import os.path as op
import numpy as np
import nibabel as nib
from dipy.core.geometry import vector_norm
from dipy.reconst import dti
import AFQ.utils.models as ut
__all__ = ["fit_dti", "predict", "tensor_odf"]
def _fit(gtab, data, mask=None):
dtimodel = dti.TensorModel(gtab)
return dtimodel.fit(data, mask=mask)
def fit_dti(data_files, bval_files, bvec_files, mask=None,
out_dir=None, file_prefix=None, b0_threshold=0):
"""
Fit the DTI model using default settings, save files with derived maps
Parameters
----------
data_files : str or list
Files containing DWI data. If this is a str, that's the full path to a
single file. If it's a list, each entry is a full path.
bval_files : str or list
Equivalent to `data_files`.
bvec_files : str or list
Equivalent to `data_files`.
mask : ndarray, optional
Binary mask, set to True or 1 in voxels to be processed.
Default: Process all voxels.
out_dir : str, optional
A full path to a directory to store the maps that get computed.
Default: maps get stored in the same directory as the last DWI file
in `data_files`.
b0_threshold : float
Returns
-------
file_paths : a dict with the derived maps that were computed and full-paths
to the files containing these maps.
Note
----
Maps that are calculated: FA, MD, AD, RD
"""
img, data, gtab, mask = ut.prepare_data(data_files, bval_files,
bvec_files, mask=mask,
b0_threshold=b0_threshold)
# In this case, we dump the fit object
dtf = _fit(gtab, data, mask=None)
FA, MD, AD, RD, params = dtf.fa, dtf.md, dtf.ad, dtf.rd, dtf.model_params
maps = [FA, MD, AD, RD, params]
names = ['FA', 'MD', 'AD', 'RD', 'params']
if out_dir is None:
if isinstance(data_files, list):
out_dir = op.join(op.split(data_files[0])[0], 'dti')
else:
out_dir = op.join(op.split(data_files)[0], 'dti')
if file_prefix is None:
file_prefix = ''
if not op.exists(out_dir):
os.makedirs(out_dir)
aff = img.affine
file_paths = {}
for m, n in zip(maps, names):
file_paths[n] = op.join(out_dir, file_prefix + 'dti_%s.nii.gz' % n)
nib.save(nib.Nifti1Image(m, aff), file_paths[n])
return file_paths
def predict(params_file, gtab, S0_file=None, out_dir=None):
"""
Create a signal prediction from DTI params
params_file : str
Full path to a file with parameters saved from a DKI fit
gtab : GradientTable object
The gradient table to predict for
S0_file : str
Full path to a nifti file that contains S0 measurements to incorporate
into the prediction. If the file contains 4D data, the volumes that
contain the S0 data must be the same as the gtab.b0s_mask.
"""
if out_dir is None:
out_dir = op.join(op.split(params_file)[0])
if S0_file is None:
S0 = 100
else:
S0 = nib.load(S0_file).get_data()
# If the S0 data is 4D, we assume it comes from an acquisition that had
# B0 measurements in the same volumes described in the gtab:
if len(S0.shape) == 4:
S0 = np.mean(S0[..., gtab.b0s_mask], -1)
# Otherwise, we assume that it's already a 3D volume, and do nothing
img = nib.load(params_file)
params = img.get_data()
pred = dti.tensor_prediction(params, gtab, S0=S0)
fname = op.join(out_dir, 'dti_prediction.nii.gz')
nib.save(nib.Nifti1Image(pred, img.affine), fname)
return fname
def tensor_odf(evals, evecs, sphere):
"""
Calculate the tensor Orientation Distribution Function
Parameters
----------
evals : array (4D)
Eigenvalues of a tensor. Shape (x, y, z, 3).
evecs : array (5D)
Eigenvectors of a tensor. Shape (x, y, z, 3, 3)
sphere : sphere object
The ODF will be calculated in each vertex of this sphere.
"""
odf = np.zeros((evals.shape[:3] + (sphere.vertices.shape[0],)))
mask = np.where((evals[..., 0] > 0)
& (evals[..., 1] > 0)
& (evals[..., 2] > 0))
lower = 4 * np.pi * np.sqrt(np.prod(evals[mask], -1))
projection = np.dot(sphere.vertices, evecs[mask])
projection /= np.sqrt(evals[mask])
odf[mask] = ((vector_norm(projection) ** -3) / lower).T
return odf