Skip to content

Commit

Permalink
Added docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
frheault committed May 16, 2024
1 parent 11f218b commit e771d98
Show file tree
Hide file tree
Showing 4 changed files with 524 additions and 4 deletions.
364 changes: 361 additions & 3 deletions scilpy/tractograms/tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import logging
import random

from dipy.io.stateful_tractogram import StatefulTractogram, Space
from dipy.io.stateful_tractogram import set_sft_logger_level, \
StatefulTractogram, Space
from dipy.io.utils import get_reference_info, is_header_compatible
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.streamline import transform_streamlines
Expand All @@ -24,9 +25,14 @@
from scipy.ndimage import map_coordinates
from scipy.spatial import cKDTree

from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
from scilpy.tractograms.streamline_operations import smooth_line_gaussian, \
resample_streamlines_step_size, parallel_transport_streamline, \
cut_invalid_streamlines
cut_invalid_streamlines, compress_sft, cut_invalid_streamlines, \
remove_overlapping_points_streamlines, remove_single_point_streamlines
from scilpy.tractograms.streamline_and_mask_operations import cut_outside_of_mask_streamlines
from scilpy.utils.spatial import generate_rotation_matrix

MIN_NB_POINTS = 10
KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1)))
Expand Down Expand Up @@ -693,7 +699,8 @@ def upsample_tractogram(sft, nb, point_wise_std=None, tube_radius=None,
polynomial = Polynomial(poly_coeffs[::-1])
noise_factor = polynomial(x)

vec = s - new_s
tmp_s = s - new_s
vec = tmp_s if not np.any(tmp_s == 0) else np.ones(tmp_s.shape)
vec /= np.linalg.norm(vec, axis=0)
new_s += vec * np.expand_dims(noise_factor, axis=1)

Expand Down Expand Up @@ -879,6 +886,357 @@ def split_sft_randomly_per_cluster(orig_sft, chunk_sizes, seed, thresholds):
return final_sfts


def subsample_streamlines_alter(sft, min_dice=0.90, epsilon=0.01,
baseline_sft=None):
"""
Function to subsample streamlines based on a dice similarity metric.
The function will keep removing streamlines until the dice similarity
between the original and the subsampled tractogram is close to min_dice.
Parameters
----------
sft: StatefulTractogram
The tractogram to subsample.
min_dice: float
The minimum dice similarity to reach before stopping the subsampling.
epsilon: float
Stopping criteria for convergence. The maximum difference between the
dice similarity and min_dice.
baseline_sft: StatefulTractogram
The tractogram to use as a reference for the dice similarity. If None,
the original tractogram will be used.
Returns
-------
new_sft: StatefulTractogram
The tractogram with a subset of streamlines in the same space as the
input tractogram.
"""
# Import in function to avoid circular import error
from scilpy.tractanalysis.reproducibility_measures import compute_dice_voxel
set_sft_logger_level(logging.ERROR)
space = sft.space
origin = sft.origin

sft.to_vox()
sft.to_corner()
if baseline_sft is None:
original_density_map = compute_tract_counts_map(sft.streamlines,
sft.dimensions)
else:
baseline_sft.to_vox()
baseline_sft.to_corner()
original_density_map = compute_tract_counts_map(baseline_sft.streamlines,
sft.dimensions)
dice = 1.0
init_pick_min = 0
init_pick_max = len(sft)
previous_to_pick = None
while dice > min_dice or np.abs(dice - min_dice) > epsilon:
to_pick = init_pick_min + (init_pick_max - init_pick_min) // 2
if to_pick == previous_to_pick:
logging.warning('No more streamlines to pick, not converging.')
break
previous_to_pick = to_pick

indices = np.random.choice(len(sft), to_pick, replace=False)
streamlines = sft.streamlines[indices]
curr_density_map = compute_tract_counts_map(streamlines,
sft.dimensions)
dice, _ = compute_dice_voxel(original_density_map, curr_density_map)
logging.debug(f'Subsampled {to_pick} streamlines, dice: {dice}')

if dice < min_dice:
init_pick_min = to_pick
else:
init_pick_max = to_pick

new_sft = StatefulTractogram.from_sft(streamlines, sft)
new_sft.to_space(space)
new_sft.to_origin(origin)
return new_sft


def cut_streamlines_alter(sft, min_dice=0.90, epsilon=0.01, from_start=True):
"""
Cut streamlines based on a dice similarity metric.
The function will keep removing points from the streamlines until the dice
similarity between the original and the cut tractogram is close to min_dice.
Parameters
----------
sft: StatefulTractogram
The tractogram to cut.
min_dice: float
The minimum dice similarity to reach before stopping the cutting.
epsilon: float
Stopping criteria for convergence. The maximum difference between the
dice similarity and min_dice.
from_start: bool
If True, cut from the start of the streamlines. If False, cut from the
end of the streamlines.
Returns
-------
new_sft: StatefulTractogram
The tractogram with cut streamlines in the same space as the input
tractogram.
"""
# Import in function to avoid circular import error
from scilpy.tractanalysis.reproducibility_measures import compute_dice_voxel
set_sft_logger_level(logging.ERROR)
space = sft.space
origin = sft.origin

# Uniformize endpoints to cut consistently from one end only
uniformize_bundle_sft(sft, swap=not from_start)
sft = resample_streamlines_step_size(sft, 0.5)
sft.to_vox()
sft.to_corner()
original_density_map = compute_tract_counts_map(sft.streamlines,
sft.dimensions)

# Initialize the dice value and the cut percentage for dichotomic search
dice = 1.0
init_cut_min = 0
init_cut_max = 1.0
previous_to_pick = None
while dice > min_dice or np.abs(dice - min_dice) > epsilon:
to_pick = init_cut_min + (init_cut_max - init_cut_min) / 2
if to_pick == previous_to_pick:
logging.warning('No more points to pick, not converging.')
break
previous_to_pick = to_pick

streamlines = []
for streamline in sft.streamlines:
pos_to_pick = int(len(streamline) * to_pick)
streamline = streamline[:pos_to_pick]
streamlines.append(streamline)
curr_density_map = compute_tract_counts_map(streamlines,
sft.dimensions)
dice, _ = compute_dice_voxel(original_density_map, curr_density_map)
logging.debug(f'Cut {to_pick * 100}% of the streamlines, dice: {dice}')

if dice < min_dice:
init_cut_min = to_pick
else:
init_cut_max = to_pick

new_sft = StatefulTractogram.from_sft(streamlines, sft)
new_sft.to_space(space)
new_sft.to_origin(origin)
return compress_sft(new_sft)


def replace_streamlines_alter(sft, min_dice=0.90, epsilon=0.01):
"""
Replace streamlines based on a dice similarity metric.
The function will upsamples the streamlines (with parallel transport),
then downsample them until the dice similarity is close to min_dice.
This effectively replaces the streamlines with new ones.
Parameters
----------
sft: StatefulTractogram
The tractogram to replace streamlines from.
min_dice: float
The minimum dice similarity to reach before stopping the replacement.
epsilon: float
Stopping criteria for convergence. The maximum difference between the
dice similarity and min_dice.
Returns
-------
new_sft: StatefulTractogram
The tractogram with replaced streamlines in the same space as the input
tractogram.
"""
# Import in function to avoid circular import error
from scilpy.tractanalysis.reproducibility_measures import compute_dice_voxel
set_sft_logger_level(logging.ERROR)

logging.debug('Upsampling the streamlines by a factor 2x to then '
'downsample.')
upsampled_sft = upsample_tractogram(sft, len(sft) * 2, point_wise_std=0.5,
tube_radius=1.0, gaussian=None,
error_rate=0.1, seed=1234)
return subsample_streamlines_alter(upsampled_sft, min_dice, epsilon,
baseline_sft=sft)


def trim_streamlines_alter(sft, min_dice=0.90, epsilon=0.01):
# Import in function to avoid circular import error
from scilpy.tractanalysis.reproducibility_measures import compute_dice_voxel
set_sft_logger_level(logging.ERROR)
space = sft.space
origin = sft.origin

sft.to_vox()
sft.to_corner()
original_density_map = compute_tract_counts_map(sft.streamlines,
sft.dimensions).astype(np.uint64)
thr_density = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
thr_pos = 0
voxels_to_remove = np.where(
(original_density_map <= thr_density[thr_pos]) &
(original_density_map > 0))

# Initialize the dice value and the number of voxels to pick
dice = 1.0
previous_dice = 0.0
init_trim_min = 0
init_trim_max = np.count_nonzero(voxels_to_remove[0])
previous_to_pick = None

while dice > min_dice or np.abs(dice - previous_dice) > epsilon:
to_pick = init_trim_min + (init_trim_max - init_trim_min) // 2
if to_pick == previous_to_pick or \
np.abs(dice - previous_dice) < epsilon:
# If too few voxels are picked, increase the threshold
# and reinitialize the picking

if np.abs(dice - min_dice) > epsilon and \
thr_pos < len(thr_density) - 1:
thr_pos += 1
logging.debug(f'Increasing threshold density to '
f'{thr_density[thr_pos]}.')

voxels_to_remove = np.where(
(original_density_map <= thr_density[thr_pos]) &
(original_density_map > 0))
init_trim_min = 0
init_trim_max = np.count_nonzero(voxels_to_remove[0])
dice = 1.0
previous_dice = 0.0
previous_to_pick = None
continue
else:
break
previous_to_pick = to_pick

voxel_to_remove = np.where(
(original_density_map <= thr_density[thr_pos]) &
(original_density_map > 0))
indices = np.random.choice(np.count_nonzero(voxel_to_remove[0]),
to_pick, replace=False)
voxel_to_remove = tuple(np.array(voxel_to_remove).T[indices].T)
mask = original_density_map.copy()
mask[voxel_to_remove] = 0

# set logger level to ERROR to avoid logging from cut_outside_of_mask
log_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)
new_sft = cut_outside_of_mask_streamlines(sft, mask, min_len=10)
# reset logger level
logging.getLogger().setLevel(log_level)

curr_density_map = compute_tract_counts_map(new_sft.streamlines,
sft.dimensions)
previous_dice = dice
dice, _ = compute_dice_voxel(original_density_map, curr_density_map)
logging.debug(f'Trimmed {to_pick} voxels at density '
f'{thr_density[thr_pos]}, dice: {dice}')

if dice < min_dice:
init_trim_max = to_pick
else:
init_trim_min = to_pick

new_sft.to_space(space)
new_sft.to_origin(origin)
return new_sft


def transform_streamlines_alter(sft, min_dice=0.90, epsilon=0.01):
"""
The function will apply random rotations to the streamlines until the dice
similarity between the original and the transformed tractogram is close to
min_dice.
Start with a large XYZ rotation, then reduce the rotation step by half one
axis at a time until the dice similarity is close to min_dice.
Parameters
----------
sft: StatefulTractogram
The tractogram to transform.
min_dice: float
The minimum dice similarity to reach before stopping the transformation.
epsilon: float
Stopping criteria for convergence. The maximum difference between the
dice similarity and min_dice.
Returns
-------
new_sft: StatefulTractogram
The tractogram with transformed streamlines in the same space as the
input tractogram.
"""
# Import in function to avoid circular import error
from scilpy.tractanalysis.reproducibility_measures import compute_dice_voxel
set_sft_logger_level(logging.ERROR)
space = sft.space
origin = sft.origin

sft.to_vox()
sft.to_corner()
original_density_map = compute_tract_counts_map(sft.streamlines,
sft.dimensions)

# Initialize the dice value and angles to pick
dice = 1.0
angle_min = [0.0, 0.0, 0.0]
angle_max = [0.1, 0.1, 0.1]
previous_dice = None
last_pick = np.array([0.0, 0.0, 0.0])
rand_val = np.random.rand(3) * angle_max[0]
axis_choices = np.random.choice(3, 3, replace=False)
axis = 0
while dice > min_dice or np.abs(dice - min_dice) > epsilon:
init_angle_min = angle_min[axis]
init_angle_max = angle_max[axis]
to_pick = init_angle_min + (init_angle_max - init_angle_min) / 2

# Generate a 4x4 matrix from random euler angles
rand_val = np.array(angle_max)
rand_val[axis] = to_pick

angles = rand_val * 2 * np.pi
rot_mat = generate_rotation_matrix(angles)
streamlines = transform_streamlines(sft.streamlines, rot_mat)

# Remove invalid streamlines to avoid numerical issues
curr_sft = StatefulTractogram.from_sft(streamlines, sft)
curr_sft, _ = cut_invalid_streamlines(curr_sft)
curr_sft = remove_single_point_streamlines(curr_sft)
curr_sft = remove_overlapping_points_streamlines(curr_sft)

curr_density_map = compute_tract_counts_map(curr_sft.streamlines,
sft.dimensions)
dice, _ = compute_dice_voxel(original_density_map, curr_density_map)
logging.debug(f'Transformed {to_pick*360} degree on axis {axis}, '
f'dice: {dice}')
last_pick[axis] = to_pick

if dice < min_dice:
angle_max[axis] = to_pick
else:
angle_min[axis] = to_pick

if (previous_dice is not None) \
and np.abs(dice - previous_dice) < epsilon / 2:
logging.debug('Not converging, switching axis.\n')
axis_choices = np.roll(axis_choices, 1)
axis = axis_choices[0]
previous_dice = dice

logging.debug(f'\nFinal angles: {last_pick * 360} at dice: {dice}')
curr_sft.to_space(space)
curr_sft.to_origin(origin)
return curr_sft


OPERATIONS = {
'difference_robust': difference_robust,
'intersection_robust': intersection_robust,
Expand Down
Loading

0 comments on commit e771d98

Please sign in to comment.