Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF: Allow passing an AFQDataset object as input. #113

Merged
merged 4 commits into from
May 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions afqinsight/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tqdm.auto import tqdm

from .utils import BUNDLE_MAT_2_PYTHON
from .datasets import AFQDataset


POSITIONS = OrderedDict(
Expand Down Expand Up @@ -52,8 +53,8 @@

def plot_tract_profiles(
X,
groups,
group_names,
groups=None,
group_names=None,
group_by=None,
group_by_name=None,
bins=None,
Expand All @@ -71,15 +72,18 @@ def plot_tract_profiles(

Parameters
----------
X : numpy.ndarray
matrix of tractometry features with shape (n_subjects, n_features).
X : numpy.ndarray or AFQDataset class instance
If array, this is a matrix of tractometry features with shape (n_subjects, n_features).

groups : list of numpy.ndarray
feature indices for each feature group of ``X``
groups : list of numpy.ndarray, optional
feature indices for each feature group of ``X``.
Must be provided if ``X`` is an array. Should not be provided if
``X`` is an AFQDataset.

group_names : list of tuples
the multi-indexed name for each group in ``groups``. Must be of same
length as ``groups``.
length as ``groups``. Must be provided if ``X`` is an array.
Should not be provided if ``X`` is an AFQDataset

group_by : list-like
grouping variable that will produce different bundle profiles with
Expand Down Expand Up @@ -142,6 +146,22 @@ def plot_tract_profiles(
dictionary of matplotlib figures, with keys corresponding to the
different diffusion metrics
"""
if isinstance(X, AFQDataset):
if groups is not None or group_names is not None:
raise ValueError(
"You provided an AFQDataset class instance as `X` input and also a `groups` or `group_names` input, but these are mutually exclusive."
)
# Allocate the variables needed below based on the input dataset:
group_names = X.group_names
groups = X.groups
X = X.X

else:
if groups is None or group_names is None:
raise ValueError(
"You provided an array input as `X` but did not provide both a `groups` and a `group_names` input. You must provide both of these for array input. "
)

plt_positions = subplot_positions if subplot_positions is not None else POSITIONS

if bins is not None and quantiles is not None:
Expand Down