Skip to content

Commit

Permalink
Merge pull request #513 from 36000/plotly_viz_fixes
Browse files Browse the repository at this point in the history
FIX: Plotly viz bug fixes, and update to custom bundles
  • Loading branch information
36000 committed Oct 6, 2020
2 parents 138aed0 + 4235145 commit fcd3c4d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
6 changes: 5 additions & 1 deletion AFQ/viz/fury_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def visualize_bundles(sft, affine=None, n_points=None, bundle_dict=None,
sft : Stateful Tractogram, str
A Stateful Tractogram containing streamline information
or a path to a trk file
In order to visualize individual bundles, the Stateful Tractogram
must contain a bundle key in it's data_per_streamline which is a list
of bundle `'uid'`.
affine : ndarray, optional
An affine transformation to apply to the streamlines before
Expand All @@ -69,7 +72,8 @@ def visualize_bundles(sft, affine=None, n_points=None, bundle_dict=None,
colors : dict or list
If this is a dict, keys are bundle names and values are RGB tuples.
If this is a list, each item is an RGB tuple. Defaults to a list
with Tableau 20 RGB values
with Tableau 20 RGB values if bundle_dict is None, or dict from
bundles to Tableau 20 RGB values if bundle_dict is not None.
background : tuple, optional
RGB values for the background. Default: (1, 1, 1), which is white
Expand Down
8 changes: 6 additions & 2 deletions AFQ/viz/plotly_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def visualize_bundles(sft, affine=None, n_points=None, bundle_dict=None,
----------
sft : Stateful Tractogram, str
A Stateful Tractogram containing streamline information
or a path to a trk file
or a path to a trk file.
In order to visualize individual bundles, the Stateful Tractogram
must contain a bundle key in it's data_per_streamline which is a list
of bundle `'uid'`.
affine : ndarray, optional
An affine transformation to apply to the streamlines before
Expand All @@ -175,7 +178,8 @@ def visualize_bundles(sft, affine=None, n_points=None, bundle_dict=None,
colors : dict or list
If this is a dict, keys are bundle names and values are RGB tuples.
If this is a list, each item is an RGB tuple. Defaults to a list
with Tableau 20 RGB values
with Tableau 20 RGB values if bundle_dict is None, or dict from
bundles to Tableau 20 RGB values if bundle_dict is not None.
background : tuple, optional
RGB values for the background. Default: (1, 1, 1), which is white
Expand Down
59 changes: 49 additions & 10 deletions AFQ/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,41 @@
{'dti_md': 0.001}


def gen_color_dict(bundles):
"""
Helper function.
Generate a color dict given a list of bundles.
"""
def incr_color_idx(color_idx):
return (color_idx + 1) % 20
custom_color_dict = {}
color_idx = 0
for bundle in bundles:
if bundle not in custom_color_dict.keys():
if bundle in COLOR_DICT.keys():
custom_color_dict[bundle] = COLOR_DICT[bundle]
else:
other_bundle = list(bundle)
if bundle[-2:] == '_L':
other_bundle[-2:] = '_R'
elif bundle[-2:] == '_R':
other_bundle[-2:] = '_L'
other_bundle = str(other_bundle)

if other_bundle == bundle: # lone bundle
custom_color_dict[bundle] = tableau_20_sns[color_idx]
color_idx = incr_color_idx(color_idx)
else: # right left pair
if color_idx % 2 != 0:
color_idx = incr_color_idx(color_idx)
custom_color_dict[bundle] =\
tableau_20_sns[color_idx]
custom_color_dict[other_bundle] =\
tableau_20_sns[color_idx + 1]
color_idx = incr_color_idx(incr_color_idx(color_idx))
return custom_color_dict


def viz_import_msg_error(module):
"""Alerts user to install the appropriate viz module """
msg = f"To use {module.upper()} visualizations in pyAFQ, you will need "
Expand Down Expand Up @@ -182,8 +217,11 @@ def tract_generator(sft, affine, bundle, bundle_dict, colors, n_points,
-------
Statefule Tractogram streamlines, RGB numpy array, str
"""
bundle_dict = bundle_dict.copy()
bundle_dict.pop('whole_brain', None)
if colors is None:
if bundle_dict is None:
colors = tableau_20_sns
else:
colors = gen_color_dict(bundle_dict.keys())

if isinstance(sft, str):
viz_logger.info("Loading Stateful Tractogram...")
Expand All @@ -199,18 +237,18 @@ def tract_generator(sft, affine, bundle, bundle_dict, colors, n_points,
streamlines = sft.streamlines
viz_logger.info("Generating colorful lines from tractography...")

if colors is None:
# Use the color dict provided
colors = COLOR_DICT

if list(sft.data_per_streamline.keys()) == []:
# There are no bundles in here:
if n_points is not None:
streamlines = dps.set_number_of_points(streamlines, n_points)
yield streamlines, [0.5, 0.5, 0.5], "all_bundles"
yield streamlines, colors[0], "all_bundles"

else:
# There are bundles:
if bundle_dict is not None:
bundle_dict = bundle_dict.copy()
bundle_dict.pop('whole_brain', None)

if bundle is None:
# No selection: visualize all of them:

Expand Down Expand Up @@ -674,6 +712,7 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False,
self.bundles.sort()
else:
self.bundles = bundles
self.color_dict = gen_color_dict(self.bundles)

def _threshold_scalar(self, bound, threshold, val):
"""
Expand Down Expand Up @@ -842,7 +881,7 @@ def tract_profiles(self, names=None, scalar="dti_fa",
{
"dashes": [(2**i, 2**i)],
"hue": "tractID",
"palette": [COLOR_DICT[bundle]]},
"palette": [self.color_dict[bundle]]},
plot_subject_lines=plot_subject_lines)
if j == 0:
line = Line2D(
Expand Down Expand Up @@ -938,7 +977,7 @@ def contrast_index(self, names=None, scalar="dti_fa",
[bundle], names, scalar)
ba.plot_line(
bundle, "nodeID", "diff", ci_df, "C.I. * 2", (-1, 1),
n_boot, 1.0, {"color": COLOR_DICT[bundle]},
n_boot, 1.0, {"color": self.color_dict[bundle]},
plot_subject_lines=plot_subject_lines)
ci_all_df[j] = ci_df["diff"]
ba.fig.legend([scalar], loc='center', fontsize=medium_font)
Expand Down Expand Up @@ -1004,7 +1043,7 @@ def lateral_contrast_index(self, name, scalar="dti_fa",
[bundle, other_bundle], [name], scalar)
ba.plot_line(
bundle, "nodeID", "diff", ci_df, "C.I. * 2", (-1, 1),
n_boot, 1.0, {"color": COLOR_DICT[bundle]},
n_boot, 1.0, {"color": self.color_dict[bundle]},
plot_subject_lines=plot_subject_lines)

ba.fig.legend([scalar], loc='center', fontsize=medium_font)
Expand Down

0 comments on commit fcd3c4d

Please sign in to comment.