Skip to content

Commit

Permalink
integrate recobundles
Browse files Browse the repository at this point in the history
  • Loading branch information
36000 committed Dec 2, 2020
1 parent 7e77126 commit 514d079
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 43 deletions.
2 changes: 2 additions & 0 deletions AFQ/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
"UF_L": "UNC_L", "UF_R": "UNC_R",
"IFOF_L": "IFO_L", "IFOF_R": "IFO_R",
"CST_L": "CST_L", "CST_R": "CST_R",
"ILF_L": "ILF_L", "ILF_R": "ILF_R",
"SLF_L": "SLF_L", "SLF_R": "SLF_R"
}

BUNDLE_MAT_2_PYTHON = \
Expand Down
120 changes: 77 additions & 43 deletions AFQ/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
SCALAR_REMOVE_MODEL = \
{'dti_md': 'MD', 'dki_md': 'MD', 'dki_fa': 'FA', 'dti_fa': 'FA'}

RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"]

def gen_color_dict(bundles):
"""
Helper function.
Expand Down Expand Up @@ -668,7 +670,7 @@ class GroupCSVComparison():
scan-rescan reliability using Pearson's r.
"""

def __init__(self, out_folder, csv_fnames, names, is_mats=False,
def __init__(self, out_folder, csv_fnames, names, is_special="",
subjects=None,
scalar_bounds={'lb': {'dti_fa': 0.2},
'ub': {'dti_md': 0.002}},
Expand All @@ -678,7 +680,8 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False,
remove_model=False,
mat_bundle_converter=BUNDLE_MAT_2_PYTHON,
mat_column_converter=CSV_MAT_2_PYTHON,
mat_scale_converter=SCALE_MAT_2_PYTHON):
mat_scale_converter=SCALE_MAT_2_PYTHON,
bundle_converter=BUNDLE_RECO_2_AFQ):
"""
Load in csv files, converting from matlab if necessary.
Expand All @@ -694,9 +697,11 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False,
names : list of strings
Name to use to identify each CSV dataset.
is_mats : bool or list of bools, optional
Whether or not the csv was generated from Matlab AFQ or pyAFQ.
Default: False
is_special : str or list of strs, optional
Whether or not the csv needs special attention.
Can be "", "mat" if the csv was generated using mAFQ,
or "reco" if the csv was generated using Recobundles.
Default: ""
subjects : list of num, optional
List of subjects to consider.
Expand Down Expand Up @@ -748,13 +753,18 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False,
Dictionary that maps scalar names to how they should be scaled
to match pyAFQ's scale for that scalar.
Default: SCALE_MAT_2_PYTHON
bundle_converter : dictionary, optional
Dictionary that maps bundle names to more standard bundle names.
Unlike mat_bundle_converter, this converter is applied to all CSVs
Default: BUNDLE_RECO_2_AFQ
"""
self.logger = logging.getLogger('AFQ.csv')
self.out_folder = out_folder
self.percent_nan_tol = percent_nan_tol

if isinstance(is_mats, bool):
is_mats = [is_mats] * len(csv_fnames)
if not isinstance(is_special, list):
is_special = [is_special] * len(csv_fnames)

self.profile_dict = {}
for i, fname in enumerate(csv_fnames):
Expand All @@ -768,7 +778,7 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False,
else:
profile['subjectID'] = 0

if is_mats[i]:
if is_special[i] == "mat":
profile.rename(
columns=mat_column_converter, inplace=True)
profile['tractID'] = \
Expand All @@ -777,6 +787,14 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False,
for scalar, scale in mat_scale_converter.items():
profile[scalar] = \
profile[scalar].apply(lambda x: x * scale)
profile.replace({"tractID": bundle_converter}, inplace=True)
if is_special[i] == "reco":
def reco_flip(df):
if df.tractID in RECO_FLIP:
return 99 - df.nodeID
else:
return df.nodeID
profile["nodeID"] = profile.apply(reco_flip, axis=1)
if remove_model:
profile.rename(
columns=SCALAR_REMOVE_MODEL, inplace=True)
Expand Down Expand Up @@ -844,15 +862,15 @@ def _get_fname(self, folder, f_name):
os.makedirs(f_folder, exist_ok=True)
return op.join(f_folder, f_name)

def _get_profile(self, name, bundle, subject, scalar, repl_nan=True):
def _get_profile(self, name, bundle, subject, scalar):
"""
Get a single profile, then handle not found / NaNs
"""
profile = self.profile_dict[name]
single_profile = profile[
(profile['subjectID'] == subject)
& (profile['tractID'] == bundle)
][scalar].to_numpy()
].sort_values("nodeID")[scalar].to_numpy()
nans = np.isnan(single_profile)
percent_nan = (np.sum(nans) * 100) // self.prof_len
if len(single_profile) < 1:
Expand Down Expand Up @@ -912,7 +930,7 @@ def masked_corr(self, arr):
np.isnan(arr[0, ...]),
np.isnan(arr[1, ...])))
if np.sum(mask) < 2:
return 0
return np.nan
return np.corrcoef(arr[:, mask])[0][1]

def tract_profiles(self, names=None, scalar="dti_fa",
Expand Down Expand Up @@ -1263,7 +1281,6 @@ def reliability_plots(self, names=None,

# extract relevant statistics / data from profiles
N = len(self.subjects)
err_prime = 1.95/np.sqrt(N-3)
all_sub_coef = np.zeros((len(scalars), len(self.bundles)))
all_sub_coef_err = np.zeros((len(scalars), len(self.bundles), 2))
all_sub_means = np.zeros(
Expand All @@ -1281,19 +1298,30 @@ def reliability_plots(self, names=None,
for j, name in enumerate(names):
for i, subject in enumerate(self.subjects):
single_profile = self._get_profile(
name, bundle, subject, scalar, repl_nan=False)
name, bundle, subject, scalar)
if single_profile is None:
bundle_profiles[j, i] = np.nan
miss_counts.at[bundle, f"miss_count{name}"] =\
miss_counts.at[bundle, f"miss_count{name}"] + 1
miss_counts.at[
bundle, f"miss_count{name}"]+ 1
else:
bundle_profiles[j, i] = single_profile

all_sub_means[m, k] = np.nanmean(bundle_profiles, axis=2)
N_not_nan = N - np.sum(np.isnan(
all_sub_means[m, k, 0] + all_sub_means[m, k, 1]))
all_sub_coef[m, k] = self.masked_corr(all_sub_means[m, k])
corr_prime = np.arctanh(all_sub_coef[m, k]) # uses Fisher Transformation
all_sub_coef_err[m, k, 0] = np.tanh(corr_prime + err_prime) - all_sub_coef[m, k]
all_sub_coef_err[m, k, 1] = all_sub_coef[m, k] - np.tanh(corr_prime - err_prime)
if np.isnan(all_sub_coef[m, k]).all() or N_not_nan < 4:
raise ValueError((
f"Not enough non-nan profiles"
f"for scalar {scalar} for bundle {bundle}"))
# use Fisher Transformation
err_prime = 1.95/np.sqrt(N_not_nan-3)
corr_prime = np.arctanh(all_sub_coef[m, k])
all_sub_coef_err[m, k, 1] = np.tanh(corr_prime + err_prime)\
- all_sub_coef[m, k]
all_sub_coef_err[m, k, 0] = all_sub_coef[m, k]\
- np.tanh(corr_prime - err_prime)

bundle_coefs = np.zeros(N)
for i in range(N):
Expand All @@ -1307,14 +1335,16 @@ def reliability_plots(self, names=None,
all_node_coef[m, k] = node_coefs

# plot histograms of subject pearson r's
maxi = all_profile_coef.max()
mini = all_profile_coef.min()
maxi = np.nanmax(all_profile_coef)
mini = np.nanmin(all_profile_coef)
bins = np.linspace(mini, maxi, 10)
ba = BrainAxes(positions=positions)
for k, bundle in enumerate(self.bundles):
ax = ba.get_axis(bundle)
for m, scalar in enumerate(scalars):
bundle_coefs = all_profile_coef[m, k]
bundle_coefs = bundle_coefs[~np.isnan(bundle_coefs)]
print(bundle_coefs)
sns.set(style="whitegrid")
sns.histplot(
data=bundle_coefs,
Expand Down Expand Up @@ -1632,8 +1662,8 @@ def reliability_plots(self, names=None,

def compare_reliability(self, reliability1, reliability2,
analysis_label1, analysis_label2,
errors1, errors2,
bundles,
errors1=None, errors2=None,
scalars=["FA", "MD"],
rtype="Subject Reliability",
show_plots=False):
Expand All @@ -1651,15 +1681,17 @@ def compare_reliability(self, reliability1, reliability2,
Names of the analyses used to obtain each dataset.
Used to label the x and y axes.
errors1, errors2 : numpy arrays
Numpy arrays describing the errors.
Typically, each of this will be outputs of separate calls
to reliability_plots.
bundles : list of str
List of bundles that correspond to the second dimension of the
reliability arrays.
errors1, errors2 : numpy arrays or None
Numpy arrays describing the errors.
Typically, each of this will be outputs of separate calls
to reliability_plots.
If None, errors are not shown.
Default is None.
scalars : list of str, optional
Lsit of scalars that correspond to the first dimension of the
reliability arrays. Default: ["FA", "MD"]
Expand All @@ -1676,34 +1708,36 @@ def compare_reliability(self, reliability1, reliability2,
-------
Returns a Matplotlib figure and axes.
"""
show_error = ((errors1 is not None) and (errors2 is not None))
fig, ax = plt.subplots()
for i, scalar in enumerate(scalars):
for j, bundle in enumerate(bundles):
if len(errors1.shape) > 2:
xerr = errors1[:, j, i].reshape((2, 1))
else:
xerr = errors1[i, j]
if len(errors2.shape) > 2:
yerr = errors2[:, j, i].reshape((2, 1))
else:
yerr = errors2[i, j]

ax.scatter(
reliability1[i, j],
reliability2[i, j],
s=marker_size,
c=[self.color_dict[bundle]],
marker=self.scalar_markers[i]
)
ax.errorbar(
reliability1[i, j],
reliability2[i, j],
xerr=xerr,
yerr=yerr,
c=[self.color_dict[bundle]],
alpha=0.5,
fmt="none"
)
if show_error:
if len(errors1.shape) > 2:
xerr = errors1[:, j, i].reshape((2, 1))
else:
xerr = errors1[i, j]
if len(errors2.shape) > 2:
yerr = errors2[:, j, i].reshape((2, 1))
else:
yerr = errors2[i, j]

ax.errorbar(
reliability1[i, j],
reliability2[i, j],
xerr=xerr,
yerr=yerr,
c=[self.color_dict[bundle]],
alpha=0.5,
fmt="none"
)

ax.set_xlabel(f"{analysis_label1} {rtype}",
fontsize=medium_font)
Expand Down

0 comments on commit 514d079

Please sign in to comment.