Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easy tests reconst #982

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions scilpy/reconst/fodf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@
cvx, have_cvxpy, _ = optional_package("cvxpy")


def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
fa_threshold, md_threshold, is_legacy=True):
def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis,
fa_threshold, md_threshold,
small_dims=False, is_legacy=True):
"""
Compute mean maximal fodf value in ventricules. Given heuristics thresholds
on FA and MD values, finds the voxels of the ventricules or CSF and
computes a mean fODF value. This is described in
Dell'Acqua et al. HBM 2013.

Ventricles are searched in a window in the middle of the data to increase
speed. No need to scan the whole image.

Parameters
----------
data: ndarray (x, y, z, ncoeffs)
Input fODF file in spherical harmonics coefficients.
Input fODF file in spherical harmonics coefficients. Uses sphere
'repulsion100' to convert to SF values.
fa: ndarray (x, y, z)
FA (Fractional Anisotropy) volume from DTI
md: ndarray (x, y, z)
Expand All @@ -38,15 +43,16 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
Either 'tournier07' or 'descoteaux07'
small_dims: bool
If set, takes the full range of data to search the max fodf amplitude
in ventricles. Useful when the data has small dimensions.
in ventricles, rather than a window center in the data. Useful when the
data has small dimensions.
fa_threshold: float
Maximal threshold of FA (voxels under that threshold are considered
for evaluation).
for evaluation). Suggested value: 0.1.
md_threshold: float
Minimal threshold of MD in mm2/s (voxels above that threshold are
considered for evaluation).
considered for evaluation). Suggested value: 0.003.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
Whether the SH basis is in its legacy form.

Returns
-------
Expand All @@ -57,26 +63,17 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
order = find_order_from_nb_coeff(data)
sphere = get_sphere('repulsion100')
b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=is_legacy)
sum_of_max = 0
count = 0

mask = np.zeros(data.shape[:-1])

if np.min(data.shape[:-1]) > 40:
step = 20
else:
if np.min(data.shape[:-1]) > 20:
step = 10
else:
step = 5

# 1000 works well at 2x2x2 = 8 mm3
# Hence, we multiply by the volume of a voxel
vol = (zoom[0] * zoom[1] * zoom[2])
if vol != 0:
max_number_of_voxels = 1000 * 8 // vol
else:
max_number_of_voxels = 1000
logging.debug("Searching for ventricle voxels, up to a maximum of {} "
"voxels.".format(max_number_of_voxels))

# In the case of 2D-like data (3D data with one dimension size of 1), or
# a small 3D dataset, the full range of data is scanned.
Expand All @@ -85,14 +82,27 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
all_j = list(range(0, data.shape[1]))
all_k = list(range(0, data.shape[2]))
# In the case of a normal 3D dataset, a window is created in the middle of
# the image to capture the ventricules. No need to scan the whole image.
# the image to capture the ventricles. No need to scan the whole image.
# (Automatic definition of window's radius based on the shape of the data.)
else:
all_i = list(range(int(data.shape[0]/2) - step,
int(data.shape[0]/2) + step))
all_j = list(range(int(data.shape[1]/2) - step,
int(data.shape[1]/2) + step))
all_k = list(range(int(data.shape[2]/2) - step,
int(data.shape[2]/2) + step))
if np.min(data.shape[:-1]) > 40:
radius = 20
else:
if np.min(data.shape[:-1]) > 20:
radius = 10
else:
radius = 5

all_i = list(range(int(data.shape[0]/2) - radius,
int(data.shape[0]/2) + radius))
all_j = list(range(int(data.shape[1]/2) - radius,
int(data.shape[1]/2) + radius))
all_k = list(range(int(data.shape[2]/2) - radius,
int(data.shape[2]/2) + radius))

# Ok. Now find ventricle voxels.
sum_of_max = 0
count = 0
for i in all_i:
for j in all_j:
for k in all_k:
Expand Down
52 changes: 34 additions & 18 deletions scilpy/reconst/frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
def compute_ssst_frf(data, bvals, bvecs, b0_threshold=DEFAULT_B0_THRESHOLD,
mask=None, mask_wm=None, fa_thresh=0.7, min_fa_thresh=0.5,
min_nvox=300, roi_radii=10, roi_center=None):
"""Compute a single-shell (under b=1500), single-tissue single Fiber
Response Function from a DWI volume.
A DTI fit is made, and voxels containing a single fiber population are
found using a threshold on the FA.
"""
Computes a single-shell (under b=1500), single-tissue single Fiber
Response Function from a DWI volume. A DTI fit is made, and voxels
containing a single fiber population are found using either a threshold on
the FA, inside a white matter mask.

Parameters
----------
Expand All @@ -43,7 +44,7 @@ def compute_ssst_frf(data, bvals, bvecs, b0_threshold=DEFAULT_B0_THRESHOLD,
3D mask with shape (X,Y,Z)
Binary white matter mask. Only the data inside this mask and above the
threshold defined by fa_thresh will be used to estimate the fiber
response function.
response function. If not given, all voxels inside `mask` will be used.
fa_thresh : float, optional
Use this threshold as the initial threshold to select single fiber
voxels. Defaults to 0.7
Expand All @@ -63,7 +64,7 @@ def compute_ssst_frf(data, bvals, bvecs, b0_threshold=DEFAULT_B0_THRESHOLD,

Returns
-------
full_reponse : ndarray
full_response : ndarray
Fiber Response Function, with shape (4,)

Raises
Expand Down Expand Up @@ -139,10 +140,11 @@ def compute_msmt_frf(data, bvals, bvecs, btens=None, data_dti=None,
fa_thr_wm=0.7, fa_thr_gm=0.2, fa_thr_csf=0.1,
md_thr_gm=0.0007, md_thr_csf=0.003, min_nvox=300,
roi_radii=10, roi_center=None, tol=20):
"""Compute a multi-shell, multi-tissue single Fiber
Response Function from a DWI volume.
A DTI fit is made, and voxels containing a single fiber population are
found using a threshold on the FA and MD.
"""
Computes a multi-shell, multi-tissue single Fiber Response Function from a
DWI volume. A DTI fit is made, and voxels containing a single fiber
population are found using a threshold on the FA and MD, inside a mask of
each tissue type.

Parameters
----------
Expand Down Expand Up @@ -304,33 +306,47 @@ def compute_msmt_frf(data, bvals, bvecs, btens=None, data_dti=None,
return responses, frf_masks


def replace_frf(old_frf, new_frf, no_factor):
def replace_frf(old_frf, new_frf, no_factor=False):
"""
Replace old_frf with new_frf
Replaces the 3 first values of old_frf with new_frf. Formats the new_frf
from a string value and verifies that the number of shells corresponds.

Parameters
----------
old_frf: np.ndarray
A loaded frf file, of shape (n, 4).
new_frf: tuple
The new frf, to be interpreted with a 10**-4 factor. Ex: (15,4,4)
A loaded frf file, of shape (N, 4), where N is the number of shells.
new_frf: str
The new frf, to be interpreted with a 10**-4 factor. Ex: 15,4,4. With
multishell: all values, concatenated into one string.
Ex: 15,4,4,13,5,5,12,5,5.
no_factor: bool
If true, the fiber response function is evaluated without the
10**-4 factor.

Returns
-------
response: np.ndarray
Formatted new frf, of shape (n, 4)
"""
old_frf = old_frf.T
new_frf = np.array(literal_eval(new_frf), dtype=np.float64)
if len(old_frf.shape) == 1: # When loading from one shell, we get (4, )
old_frf = old_frf[None, :]
old_nb_shells = old_frf.shape[0]
b0_mean = old_frf[:, 3]

new_frf = np.array(literal_eval(new_frf), dtype=np.float64)
if not no_factor:
new_frf *= 10 ** -4
b0_mean = old_frf[3]

if new_frf.shape[0] % 3 != 0:
raise ValueError('Inputed new frf is not valid. There should be '
'three values per shell, and thus the total number '
'of values should be a multiple of three.')

nb_shells = int(new_frf.shape[0] / 3)
if nb_shells != old_nb_shells:
raise ValueError("The old frf contained {} shell(s). Cannot replace "
"with {} shell(s).".format(old_nb_shells, nb_shells))

new_frf = new_frf.reshape((nb_shells, 3))

response = np.empty((nb_shells, 4))
Expand Down
36 changes: 34 additions & 2 deletions scilpy/reconst/tests/test_fodf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
# -*- coding: utf-8 -*-
import numpy as np
from dipy.data import get_sphere
from dipy.reconst.shm import sh_to_sf_matrix

from scilpy.reconst.fodf import get_ventricles_max_fodf
from scilpy.reconst.utils import find_order_from_nb_coeff
from scilpy.tests.arrays import fodf_3x3_order8_descoteaux07


def test_get_ventricles_max_fodf():
# toDO
pass
fake_fa = np.ones((3, 3, 1)) # High FA
fake_fa[1:3, 0:2, 0] = 0 # Low in ventricles
fake_md = np.zeros((3, 3, 1)) # Low MD
fake_md[0:2, 0:2, 0] = 1 # High in ventricles
zoom = [1, 1, 1]
fa_threshold = 0.5
md_threshold = 0.5
sh_basis = 'descoteaux07'

# Should find that the only 2 ventricle voxels are at [1, 0:2, 0]
mean, mask = get_ventricles_max_fodf(
fodf_3x3_order8_descoteaux07, fake_fa, fake_md, zoom, sh_basis,
fa_threshold, md_threshold, small_dims=True)

expected_mask = np.logical_and(~fake_fa.astype(bool), fake_md)
assert np.count_nonzero(mask) == 2
assert np.array_equal(mask.astype(bool), expected_mask)

# Reconstruct SF values same as in method.
order = find_order_from_nb_coeff(fodf_3x3_order8_descoteaux07)
sphere = get_sphere('repulsion100')
b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=True)

sf1 = np.dot(fodf_3x3_order8_descoteaux07[1, 0, 0], b_matrix)
sf2 = np.dot(fodf_3x3_order8_descoteaux07[1, 1, 0], b_matrix)

assert mean == np.mean([np.max(sf1), np.max(sf2)])


def test_fit_from_model():
Expand Down
75 changes: 69 additions & 6 deletions scilpy/reconst/tests/test_frf.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,79 @@
# -*- coding: utf-8 -*-
import os
import tempfile

import nibabel as nib
import numpy as np
from dipy.io import read_bvals_bvecs

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.reconst.frf import compute_ssst_frf, compute_msmt_frf, replace_frf

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['processing.zip'])
tmp_dir = tempfile.TemporaryDirectory()
in_dwi = os.path.join(SCILPY_HOME, 'processing', 'dwi_crop.nii.gz')
in_bval = os.path.join(SCILPY_HOME, 'processing', 'dwi.bval')
in_bvec = os.path.join(SCILPY_HOME, 'processing', 'dwi.bvec')


def test_compute_ssst_frf():
# toDO
pass
# Uses data from our test data.
# To use a smaller subset, we need to ensure that it has at least one
# voxel with FA higher than 0.7. Quite fast as is, so, ok.
dwi = nib.load(in_dwi).get_fdata() # Shape: 57, 67, 56, 64
bvals, bvecs = read_bvals_bvecs(in_bval, in_bvec)

result = compute_ssst_frf(dwi, bvals, bvecs)

# Value with current data at the date of test creation:
expected_result = [1.03068237e-03, 2.44994949e-04,
2.44994949e-04, 3.26903486e+03]
assert np.allclose(result, expected_result)


def test_compute_msmt_frf():
# toDO
pass
# Uses data from our test data.
# To use a smaller subset, we need to ensure that it has at least one
# voxel with each tissue type.
dwi = nib.load(in_dwi).get_fdata() # Shape: 57, 67, 56, 64
bvals, bvecs = read_bvals_bvecs(in_bval, in_bvec)

responses, masks = compute_msmt_frf(dwi, bvals, bvecs)

# Value with current data at the date of test creation:
expected_result_wm = [[1.56925332e-03, 4.68706503e-04,
4.68706503e-04, 3.26903486e+03],
[1.15181122e-03, 3.75303294e-04,
3.75303294e-04, 3.26903486e+03],
[8.61299793e-04, 3.14541494e-04,
3.14541494e-04, 3.26903486e+03]]
expected_result_gm = [[9.74471606e-04, 8.34628732e-04,
8.34628732e-04, 3.42007686e+03],
[7.76991313e-04, 6.89550835e-04,
6.89550835e-04, 3.42007686e+03],
[6.26617550e-04, 5.73389066e-04,
5.73389066e-04, 3.42007686e+03]]
expected_result_csf = [[9.33140592e-04, 8.31445917e-04,
8.31445917e-04, 3.62805637e+03],
[7.69894406e-04, 7.07255607e-04,
7.07255607e-04, 3.62805637e+03],
[6.34735398e-04, 5.96451860e-04,
5.96451860e-04, 3.62805637e+03]]
assert np.allclose(responses[0], expected_result_wm)
assert np.allclose(responses[1], expected_result_gm)
assert np.allclose(responses[2], expected_result_csf)

assert np.count_nonzero(masks[0]) == 845 # wm
assert np.count_nonzero(masks[1]) == 1779 # gm
assert np.count_nonzero(masks[2]) == 449 # csf


def test_replace_frf():
# toDo
pass
old_frf = np.random.rand(4)
new_frf = "15,4,4"
result = replace_frf(old_frf, new_frf, no_factor=True)

# Rounds to float64
assert np.allclose(result, [15, 4, 4, old_frf[-1]])
4 changes: 2 additions & 2 deletions scripts/scil_fodf_max_in_ventricles.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def main():
sh_basis, is_legacy = parse_sh_basis_arg(args)

value, mask = get_ventricles_max_fodf(fodf, fa, md, zoom, sh_basis,
args.small_dims, args.fa_threshold,
args.md_threshold,
args.fa_threshold, args.md_threshold,
small_dims=args.small_dims,
is_legacy=is_legacy)

if args.mask_output:
Expand Down
9 changes: 3 additions & 6 deletions scripts/tests/test_frf_set_diffusivities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,23 @@ def test_help_option(script_runner):

def test_execution_processing_ssst(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(SCILPY_HOME, 'processing',
'frf.txt')
in_frf = os.path.join(SCILPY_HOME, 'processing', 'frf.txt')
ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf,
'15,4,4', 'new_frf.txt', '-f')
assert ret.success


def test_execution_processing_msmt(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(SCILPY_HOME, 'commit_amico',
'wm_frf.txt')
in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt')
ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf,
'15,4,4,13,4,4,12,5,5', 'new_frf.txt', '-f')
assert ret.success


def test_execution_processing__wrong_input(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(SCILPY_HOME, 'commit_amico',
'wm_frf.txt')
in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt')
ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf,
'15,4,4,13,4,4', 'new_frf.txt', '-f')
assert not ret.success
Loading