Skip to content

Commit

Permalink
more versatile mAFQ conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
36000 committed Dec 22, 2020
1 parent 54f6fe8 commit e88266a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
13 changes: 12 additions & 1 deletion AFQ/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,27 @@

BUNDLE_MAT_2_PYTHON = \
{'Right Corticospinal': 'CST_R', 'Left Corticospinal': 'CST_L',
'RightCorticospinal': 'CST_R', 'LeftCorticospinal': 'CST_L',
'Right Uncinate': 'UNC_R', 'Left Uncinate': 'UNC_L',
'RightUncinate': 'UNC_R', 'LeftUncinate': 'UNC_L',
'Left IFOF': 'IFO_L', 'Right IFOF': 'IFO_R',
'LeftIFOF': 'IFO_L', 'RightIFOF': 'IFO_R',
'Right Arcuate': 'ARC_R', 'Left Arcuate': 'ARC_L',
'RightArcuate': 'ARC_R', 'LeftArcuate': 'ARC_L',
'Right Thalamic Radiation': 'ATR_R', 'Left Thalamic Radiation': 'ATR_L',
'RightThalamicRadiation': 'ATR_R', 'LeftThalamicRadiation': 'ATR_L',
'Right Cingulum Cingulate': 'CGC_R', 'Left Cingulum Cingulate': 'CGC_L',
'RightCingulumCingulate': 'CGC_R', 'LeftCingulumCingulate': 'CGC_L',
'Right Cingulum Hippocampus': 'HCC_R',
'Left Cingulum Hippocampus': 'HCC_L',
'RightCingulumHippocampus': 'HCC_R',
'LeftCingulumHippocampus': 'HCC_L',
'Callosum Forceps Major': 'FP', 'Callosum Forceps Minor': 'FA',
'CallosumForcepsMajor': 'FP', 'CallosumForcepsMinor': 'FA',
'Right ILF': 'ILF_R', 'Left ILF': 'ILF_L',
'Right SLF': 'SLF_R', 'Left SLF': 'SLF_L'}
'RightILF': 'ILF_R', 'LeftILF': 'ILF_L',
'Right SLF': 'SLF_R', 'Left SLF': 'SLF_L',
'RightSLF': 'SLF_R', 'LeftSLF': 'SLF_L'}

afq_home = op.join(op.expanduser('~'), 'AFQ_data')

Expand Down
54 changes: 39 additions & 15 deletions AFQ/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def visualize_tract_profiles(tract_profiles, scalar="dti_fa", ylim=None,
None,
[tract_profiles],
["my_tract_profiles"],
remove_model=False,
scalar_bounds={'lb': {}, 'ub': {}})

df = csv_comparison.tract_profiles(
Expand Down Expand Up @@ -933,7 +934,7 @@ def masked_corr(self, arr):
return np.nan
return np.corrcoef(arr[:, mask])[0][1]

def tract_profiles(self, names=None, scalar="dti_fa",
def tract_profiles(self, names=None, scalar="FA",
ylim=[0.0, 1.0],
show_plots=False,
positions=POSITIONS,
Expand All @@ -953,7 +954,7 @@ def tract_profiles(self, names=None, scalar="dti_fa",
Default: None
scalar : string, optional
Scalar to use in plots. Default: "dti_fa".
Scalar to use in plots. Default: "FA".
ylim : list of 2 floats, optional
Minimum and maximum value used for y-axis bounds.
Expand Down Expand Up @@ -1060,7 +1061,7 @@ def _contrast_index_df_maker(self, bundles, names, scalar):
ignore_index=True)
return ci_df

def contrast_index(self, names=None, scalar="dti_fa",
def contrast_index(self, names=None, scalar="FA",
show_plots=False, n_boot=1000,
show_legend=False,
positions=POSITIONS, plot_subject_lines=True):
Expand All @@ -1076,7 +1077,7 @@ def contrast_index(self, names=None, scalar="dti_fa",
Default: None
scalar : string, optional
Scalar to use for the contrast index. Default: "dti_fa".
Scalar to use for the contrast index. Default: "FA".
show_plots : bool, optional
Whether to show plots if in an interactive environment.
Expand Down Expand Up @@ -1135,7 +1136,7 @@ def contrast_index(self, names=None, scalar="dti_fa",
plt.ion()
return ci_all_df

def lateral_contrast_index(self, name, scalar="dti_fa",
def lateral_contrast_index(self, name, scalar="FA",
show_plots=False, n_boot=1000,
positions=POSITIONS, plot_subject_lines=True):
"""
Expand All @@ -1148,7 +1149,7 @@ def lateral_contrast_index(self, name, scalar="dti_fa",
Names of dataset to plot profiles of.
scalar : string, optional
Scalar to use for the contrast index. Default: "dti_fa".
Scalar to use for the contrast index. Default: "FA".
show_plots : bool, optional
Whether to show plots if in an interactive environment.
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def lateral_contrast_index(self, name, scalar="dti_fa",
plt.ion()

def reliability_plots(self, names=None,
scalars=["dti_fa", "dti_md"],
scalars=["FA", "MD"],
ylims=[0.0, 1.0],
show_plots=False,
only_plot_above_thr=None,
Expand All @@ -1226,7 +1227,7 @@ def reliability_plots(self, names=None,
Default: None
scalars : list of strings, optional
Scalars to correlate. Default: ["dti_fa", "dti_md"].
Scalars to correlate. Default: ["FA", "MD"].
ylims : 2-tuple of floats, optional
Limits of the y-axis. Useful to synchronize axes across graphs.
Expand Down Expand Up @@ -1665,7 +1666,8 @@ def compare_reliability(self, reliability1, reliability2,
errors1=None, errors2=None,
scalars=["FA", "MD"],
rtype="Subject Reliability",
show_plots=False):
show_plots=False,
show_legend=True):
"""
Plot a comparison of scan-rescan reliability between two analyses.
Expand Down Expand Up @@ -1703,17 +1705,22 @@ def compare_reliability(self, reliability1, reliability2,
Whether to show plots if in an interactive environment.
Default: False
show_legend : bool, optional
Show legend for the plot, off to the right hand side.
Default: True
Returns
-------
Returns a Matplotlib figure and axes.
"""
show_error = ((errors1 is not None) and (errors2 is not None))
fig, ax = plt.subplots()
legend_labels = []
for i, scalar in enumerate(scalars):
marker = self.scalar_markers[i]
if marker == "x":
marker = marker.upper()
for j, bundle in enumerate(bundles):
marker = self.scalar_markers[i]
if marker == "x":
marker = marker.upper()
ax.scatter(
reliability1[i, j],
reliability2[i, j],
Expand All @@ -1740,6 +1747,17 @@ def compare_reliability(self, reliability1, reliability2,
alpha=0.5,
fmt="none"
)
if i == 0:
legend_labels.append(Patch(
facecolor=self.color_dict[bundle],
label=bundle))
legend_labels.append(Line2D(
[0], [0],
marker=marker,
color='k',
lw=0,
markersize=10,
label=scalar))

ax.set_xlabel(f"{analysis_label1} {rtype}",
fontsize=medium_font)
Expand All @@ -1751,9 +1769,15 @@ def compare_reliability(self, reliability1, reliability2,
axis='y', which='major', labelsize=medium_font)
ax.set_ylim(0.2, 1)
ax.set_xlim(0.2, 1)
plt.legend(loc="lower left", ncol=7, bbox_to_anchor=(0., -0.8))
ax.plot([[0, 0], [1, 1]], [[0, 0], [1, 1]], '--')

ax.plot([[0, 0], [1, 1]], [[0, 0], [1, 1]], '--', color='red')
legend_labels.append(Line2D(
[0], [0], linewidth=3, linestyle='--', color='red', label='X=Y'))
if show_legend:
fig.legend(
handles=legend_labels,
fontsize=small_font - 6,
bbox_to_anchor=(1.5, 2.0))
fig.tight_layout()
return fig, ax


Expand Down

0 comments on commit e88266a

Please sign in to comment.