Skip to content

Commit

Permalink
add fishers transformation; tweak more font sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
36000 committed Nov 6, 2020
1 parent 15fcf62 commit d10522f
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions AFQ/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,10 @@ def visualize_tract_profiles(tract_profiles, scalar="dti_fa", ylim=None,
return df


def display_scalar(scalar_name):
return scalar_name.replace("_", " ").upper()


class BrainAxes():
'''
Helper class.
Expand Down Expand Up @@ -970,7 +974,7 @@ def tract_profiles(self, names=None, scalar="dti_fa",
profile = profile[profile['tractID'] == bundle]
ba.plot_line(
bundle, "nodeID", scalar, profile,
scalar, ylim, n_boot, self._alpha(0.6 + 0.2 * i),
display_scalar(scalar), ylim, n_boot, self._alpha(0.6 + 0.2 * i),
{
"dashes": [(2**i, 2**i)],
"hue": "tractID",
Expand Down Expand Up @@ -1219,19 +1223,22 @@ def reliability_plots(self, names=None,
return None

# extract relevant statistics / data from profiles
N = len(self.subjects)
err_prime = 1.95/np.sqrt(N-3)
all_sub_coef = np.zeros((len(scalars), len(self.bundles)))
all_sub_coef_err = np.zeros((len(scalars), len(self.bundles), 2))
all_sub_means = np.zeros(
(len(scalars), len(self.bundles), 2, len(self.subjects)))
(len(scalars), len(self.bundles), 2, N))
all_profile_coef = \
np.zeros((len(scalars), len(self.bundles), len(self.subjects)))
np.zeros((len(scalars), len(self.bundles), N))
all_node_coef = np.zeros(
(len(scalars), len(self.bundles), self.prof_len))
miss_counts = pd.DataFrame(0, index=self.bundles, columns=[
f"miss_count{names[0]}", f"miss_count{names[1]}"])
for m, scalar in enumerate(scalars):
for k, bundle in enumerate(self.bundles):
bundle_profiles =\
np.zeros((2, len(self.subjects), self.prof_len))
np.zeros((2, N, self.prof_len))
for j, name in enumerate(names):
for i, subject in enumerate(self.subjects):
single_profile = self._get_profile(
Expand All @@ -1245,9 +1252,12 @@ def reliability_plots(self, names=None,

all_sub_means[m, k] = np.nanmean(bundle_profiles, axis=2)
all_sub_coef[m, k] = self.masked_corr(all_sub_means[m, k])
corr_prime = np.arctanh(all_sub_coef[m, k]) # uses Fisher Transformation
all_sub_coef_err[m, k, 0] = np.tanh(corr_prime + err_prime) - all_sub_coef[m, k]
all_sub_coef_err[m, k, 1] = all_sub_coef[m, k] - np.tanh(corr_prime - err_prime)

bundle_coefs = np.zeros(len(self.subjects))
for i in range(len(self.subjects)):
bundle_coefs = np.zeros(N)
for i in range(N):
bundle_coefs[i] = \
self.masked_corr(bundle_profiles[:, i, :])
all_profile_coef[m, k] = bundle_coefs
Expand Down Expand Up @@ -1403,7 +1413,7 @@ def reliability_plots(self, names=None,
fig, axes = plt.subplots(2, 1)
fig.set_size_inches((8, 8))
bundle_prof_means = np.nanmean(all_profile_coef, axis=2)
bundle_prof_stds = sem(all_profile_coef, axis=2, nan_policy='omit')
bundle_prof_stds = 1.95*sem(all_profile_coef, axis=2, nan_policy='omit')
if ylims is None:
maxi = np.maximum(bundle_prof_means.max(), all_sub_coef.max())
mini = np.minimum(bundle_prof_means.min(), all_sub_coef.min())
Expand All @@ -1428,10 +1438,15 @@ def reliability_plots(self, names=None,
all_sub_coef,
removal_idx,
axis=1)
all_sub_coef_err_removed = np.delete(
all_sub_coef_err,
removal_idx,
axis=1)
else:
is_removed_bundle = [False]*len(self.bundles)
bundle_prof_means_removed = bundle_prof_means
bundle_prof_stds_removed = bundle_prof_stds
all_sub_coef_err_removed = all_sub_coef_err
all_sub_coef_removed = all_sub_coef

df_bundle_prof_means = pd.DataFrame(
Expand Down Expand Up @@ -1468,6 +1483,7 @@ def reliability_plots(self, names=None,
ignore_index=True)

sns.set(style="whitegrid")
print(np.transpose([*all_sub_coef_err_removed[m], np.asarray([0, 0])]))
sns.barplot(
data=df_bundle_prof_means, x='tractID', y='value', hue='scalar',
palette=tableau_20_sns[:len(scalars) * 2 + 2:2],
Expand All @@ -1477,6 +1493,7 @@ def reliability_plots(self, names=None,
sns.barplot(
data=df_all_sub_coef, x='tractID', y='value', hue='scalar',
palette=tableau_20_sns[:len(scalars) * 2 + 2:2],
yerr=np.transpose([*all_sub_coef_err_removed[m], np.asarray([0, 0])]),
ax=axes[1])
axes[1].legend_.remove()

Expand Down Expand Up @@ -1591,9 +1608,17 @@ def compare_reliability(self, reliability_df1, reliability_df2,
hue='tractID',
palette=self.color_dict,
ax=ax)
g.set(ylim=(0, 1))
g.set(xlim=(0, 1))
plt.legend(bbox_to_anchor=(1.1, 1),borderaxespad=0)
ax.set_xlabel(f"{analysis_label1} {rtype}",
fontsize=medium_font)
ax.set_ylabel(f"{analysis_label2} {rtype}",
fontsize=medium_font)
ax.tick_params(
axis='x', which='major', labelsize=medium_font)
ax.tick_params(
axis='y', which='major', labelsize=medium_font)
g.set(ylim=(0.2, 1))
g.set(xlim=(0.2, 1))
plt.legend(loc="lower left", ncol=7, bbox_to_anchor=(0., -0.8))
ax.plot([[0, 0], [1, 1]], [[0, 0], [1, 1]], '--')

return fig, ax
Expand Down

0 comments on commit d10522f

Please sign in to comment.