From 514d0792a369cca0d88a19964ae1e97aadaaa48d Mon Sep 17 00:00:00 2001 From: 36000 Date: Wed, 2 Dec 2020 12:02:51 -0800 Subject: [PATCH] integrate recobundles --- AFQ/data.py | 2 + AFQ/viz/utils.py | 120 ++++++++++++++++++++++++++++++----------------- 2 files changed, 79 insertions(+), 43 deletions(-) diff --git a/AFQ/data.py b/AFQ/data.py index 2048ad169..52bb59a70 100644 --- a/AFQ/data.py +++ b/AFQ/data.py @@ -68,6 +68,8 @@ "UF_L": "UNC_L", "UF_R": "UNC_R", "IFOF_L": "IFO_L", "IFOF_R": "IFO_R", "CST_L": "CST_L", "CST_R": "CST_R", + "ILF_L": "ILF_L", "ILF_R": "ILF_R", + "SLF_L": "SLF_L", "SLF_R": "SLF_R" } BUNDLE_MAT_2_PYTHON = \ diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index e60653694..b2719e65a 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -80,6 +80,8 @@ SCALAR_REMOVE_MODEL = \ {'dti_md': 'MD', 'dki_md': 'MD', 'dki_fa': 'FA', 'dti_fa': 'FA'} +RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"] + def gen_color_dict(bundles): """ Helper function. @@ -668,7 +670,7 @@ class GroupCSVComparison(): scan-rescan reliability using Pearson's r. """ - def __init__(self, out_folder, csv_fnames, names, is_mats=False, + def __init__(self, out_folder, csv_fnames, names, is_special="", subjects=None, scalar_bounds={'lb': {'dti_fa': 0.2}, 'ub': {'dti_md': 0.002}}, @@ -678,7 +680,8 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False, remove_model=False, mat_bundle_converter=BUNDLE_MAT_2_PYTHON, mat_column_converter=CSV_MAT_2_PYTHON, - mat_scale_converter=SCALE_MAT_2_PYTHON): + mat_scale_converter=SCALE_MAT_2_PYTHON, + bundle_converter=BUNDLE_RECO_2_AFQ): """ Load in csv files, converting from matlab if necessary. @@ -694,9 +697,11 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False, names : list of strings Name to use to identify each CSV dataset. - is_mats : bool or list of bools, optional - Whether or not the csv was generated from Matlab AFQ or pyAFQ. - Default: False + is_special : str or list of strs, optional + Whether or not the csv needs special attention. + Can be "", "mat" if the csv was generated using mAFQ, + or "reco" if the csv was generated using Recobundles. + Default: "" subjects : list of num, optional List of subjects to consider. @@ -748,13 +753,18 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False, Dictionary that maps scalar names to how they should be scaled to match pyAFQ's scale for that scalar. Default: SCALE_MAT_2_PYTHON + + bundle_converter : dictionary, optional + Dictionary that maps bundle names to more standard bundle names. + Unlike mat_bundle_converter, this converter is applied to all CSVs + Default: BUNDLE_RECO_2_AFQ """ self.logger = logging.getLogger('AFQ.csv') self.out_folder = out_folder self.percent_nan_tol = percent_nan_tol - if isinstance(is_mats, bool): - is_mats = [is_mats] * len(csv_fnames) + if not isinstance(is_special, list): + is_special = [is_special] * len(csv_fnames) self.profile_dict = {} for i, fname in enumerate(csv_fnames): @@ -768,7 +778,7 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False, else: profile['subjectID'] = 0 - if is_mats[i]: + if is_special[i] == "mat": profile.rename( columns=mat_column_converter, inplace=True) profile['tractID'] = \ @@ -777,6 +787,14 @@ def __init__(self, out_folder, csv_fnames, names, is_mats=False, for scalar, scale in mat_scale_converter.items(): profile[scalar] = \ profile[scalar].apply(lambda x: x * scale) + profile.replace({"tractID": bundle_converter}, inplace=True) + if is_special[i] == "reco": + def reco_flip(df): + if df.tractID in RECO_FLIP: + return 99 - df.nodeID + else: + return df.nodeID + profile["nodeID"] = profile.apply(reco_flip, axis=1) if remove_model: profile.rename( columns=SCALAR_REMOVE_MODEL, inplace=True) @@ -844,7 +862,7 @@ def _get_fname(self, folder, f_name): os.makedirs(f_folder, exist_ok=True) return op.join(f_folder, f_name) - def _get_profile(self, name, bundle, subject, scalar, repl_nan=True): + def _get_profile(self, name, bundle, subject, scalar): """ Get a single profile, then handle not found / NaNs """ @@ -852,7 +870,7 @@ def _get_profile(self, name, bundle, subject, scalar, repl_nan=True): single_profile = profile[ (profile['subjectID'] == subject) & (profile['tractID'] == bundle) - ][scalar].to_numpy() + ].sort_values("nodeID")[scalar].to_numpy() nans = np.isnan(single_profile) percent_nan = (np.sum(nans) * 100) // self.prof_len if len(single_profile) < 1: @@ -912,7 +930,7 @@ def masked_corr(self, arr): np.isnan(arr[0, ...]), np.isnan(arr[1, ...]))) if np.sum(mask) < 2: - return 0 + return np.nan return np.corrcoef(arr[:, mask])[0][1] def tract_profiles(self, names=None, scalar="dti_fa", @@ -1263,7 +1281,6 @@ def reliability_plots(self, names=None, # extract relevant statistics / data from profiles N = len(self.subjects) - err_prime = 1.95/np.sqrt(N-3) all_sub_coef = np.zeros((len(scalars), len(self.bundles))) all_sub_coef_err = np.zeros((len(scalars), len(self.bundles), 2)) all_sub_means = np.zeros( @@ -1281,19 +1298,30 @@ def reliability_plots(self, names=None, for j, name in enumerate(names): for i, subject in enumerate(self.subjects): single_profile = self._get_profile( - name, bundle, subject, scalar, repl_nan=False) + name, bundle, subject, scalar) if single_profile is None: bundle_profiles[j, i] = np.nan miss_counts.at[bundle, f"miss_count{name}"] =\ - miss_counts.at[bundle, f"miss_count{name}"] + 1 + miss_counts.at[ + bundle, f"miss_count{name}"]+ 1 else: bundle_profiles[j, i] = single_profile all_sub_means[m, k] = np.nanmean(bundle_profiles, axis=2) + N_not_nan = N - np.sum(np.isnan( + all_sub_means[m, k, 0] + all_sub_means[m, k, 1])) all_sub_coef[m, k] = self.masked_corr(all_sub_means[m, k]) - corr_prime = np.arctanh(all_sub_coef[m, k]) # uses Fisher Transformation - all_sub_coef_err[m, k, 0] = np.tanh(corr_prime + err_prime) - all_sub_coef[m, k] - all_sub_coef_err[m, k, 1] = all_sub_coef[m, k] - np.tanh(corr_prime - err_prime) + if np.isnan(all_sub_coef[m, k]).all() or N_not_nan < 4: + raise ValueError(( + f"Not enough non-nan profiles" + f"for scalar {scalar} for bundle {bundle}")) + # use Fisher Transformation + err_prime = 1.95/np.sqrt(N_not_nan-3) + corr_prime = np.arctanh(all_sub_coef[m, k]) + all_sub_coef_err[m, k, 1] = np.tanh(corr_prime + err_prime)\ + - all_sub_coef[m, k] + all_sub_coef_err[m, k, 0] = all_sub_coef[m, k]\ + - np.tanh(corr_prime - err_prime) bundle_coefs = np.zeros(N) for i in range(N): @@ -1307,14 +1335,16 @@ def reliability_plots(self, names=None, all_node_coef[m, k] = node_coefs # plot histograms of subject pearson r's - maxi = all_profile_coef.max() - mini = all_profile_coef.min() + maxi = np.nanmax(all_profile_coef) + mini = np.nanmin(all_profile_coef) bins = np.linspace(mini, maxi, 10) ba = BrainAxes(positions=positions) for k, bundle in enumerate(self.bundles): ax = ba.get_axis(bundle) for m, scalar in enumerate(scalars): bundle_coefs = all_profile_coef[m, k] + bundle_coefs = bundle_coefs[~np.isnan(bundle_coefs)] + print(bundle_coefs) sns.set(style="whitegrid") sns.histplot( data=bundle_coefs, @@ -1632,8 +1662,8 @@ def reliability_plots(self, names=None, def compare_reliability(self, reliability1, reliability2, analysis_label1, analysis_label2, - errors1, errors2, bundles, + errors1=None, errors2=None, scalars=["FA", "MD"], rtype="Subject Reliability", show_plots=False): @@ -1651,15 +1681,17 @@ def compare_reliability(self, reliability1, reliability2, Names of the analyses used to obtain each dataset. Used to label the x and y axes. - errors1, errors2 : numpy arrays - Numpy arrays describing the errors. - Typically, each of this will be outputs of separate calls - to reliability_plots. - bundles : list of str List of bundles that correspond to the second dimension of the reliability arrays. + errors1, errors2 : numpy arrays or None + Numpy arrays describing the errors. + Typically, each of this will be outputs of separate calls + to reliability_plots. + If None, errors are not shown. + Default is None. + scalars : list of str, optional Lsit of scalars that correspond to the first dimension of the reliability arrays. Default: ["FA", "MD"] @@ -1676,18 +1708,10 @@ def compare_reliability(self, reliability1, reliability2, ------- Returns a Matplotlib figure and axes. """ + show_error = ((errors1 is not None) and (errors2 is not None)) fig, ax = plt.subplots() for i, scalar in enumerate(scalars): for j, bundle in enumerate(bundles): - if len(errors1.shape) > 2: - xerr = errors1[:, j, i].reshape((2, 1)) - else: - xerr = errors1[i, j] - if len(errors2.shape) > 2: - yerr = errors2[:, j, i].reshape((2, 1)) - else: - yerr = errors2[i, j] - ax.scatter( reliability1[i, j], reliability2[i, j], @@ -1695,15 +1719,25 @@ def compare_reliability(self, reliability1, reliability2, c=[self.color_dict[bundle]], marker=self.scalar_markers[i] ) - ax.errorbar( - reliability1[i, j], - reliability2[i, j], - xerr=xerr, - yerr=yerr, - c=[self.color_dict[bundle]], - alpha=0.5, - fmt="none" - ) + if show_error: + if len(errors1.shape) > 2: + xerr = errors1[:, j, i].reshape((2, 1)) + else: + xerr = errors1[i, j] + if len(errors2.shape) > 2: + yerr = errors2[:, j, i].reshape((2, 1)) + else: + yerr = errors2[i, j] + + ax.errorbar( + reliability1[i, j], + reliability2[i, j], + xerr=xerr, + yerr=yerr, + c=[self.color_dict[bundle]], + alpha=0.5, + fmt="none" + ) ax.set_xlabel(f"{analysis_label1} {rtype}", fontsize=medium_font)