Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that the straightening.cache file is output to the PWD during template registration #4156

Merged
merged 7 commits into from Jul 13, 2023
7 changes: 4 additions & 3 deletions spinalcordtoolbox/scripts/sct_label_vertebrae.py
Expand Up @@ -307,13 +307,13 @@ def main(argv: Sequence[str]):
cache_sig = cache_signature(
input_files=[fname_in, fname_seg],
)
fname_cache = "straightening.cache"
if (cache_valid(os.path.join(curdir, fname_cache), cache_sig)
if (cache_valid(os.path.join(curdir, "straightening.cache"), cache_sig)
and os.path.isfile(os.path.join(curdir, "warp_curve2straight.nii.gz"))
and os.path.isfile(os.path.join(curdir, "warp_straight2curve.nii.gz"))
and os.path.isfile(os.path.join(curdir, "straight_ref.nii.gz"))):
# if they exist, copy them into current folder
printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
copy(os.path.join(curdir, "straightening.cache"), 'straightening.cache')
copy(os.path.join(curdir, "warp_curve2straight.nii.gz"), 'warp_curve2straight.nii.gz')
copy(os.path.join(curdir, "warp_straight2curve.nii.gz"), 'warp_straight2curve.nii.gz')
copy(os.path.join(curdir, "straight_ref.nii.gz"), 'straight_ref.nii.gz')
Expand All @@ -326,7 +326,7 @@ def main(argv: Sequence[str]):
'-r', str(remove_temp_files),
'-v', '0',
])
cache_save(os.path.join(path_output, fname_cache), cache_sig)
cache_save("straightening.cache", cache_sig)

# resample to 0.5mm isotropic to match template resolution
printv('\nResample to 0.5mm isotropic...', verbose)
Expand Down Expand Up @@ -467,6 +467,7 @@ def main(argv: Sequence[str]):
generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"), fname_seg_labeled)
generate_output_file(os.path.join(path_tmp, "segmentation_labeled_disc.nii"), os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
# copy straightening files in case subsequent SCT functions need them
generate_output_file(os.path.join(path_tmp, "straightening.cache"), os.path.join(path_output, "straightening.cache"), verbose=verbose)
generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose=verbose)
generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose=verbose)
generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose=verbose)
Expand Down
10 changes: 7 additions & 3 deletions spinalcordtoolbox/scripts/sct_register_to_template.py
Expand Up @@ -529,9 +529,12 @@ def main(argv: Sequence[str]):
cache_sig = cache_signature(
input_files=cache_input_files,
)
cachefile = os.path.join(curdir, "straightening.cache")
if cache_valid(cachefile, cache_sig) and os.path.isfile(fn_warp_curve2straight) and os.path.isfile(fn_warp_straight2curve) and os.path.isfile(fn_straight_ref):
if (cache_valid(os.path.join(curdir, "straightening.cache"), cache_sig)
and os.path.isfile(fn_warp_curve2straight)
and os.path.isfile(fn_warp_straight2curve)
and os.path.isfile(fn_straight_ref)):
printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
copy(os.path.join(curdir, "straightening.cache"), 'straightening.cache')
copy(fn_warp_curve2straight, 'warp_curve2straight.nii.gz')
copy(fn_warp_straight2curve, 'warp_straight2curve.nii.gz')
copy(fn_straight_ref, 'straight_ref.nii.gz')
Expand Down Expand Up @@ -560,7 +563,7 @@ def main(argv: Sequence[str]):
sc_straight.discs_ref_filename = ftmp_template_label

sc_straight.straighten()
cache_save(cachefile, cache_sig)
cache_save("straightening.cache", cache_sig)

# N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
# re-define warping field using non-cropped space (to avoid issue #367)
Expand Down Expand Up @@ -801,6 +804,7 @@ def main(argv: Sequence[str]):
generate_output_file(os.path.join(path_tmp, "anat2template.nii.gz"), fname_anat2template, verbose=verbose)
if ref == 'template':
# copy straightening files in case subsequent SCT functions need them
generate_output_file(os.path.join(path_tmp, "straightening.cache"), os.path.join(path_output, "straightening.cache"), verbose=verbose)
generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose=verbose)
generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose=verbose)
generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose=verbose)
Expand Down
5 changes: 5 additions & 0 deletions testing/cli/test_cli_sct_label_vertebrae.py
Expand Up @@ -26,6 +26,11 @@ def test_sct_label_vertebrae_consistent_disc(tmp_path):
assert fp == []
assert fn == []

# Ensure that the straightening files are correctly generated in the output directory
for file in ["straightening.cache", "straight_ref.nii.gz",
"warp_straight2curve.nii.gz", "warp_curve2straight.nii.gz"]:
assert os.path.isfile(tmp_path/file)


@pytest.mark.sct_testing
@pytest.mark.usefixtures("run_in_sct_testing_data_dir")
Expand Down
22 changes: 14 additions & 8 deletions testing/cli/test_cli_sct_register_to_template.py
Expand Up @@ -70,38 +70,44 @@ def test_sct_register_to_template_non_rpi_data(tmp_path, template_lpi):
(os.path.join(__sct_dir__, 'data/PAM50/template/PAM50_cord.nii.gz'),
['-ldisc', 't2/labels.nii.gz', '-ref', 'subject'])
])
def test_sct_register_to_template_dice_coefficient_against_groundtruth(fname_gt, remaining_args):
def test_sct_register_to_template_dice_coefficient_against_groundtruth(fname_gt, remaining_args, tmp_path):
"""Run the CLI script and verify transformed images have expected attributes."""
fname_seg = 't2/t2_seg-manual.nii.gz'
dice_threshold = 0.9
sct_register_to_template.main(argv=['-i', 't2/t2.nii.gz', '-s', fname_seg] + remaining_args)
sct_register_to_template.main(argv=['-i', 't2/t2.nii.gz', '-s', fname_seg, '-ofolder', str(tmp_path)]
+ remaining_args)

# Straightening files are only generated for `-ref template`. They should *not* exist for `-ref subject`.
for file in ["straightening.cache", "straight_ref.nii.gz",
"warp_straight2curve.nii.gz", "warp_curve2straight.nii.gz"]:
assert os.path.isfile(tmp_path/file) == (False if 'subject' in remaining_args else True)

# apply transformation to binary mask: template --> anat
sct_apply_transfo.main(argv=[
'-i', fname_gt,
'-d', fname_seg,
'-w', 'warp_template2anat.nii.gz',
'-o', 'test_template2anat.nii.gz',
'-w', str(tmp_path/'warp_template2anat.nii.gz'),
'-o', str(tmp_path/'test_template2anat.nii.gz'),
'-x', 'nn',
'-v', '0'])

# apply transformation to binary mask: anat --> template
sct_apply_transfo.main(argv=[
'-i', fname_seg,
'-d', fname_gt,
'-w', 'warp_anat2template.nii.gz',
'-o', 'test_anat2template.nii.gz',
'-w', str(tmp_path/'warp_anat2template.nii.gz'),
'-o', str(tmp_path/'test_anat2template.nii.gz'),
'-x', 'nn',
'-v', '0'])

# compute dice coefficient between template segmentation warped to anat and segmentation from anat
im_seg = Image(fname_seg)
im_template_seg_reg = Image('test_template2anat.nii.gz')
im_template_seg_reg = Image(str(tmp_path/'test_template2anat.nii.gz'))
dice_template2anat = compute_dice(im_seg, im_template_seg_reg, mode='3d', zboundaries=True)
assert dice_template2anat > dice_threshold

# compute dice coefficient between anat segmentation warped to template and segmentation from template
im_seg_reg = Image('test_anat2template.nii.gz')
im_seg_reg = Image(str(tmp_path/'test_anat2template.nii.gz'))
im_template_seg = Image(fname_gt)
dice_anat2template = compute_dice(im_seg_reg, im_template_seg, mode='3d', zboundaries=True)
assert dice_anat2template > dice_threshold