diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index eed558918..dec9c6b91 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -509,6 +509,10 @@ def visualize_tract_profiles(tract_profiles, scalar="dti_fa", ylim=None, return df +def display_scalar(scalar_name): + return scalar_name.replace("_", " ").upper() + + class BrainAxes(): ''' Helper class. @@ -970,7 +974,7 @@ def tract_profiles(self, names=None, scalar="dti_fa", profile = profile[profile['tractID'] == bundle] ba.plot_line( bundle, "nodeID", scalar, profile, - scalar, ylim, n_boot, self._alpha(0.6 + 0.2 * i), + display_scalar(scalar), ylim, n_boot, self._alpha(0.6 + 0.2 * i), { "dashes": [(2**i, 2**i)], "hue": "tractID", @@ -1219,11 +1223,14 @@ def reliability_plots(self, names=None, return None # extract relevant statistics / data from profiles + N = len(self.subjects) + err_prime = 1.95/np.sqrt(N-3) all_sub_coef = np.zeros((len(scalars), len(self.bundles))) + all_sub_coef_err = np.zeros((len(scalars), len(self.bundles), 2)) all_sub_means = np.zeros( - (len(scalars), len(self.bundles), 2, len(self.subjects))) + (len(scalars), len(self.bundles), 2, N)) all_profile_coef = \ - np.zeros((len(scalars), len(self.bundles), len(self.subjects))) + np.zeros((len(scalars), len(self.bundles), N)) all_node_coef = np.zeros( (len(scalars), len(self.bundles), self.prof_len)) miss_counts = pd.DataFrame(0, index=self.bundles, columns=[ @@ -1231,7 +1238,7 @@ def reliability_plots(self, names=None, for m, scalar in enumerate(scalars): for k, bundle in enumerate(self.bundles): bundle_profiles =\ - np.zeros((2, len(self.subjects), self.prof_len)) + np.zeros((2, N, self.prof_len)) for j, name in enumerate(names): for i, subject in enumerate(self.subjects): single_profile = self._get_profile( @@ -1245,9 +1252,12 @@ def reliability_plots(self, names=None, all_sub_means[m, k] = np.nanmean(bundle_profiles, axis=2) all_sub_coef[m, k] = self.masked_corr(all_sub_means[m, k]) + corr_prime = np.arctanh(all_sub_coef[m, k]) # uses Fisher Transformation + all_sub_coef_err[m, k, 0] = np.tanh(corr_prime + err_prime) - all_sub_coef[m, k] + all_sub_coef_err[m, k, 1] = all_sub_coef[m, k] - np.tanh(corr_prime - err_prime) - bundle_coefs = np.zeros(len(self.subjects)) - for i in range(len(self.subjects)): + bundle_coefs = np.zeros(N) + for i in range(N): bundle_coefs[i] = \ self.masked_corr(bundle_profiles[:, i, :]) all_profile_coef[m, k] = bundle_coefs @@ -1403,7 +1413,7 @@ def reliability_plots(self, names=None, fig, axes = plt.subplots(2, 1) fig.set_size_inches((8, 8)) bundle_prof_means = np.nanmean(all_profile_coef, axis=2) - bundle_prof_stds = sem(all_profile_coef, axis=2, nan_policy='omit') + bundle_prof_stds = 1.95*sem(all_profile_coef, axis=2, nan_policy='omit') if ylims is None: maxi = np.maximum(bundle_prof_means.max(), all_sub_coef.max()) mini = np.minimum(bundle_prof_means.min(), all_sub_coef.min()) @@ -1428,10 +1438,15 @@ def reliability_plots(self, names=None, all_sub_coef, removal_idx, axis=1) + all_sub_coef_err_removed = np.delete( + all_sub_coef_err, + removal_idx, + axis=1) else: is_removed_bundle = [False]*len(self.bundles) bundle_prof_means_removed = bundle_prof_means bundle_prof_stds_removed = bundle_prof_stds + all_sub_coef_err_removed = all_sub_coef_err all_sub_coef_removed = all_sub_coef df_bundle_prof_means = pd.DataFrame( @@ -1468,6 +1483,7 @@ def reliability_plots(self, names=None, ignore_index=True) sns.set(style="whitegrid") + print(np.transpose([*all_sub_coef_err_removed[m], np.asarray([0, 0])])) sns.barplot( data=df_bundle_prof_means, x='tractID', y='value', hue='scalar', palette=tableau_20_sns[:len(scalars) * 2 + 2:2], @@ -1477,6 +1493,7 @@ def reliability_plots(self, names=None, sns.barplot( data=df_all_sub_coef, x='tractID', y='value', hue='scalar', palette=tableau_20_sns[:len(scalars) * 2 + 2:2], + yerr=np.transpose([*all_sub_coef_err_removed[m], np.asarray([0, 0])]), ax=axes[1]) axes[1].legend_.remove() @@ -1591,9 +1608,17 @@ def compare_reliability(self, reliability_df1, reliability_df2, hue='tractID', palette=self.color_dict, ax=ax) - g.set(ylim=(0, 1)) - g.set(xlim=(0, 1)) - plt.legend(bbox_to_anchor=(1.1, 1),borderaxespad=0) + ax.set_xlabel(f"{analysis_label1} {rtype}", + fontsize=medium_font) + ax.set_ylabel(f"{analysis_label2} {rtype}", + fontsize=medium_font) + ax.tick_params( + axis='x', which='major', labelsize=medium_font) + ax.tick_params( + axis='y', which='major', labelsize=medium_font) + g.set(ylim=(0.2, 1)) + g.set(xlim=(0.2, 1)) + plt.legend(loc="lower left", ncol=7, bbox_to_anchor=(0., -0.8)) ax.plot([[0, 0], [1, 1]], [[0, 0], [1, 1]], '--') return fig, ax