Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add angle coloring option + fix bugs #718

Merged
merged 1 commit into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions scilpy/utils/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.spatial import cKDTree
from sklearn.cluster import KMeans

from scilpy.io.utils import load_matrix_in_any_format
from scilpy.tractanalysis.features import get_streamlines_centroid

from scilpy.viz.utils import get_colormap
Expand Down Expand Up @@ -133,31 +134,95 @@ def uniformize_bundle_sft_using_mask(sft, mask, swap=False):
sft.to_space(old_space)
sft.to_origin(old_origin)

def get_color_streamlines_along_length(sft, colormap='jet'):

def clip_and_normalize_data_for_cmap(args, data):
if args.LUT:
LUT = load_matrix_in_any_format(args.LUT)
for i, val in enumerate(LUT):
data[data == i+1] = val

if args.min_range is not None or args.max_range is not None:
data = np.clip(data, args.min_range, args.max_range)

# get data values range
if args.min_cmap is not None:
lbound = args.min_cmap
else:
lbound = np.min(data)
if args.max_cmap is not None:
ubound = args.max_cmap
else:
ubound = np.max(data)

if args.log:
data[data > 0] = np.log10(data[data > 0])

# normalize data between 0 and 1
data -= lbound
data = data / ubound if ubound > 0 else data
return data, lbound, ubound


def get_color_streamlines_from_angle(sft, args):
"""Color streamlines according to their length.

Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be colored
colormap: str
The colormap to use.
args: NameSpace
The colormap options.

Returns
-------
color: np.ndarray
An array of shape (nb_streamlines, 3) containing the RGB values of
streamlines

lbound: float
Minimal value
ubound: float
Maximal value
"""
cmap = get_colormap(colormap)
color_dpp = copy.deepcopy(sft.streamlines)
angles = []
for i in range(len(sft.streamlines)):
dirs = np.diff(sft.streamlines[i], axis=0)
dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True)
cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1)
# Resolve numerical instability
cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0)
line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0]
angles.extend(line_angles)

angles = np.rad2deg(angles)

return clip_and_normalize_data_for_cmap(args, angles)


def get_color_streamlines_along_length(sft, args):
"""Color streamlines according to their length.

Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be colored
args: NameSpace
The colormap options.

Returns
-------
color: np.ndarray
An array of shape (nb_streamlines, 3) containing the RGB values of
streamlines
lbound: int
Minimal value
ubound: int
Maximal value
"""
positions = []
for i in range(len(sft.streamlines)):
color_dpp[i] = cmap(np.linspace(0, 1, len(sft.streamlines[i])))[
:, 0:3] * 255
positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i]))))

return color_dpp._data
return clip_and_normalize_data_for_cmap(args, positions)


def filter_tractogram_data(tractogram, streamline_ids):
Expand Down
76 changes: 34 additions & 42 deletions scripts/scil_assign_custom_color_to_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
add_overwrite_arg,
add_reference_arg,
load_matrix_in_any_format)
from scilpy.utils.streamlines import get_color_streamlines_along_length
from scilpy.utils.streamlines import get_color_streamlines_along_length, \
get_color_streamlines_from_angle, clip_and_normalize_data_for_cmap
from scilpy.viz.utils import get_colormap

COLORBAR_NB_VALUES = 255
NB_TICKS = 10


def _build_arg_parser():
Expand All @@ -71,6 +73,9 @@ def _build_arg_parser():
cbar_g.add_argument('--out_colorbar',
help='Optional output colorbar (.png, .jpg or any '
'format supported by matplotlib).')
cbar_g.add_argument('--show_colorbar', action='store_true',
help="Will show the colorbar. Must be used with "
"--out_colorbar to be effective.")
cbar_g.add_argument('--horizontal_cbar', action='store_true',
help='Draw horizontal colorbar (vertical by default).')

Expand All @@ -92,6 +97,10 @@ def _build_arg_parser():
p1.add_argument('--along_profile', action='store_true',
help='Color streamlines according to each point position'
'along its length.\nMust be uniformized head/tail.')
p1.add_argument('--local_angle', action='store_true',
help="Color streamlines according to the angle between "
"each segment (in degree). \nAngles at first and "
"last points are set to 0.")

g2 = p.add_argument_group(title='Coloring Options')
g2.add_argument('--colormap', default='jet',
Expand Down Expand Up @@ -120,34 +129,6 @@ def _build_arg_parser():
return p


def transform_data(args, data):
if args.LUT:
LUT = load_matrix_in_any_format(args.LUT)
for i, val in enumerate(LUT):
data[data == i+1] = val

if args.min_range is not None or args.max_range is not None:
data = np.clip(data, args.min_range, args.max_range)

# get data values range
if args.min_cmap is not None:
lbound = args.min_cmap
else:
lbound = np.min(data)
if args.max_cmap is not None:
ubound = args.max_cmap
else:
ubound = np.max(data)

if args.log:
data[data > 0] = np.log10(data[data > 0])

# normalize data between 0 and 1
data -= lbound
data = data / ubound if ubound > 0 else data
return data, lbound, ubound


def save_colorbar(cmap, lbound, ubound, args):
gradient = cmap(np.linspace(0, 1, COLORBAR_NB_VALUES))[:, 0:3]

Expand All @@ -160,22 +141,25 @@ def save_colorbar(cmap, lbound, ubound, args):
_, ax = plt.subplots(1, 1)
ax.imshow(gradient, origin='lower')

ticks_labels = ['{0:.3f}'.format(lbound), '{0:.3f}'.format(ubound)]
ticks_labels = ['{0:.3f}'.format(i) for i in
np.linspace(lbound, ubound, NB_TICKS)]

if args.log:
ticks_labels[0] = 'log(' + ticks_labels[0] + ')'
ticks_labels[1] = 'log(' + ticks_labels[1] + ')'
ticks_labels = ['log(' + t + ')' for t in ticks_labels]

ticks = np.linspace(0, COLORBAR_NB_VALUES - 1, NB_TICKS)
if not args.horizontal_cbar:
ax.set_yticks([0, COLORBAR_NB_VALUES - 1])
ax.set_yticks(ticks)
ax.set_yticklabels(ticks_labels)
ax.set_xticks([])
else:
ax.set_xticks([0, COLORBAR_NB_VALUES - 1])
ax.set_xticks(ticks)
ax.set_xticklabels(ticks_labels)
ax.set_yticks([])

plt.savefig(args.out_colorbar, bbox_inches='tight')
if args.show_colorbar:
plt.show()


def main():
Expand Down Expand Up @@ -216,26 +200,28 @@ def main():
data = np.squeeze(load_matrix_in_any_format(args.load_dps))
if len(data) != len(sft):
parser.error('Wrong dps size!')
elif args.load_dpp:
else: # args.load_dpp
data = np.squeeze(load_matrix_in_any_format(args.load_dpp))
if len(data) != len(sft.streamlines._data):
parser.error('Wrong dpp size!')
data, lbound, ubound = transform_data(args, data)
color = cmap(data)[:, 0:3] * 255
values, lbound, ubound = clip_and_normalize_data_for_cmap(args, data)
elif args.from_anatomy:
data = nib.load(args.from_anatomy).get_fdata()
data, lbound, ubound = transform_data(args, data)
data, lbound, ubound = clip_and_normalize_data_for_cmap(args, data)

sft.to_vox()
values = map_coordinates(data, sft.streamlines._data.T,
order=0)
color = cmap(values)[:, 0:3] * 255
values = map_coordinates(data, sft.streamlines._data.T, order=0)
sft.to_rasmm()
elif args.along_profile:
color = get_color_streamlines_along_length(sft, args.colormap)
values, lbound, ubound = get_color_streamlines_along_length(
sft, args)
elif args.local_angle:
values, lbound, ubound = get_color_streamlines_from_angle(
sft, args)
else:
parser.error('No coloring method specified.')

color = cmap(values)[:, 0:3] * 255
if len(color) == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],
(len(sft.streamlines[i]), 1))
Expand All @@ -244,6 +230,12 @@ def main():
elif len(color) == len(sft.streamlines._data):
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. (this is our fault). Expecting either one"
"color per streamline ({}) or one per point ({}) but "
"got {}.".format(len(sft), len(sft.streamlines._data),
len(color)))
save_tractogram(sft, args.out_tractogram)

# output colormap
Expand Down