Skip to content

Commit

Permalink
Merge pull request #490 from bloomdt-uw/fix_sphinx-gallery-workdirs
Browse files Browse the repository at this point in the history
adding separate example output directories to avoid name collisions
  • Loading branch information
arokem committed Oct 6, 2020
2 parents fcd3c4d + 89403a9 commit aaa1db4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 29 deletions.
34 changes: 20 additions & 14 deletions examples/plot_recobundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import AFQ.segmentation as seg
import AFQ.api as api

# Target directory for this example's output files
working_dir = "./recobundles"

dpd.fetch_stanford_hardi()

hardi_dir = op.join(fetcher.dipy_home, "stanford_hardi")
Expand All @@ -44,32 +47,33 @@
img = nib.load(hardi_fdata)

print("Calculating DTI...")
if not op.exists('./dti_FA.nii.gz'):
if not op.exists(op.join(working_dir, 'dti_FA.nii.gz')):
dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
out_dir='.')
out_dir=working_dir)
else:
dti_params = {'FA': './dti_FA.nii.gz',
'params': './dti_params.nii.gz'}
dti_params = {'FA': op.join(working_dir, 'dti_FA.nii.gz'),
'params': op.join(working_dir, 'dti_params.nii.gz')}

FA_img = nib.load(dti_params['FA'])
FA_data = FA_img.get_fdata()

print("Registering to template...")
MNI_T2_img = afd.read_mni_template()
if not op.exists('mapping.nii.gz'):
if not op.exists(op.join(working_dir, 'mapping.nii.gz')):
import dipy.core.gradients as dpg
gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata, gtab,
template=MNI_T2_img)
reg.write_mapping(mapping, './mapping.nii.gz')
reg.write_mapping(mapping, op.join(working_dir, 'mapping.nii.gz'))
else:
mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)
mapping = reg.read_mapping(op.join(working_dir, 'mapping.nii.gz'), img,
MNI_T2_img)

bundle_names = ["CST", "UF", "CC_ForcepsMajor", "CC_ForcepsMinor", "OR", "VOF"]
bundles = api.make_bundle_dict(bundle_names=bundle_names, seg_algo="reco80")

print("Tracking...")
if not op.exists('dti_streamlines_reco.trk'):
if not op.exists(op.join(working_dir, 'dti_streamlines_reco.trk')):
seed_roi = np.zeros(img.shape[:-1])
for bundle in bundles:
if bundle != 'whole_brain':
Expand All @@ -86,7 +90,7 @@
np.linalg.inv(MNI_T2_img.affine)))

sft = StatefulTractogram(sl_xform, img, Space.RASMM)
save_tractogram(sft, f'./{bundle}_atlas.trk')
save_tractogram(sft, op.join(working_dir, f'{bundle}_atlas.trk'))

sl_xform = dts.Streamlines(
dtu.transform_tracking_output(sl_xform,
Expand All @@ -98,15 +102,17 @@
sl_as_idx[:, 1],
sl_as_idx[:, 2]] = 1

nib.save(nib.Nifti1Image(seed_roi, img.affine), 'seed_roi.nii.gz')
nib.save(nib.Nifti1Image(seed_roi, img.affine),
op.join(working_dir, 'seed_roi.nii.gz'))
sft = aft.track(dti_params['params'], seed_mask=seed_roi,
directions='det', stop_mask=FA_data,
stop_threshold=0.1)
print(len(sft.streamlines))
save_tractogram(sft, './dti_streamlines_reco.trk',
save_tractogram(sft, op.join(working_dir, 'dti_streamlines_reco.trk'),
bbox_valid_check=False)
else:
sft = load_tractogram('./dti_streamlines_reco.trk', img)
sft = load_tractogram(op.join(working_dir, 'dti_streamlines_reco.trk'),
img)

print("Segmenting fiber groups...")
segmentation = seg.Segmentation(seg_algo='reco80',
Expand All @@ -127,7 +133,7 @@
sft = StatefulTractogram(fiber_groups[kk].streamlines,
img,
Space.RASMM)
save_tractogram(sft, './%s_reco.trk' % kk,
save_tractogram(sft, op.join(working_dir, '%s_reco.trk' % kk),
bbox_valid_check=False)


Expand All @@ -137,7 +143,7 @@

fig, ax = plt.subplots(1)
sft = load_tractogram(
f'./{bundle}_reco.trk',
op.join(working_dir, f'{bundle}_reco.trk'),
img,
to_space=Space.VOX,
bbox_valid_check=False)
Expand Down
34 changes: 19 additions & 15 deletions examples/plot_tract_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import logging
logging.basicConfig(level=logging.INFO)

# Target directory for this example's output files
working_dir = "./tract_profile"

##########################################################################
# Get example data:
Expand All @@ -48,12 +50,12 @@
# -------------------------

print("Calculating DTI...")
if not op.exists('./dti_FA.nii.gz'):
if not op.exists(op.join(working_dir, 'dti_FA.nii.gz')):
dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
out_dir='.')
out_dir=working_dir)
else:
dti_params = {'FA': './dti_FA.nii.gz',
'params': './dti_params.nii.gz'}
dti_params = {'FA': op.join(working_dir, 'dti_FA.nii.gz'),
'params': op.join(working_dir, 'dti_params.nii.gz')}

FA_img = nib.load(dti_params['FA'])
FA_data = FA_img.get_fdata()
Expand All @@ -75,14 +77,15 @@
#
print("Registering to template...")
MNI_T2_img = afd.read_mni_template()
if not op.exists('mapping.nii.gz'):
if not op.exists(op.join(working_dir, 'mapping.nii.gz')):
import dipy.core.gradients as dpg
gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)

warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata, gtab)
reg.write_mapping(mapping, './mapping.nii.gz')
reg.write_mapping(mapping, op.join(working_dir, 'mapping.nii.gz'))
else:
mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)
mapping = reg.read_mapping(op.join(working_dir, 'mapping.nii.gz'),
img, MNI_T2_img)


##########################################################################
Expand All @@ -104,7 +107,7 @@
# algorithm. For speed, we seed only within the waypoint ROIs for each bundle.

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
if not op.exists(op.join(working_dir, 'dti_streamlines.trk')):
seed_roi = np.zeros(img.shape[:-1])
for bundle in bundles:
for idx, roi in enumerate(bundles[bundle]['ROIs']):
Expand All @@ -116,17 +119,17 @@
bundle_name=bundle)

nib.save(nib.Nifti1Image(warped_roi.astype(float), img.affine),
f"{bundle}_{idx+1}.nii.gz")
op.join(working_dir, f"{bundle}_{idx+1}.nii.gz"))
# Add voxels that aren't there yet:
seed_roi = np.logical_or(seed_roi, warped_roi)
nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine),
'seed_roi.nii.gz')
op.join(working_dir, 'seed_roi.nii.gz'))
sft = aft.track(dti_params['params'], seed_mask=seed_roi,
stop_mask=FA_data, stop_threshold=0.1)
save_tractogram(sft, './dti_streamlines.trk',
save_tractogram(sft, op.join(working_dir, 'dti_streamlines.trk'),
bbox_valid_check=False)
else:
sft = load_tractogram('./dti_streamlines.trk', img)
sft = load_tractogram(op.join(working_dir, 'dti_streamlines.trk'), img)

sft.to_vox()

Expand Down Expand Up @@ -167,12 +170,12 @@
print(f"Afer cleaning: {len(new_fibers)} streamlines")

idx_in_global = fiber_groups[bundle]['idx'][idx_in_bundle]
np.save(f'{bundle}_idx.npy', idx_in_global)
np.save(op.join(working_dir, f'{bundle}_idx.npy'), idx_in_global)
sft = StatefulTractogram(new_fibers.streamlines,
img,
Space.VOX)
sft.to_rasmm()
save_tractogram(sft, f'./{bundle}_afq.trk',
save_tractogram(sft, op.join(working_dir, f'{bundle}_afq.trk'),
bbox_valid_check=False)


Expand All @@ -186,7 +189,8 @@

print("Extracting tract profiles...")
for bundle in bundles:
sft = load_tractogram(f'./{bundle}_afq.trk', img, to_space=Space.VOX)
sft = load_tractogram(op.join(working_dir, f'{bundle}_afq.trk'),
img, to_space=Space.VOX)
fig, ax = plt.subplots(1)
weights = gaussian_weights(sft.streamlines)
profile = afq_profile(FA_data, sft.streamlines,
Expand Down

0 comments on commit aaa1db4

Please sign in to comment.