Skip to content

Commit

Permalink
Merge pull request #718 from EmmaRenauld/color_per_angle
Browse files Browse the repository at this point in the history
Add angle coloring option + fix bugs
  • Loading branch information
arnaudbore committed May 17, 2023
2 parents 127dc51 + b0e838a commit 233bb50
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 51 deletions.
83 changes: 74 additions & 9 deletions scilpy/utils/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.spatial import cKDTree
from sklearn.cluster import KMeans

from scilpy.io.utils import load_matrix_in_any_format
from scilpy.tractanalysis.features import get_streamlines_centroid

from scilpy.viz.utils import get_colormap
Expand Down Expand Up @@ -133,31 +134,95 @@ def uniformize_bundle_sft_using_mask(sft, mask, swap=False):
sft.to_space(old_space)
sft.to_origin(old_origin)

def get_color_streamlines_along_length(sft, colormap='jet'):

def clip_and_normalize_data_for_cmap(args, data):
if args.LUT:
LUT = load_matrix_in_any_format(args.LUT)
for i, val in enumerate(LUT):
data[data == i+1] = val

if args.min_range is not None or args.max_range is not None:
data = np.clip(data, args.min_range, args.max_range)

# get data values range
if args.min_cmap is not None:
lbound = args.min_cmap
else:
lbound = np.min(data)
if args.max_cmap is not None:
ubound = args.max_cmap
else:
ubound = np.max(data)

if args.log:
data[data > 0] = np.log10(data[data > 0])

# normalize data between 0 and 1
data -= lbound
data = data / ubound if ubound > 0 else data
return data, lbound, ubound


def get_color_streamlines_from_angle(sft, args):
"""Color streamlines according to their length.
Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be colored
colormap: str
The colormap to use.
args: NameSpace
The colormap options.
Returns
-------
color: np.ndarray
An array of shape (nb_streamlines, 3) containing the RGB values of
streamlines
lbound: float
Minimal value
ubound: float
Maximal value
"""
cmap = get_colormap(colormap)
color_dpp = copy.deepcopy(sft.streamlines)
angles = []
for i in range(len(sft.streamlines)):
dirs = np.diff(sft.streamlines[i], axis=0)
dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True)
cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1)
# Resolve numerical instability
cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0)
line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0]
angles.extend(line_angles)

angles = np.rad2deg(angles)

return clip_and_normalize_data_for_cmap(args, angles)


def get_color_streamlines_along_length(sft, args):
"""Color streamlines according to their length.
Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be colored
args: NameSpace
The colormap options.
Returns
-------
color: np.ndarray
An array of shape (nb_streamlines, 3) containing the RGB values of
streamlines
lbound: int
Minimal value
ubound: int
Maximal value
"""
positions = []
for i in range(len(sft.streamlines)):
color_dpp[i] = cmap(np.linspace(0, 1, len(sft.streamlines[i])))[
:, 0:3] * 255
positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i]))))

return color_dpp._data
return clip_and_normalize_data_for_cmap(args, positions)


def filter_tractogram_data(tractogram, streamline_ids):
Expand Down
76 changes: 34 additions & 42 deletions scripts/scil_assign_custom_color_to_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
add_overwrite_arg,
add_reference_arg,
load_matrix_in_any_format)
from scilpy.utils.streamlines import get_color_streamlines_along_length
from scilpy.utils.streamlines import get_color_streamlines_along_length, \
get_color_streamlines_from_angle, clip_and_normalize_data_for_cmap
from scilpy.viz.utils import get_colormap

COLORBAR_NB_VALUES = 255
NB_TICKS = 10


def _build_arg_parser():
Expand All @@ -71,6 +73,9 @@ def _build_arg_parser():
cbar_g.add_argument('--out_colorbar',
help='Optional output colorbar (.png, .jpg or any '
'format supported by matplotlib).')
cbar_g.add_argument('--show_colorbar', action='store_true',
help="Will show the colorbar. Must be used with "
"--out_colorbar to be effective.")
cbar_g.add_argument('--horizontal_cbar', action='store_true',
help='Draw horizontal colorbar (vertical by default).')

Expand All @@ -92,6 +97,10 @@ def _build_arg_parser():
p1.add_argument('--along_profile', action='store_true',
help='Color streamlines according to each point position'
'along its length.\nMust be uniformized head/tail.')
p1.add_argument('--local_angle', action='store_true',
help="Color streamlines according to the angle between "
"each segment (in degree). \nAngles at first and "
"last points are set to 0.")

g2 = p.add_argument_group(title='Coloring Options')
g2.add_argument('--colormap', default='jet',
Expand Down Expand Up @@ -120,34 +129,6 @@ def _build_arg_parser():
return p


def transform_data(args, data):
if args.LUT:
LUT = load_matrix_in_any_format(args.LUT)
for i, val in enumerate(LUT):
data[data == i+1] = val

if args.min_range is not None or args.max_range is not None:
data = np.clip(data, args.min_range, args.max_range)

# get data values range
if args.min_cmap is not None:
lbound = args.min_cmap
else:
lbound = np.min(data)
if args.max_cmap is not None:
ubound = args.max_cmap
else:
ubound = np.max(data)

if args.log:
data[data > 0] = np.log10(data[data > 0])

# normalize data between 0 and 1
data -= lbound
data = data / ubound if ubound > 0 else data
return data, lbound, ubound


def save_colorbar(cmap, lbound, ubound, args):
gradient = cmap(np.linspace(0, 1, COLORBAR_NB_VALUES))[:, 0:3]

Expand All @@ -160,22 +141,25 @@ def save_colorbar(cmap, lbound, ubound, args):
_, ax = plt.subplots(1, 1)
ax.imshow(gradient, origin='lower')

ticks_labels = ['{0:.3f}'.format(lbound), '{0:.3f}'.format(ubound)]
ticks_labels = ['{0:.3f}'.format(i) for i in
np.linspace(lbound, ubound, NB_TICKS)]

if args.log:
ticks_labels[0] = 'log(' + ticks_labels[0] + ')'
ticks_labels[1] = 'log(' + ticks_labels[1] + ')'
ticks_labels = ['log(' + t + ')' for t in ticks_labels]

ticks = np.linspace(0, COLORBAR_NB_VALUES - 1, NB_TICKS)
if not args.horizontal_cbar:
ax.set_yticks([0, COLORBAR_NB_VALUES - 1])
ax.set_yticks(ticks)
ax.set_yticklabels(ticks_labels)
ax.set_xticks([])
else:
ax.set_xticks([0, COLORBAR_NB_VALUES - 1])
ax.set_xticks(ticks)
ax.set_xticklabels(ticks_labels)
ax.set_yticks([])

plt.savefig(args.out_colorbar, bbox_inches='tight')
if args.show_colorbar:
plt.show()


def main():
Expand Down Expand Up @@ -216,26 +200,28 @@ def main():
data = np.squeeze(load_matrix_in_any_format(args.load_dps))
if len(data) != len(sft):
parser.error('Wrong dps size!')
elif args.load_dpp:
else: # args.load_dpp
data = np.squeeze(load_matrix_in_any_format(args.load_dpp))
if len(data) != len(sft.streamlines._data):
parser.error('Wrong dpp size!')
data, lbound, ubound = transform_data(args, data)
color = cmap(data)[:, 0:3] * 255
values, lbound, ubound = clip_and_normalize_data_for_cmap(args, data)
elif args.from_anatomy:
data = nib.load(args.from_anatomy).get_fdata()
data, lbound, ubound = transform_data(args, data)
data, lbound, ubound = clip_and_normalize_data_for_cmap(args, data)

sft.to_vox()
values = map_coordinates(data, sft.streamlines._data.T,
order=0)
color = cmap(values)[:, 0:3] * 255
values = map_coordinates(data, sft.streamlines._data.T, order=0)
sft.to_rasmm()
elif args.along_profile:
color = get_color_streamlines_along_length(sft, args.colormap)
values, lbound, ubound = get_color_streamlines_along_length(
sft, args)
elif args.local_angle:
values, lbound, ubound = get_color_streamlines_from_angle(
sft, args)
else:
parser.error('No coloring method specified.')

color = cmap(values)[:, 0:3] * 255
if len(color) == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],
(len(sft.streamlines[i]), 1))
Expand All @@ -244,6 +230,12 @@ def main():
elif len(color) == len(sft.streamlines._data):
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. (this is our fault). Expecting either one"
"color per streamline ({}) or one per point ({}) but "
"got {}.".format(len(sft), len(sft.streamlines._data),
len(color)))
save_tractogram(sft, args.out_tractogram)

# output colormap
Expand Down

0 comments on commit 233bb50

Please sign in to comment.