diff --git a/spinalcordtoolbox/aggregate_slicewise.py b/spinalcordtoolbox/aggregate_slicewise.py index dd4d4c6904..ff64caf530 100644 --- a/spinalcordtoolbox/aggregate_slicewise.py +++ b/spinalcordtoolbox/aggregate_slicewise.py @@ -232,8 +232,8 @@ def func_wa(data, mask=None, map_clusters=None): return np.average(data, weights=mask), None -def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], perslice=None, perlevel=False, - vert_level=None, group_funcs=(('MEAN', func_wa),), map_clusters=None): +def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], distance_pmj=None, perslice=None, + perlevel=False, vert_level=None, group_funcs=(('MEAN', func_wa),), map_clusters=None): """ The aggregation will be performed along the last dimension of 'metric' ndarray. @@ -241,6 +241,7 @@ def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], persli :param mask: Class Metric(): mask to use for aggregating the data. Optional. :param slices: List[int]: Slices to aggregate metric from. If empty, select all slices. :param levels: List[int]: Vertebral levels to aggregate metric from. It has priority over "slices". + :param distance_pmj: float: Distance from Ponto-Medullary Junction (PMJ) in mm. :param Bool perslice: Aggregate per slice (True) or across slices (False) :param Bool perlevel: Aggregate per level (True) or across levels (False). Has priority over "perslice". :param vert_level: Vertebral level. Could be either an Image or a file name. @@ -273,6 +274,7 @@ def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], persli slices = range(metric.data.shape[ndim-1]) # aggregation based on levels + vertgroups = None if levels: im_vert_level = Image(vert_level).change_orientation('RPI') # slicegroups = [(0, 1, 2), (3, 4, 5), (6, 7, 8)] @@ -293,7 +295,6 @@ def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], persli vertgroups = [tuple([level for level in levels])] # aggregation based on slices else: - vertgroups = None if perslice: # slicegroups = [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)] slicegroups = [tuple([slice]) for slice in slices] @@ -301,9 +302,13 @@ def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], persli # slicegroups = [(0, 1, 2, 3, 4, 5, 6, 7, 8)] slicegroups = [tuple(slices)] agg_metric = dict((slicegroup, dict()) for slicegroup in slicegroups) - # loop across slice group for slicegroup in slicegroups: + # add distance from PMJ info + if distance_pmj is not None: + agg_metric[slicegroup]['DistancePMJ'] = [distance_pmj] + else: + agg_metric[slicegroup]['DistancePMJ'] = None # add level info if vertgroups is None: agg_metric[slicegroup]['VertLevel'] = None @@ -527,7 +532,7 @@ def save_as_csv(agg_metric, fname_out, fname_in=None, append=False): if not append or not os.path.isfile(fname_out): with open(fname_out, 'w') as csvfile: # spamwriter = csv.writer(csvfile, delimiter=',') - header = ['Timestamp', 'SCT Version', 'Filename', 'Slice (I->S)', 'VertLevel'] + header = ['Timestamp', 'SCT Version', 'Filename', 'Slice (I->S)', 'VertLevel', 'DistancePMJ'] agg_metric_key = [v for i, (k, v) in enumerate(agg_metric.items())][0] for item in list_item: for key in agg_metric_key: @@ -547,6 +552,7 @@ def save_as_csv(agg_metric, fname_out, fname_in=None, append=False): line.append(fname_in) # file name associated with the results line.append(parse_num_list_inv(slicegroup)) # list all slices in slicegroup line.append(parse_num_list_inv(agg_metric[slicegroup]['VertLevel'])) # list vertebral levels + line.append(parse_num_list_inv(agg_metric[slicegroup]['DistancePMJ'])) # list distance from PMJ agg_metric_key = [v for i, (k, v) in enumerate(agg_metric.items())][0] for item in list_item: for key in agg_metric_key: diff --git a/spinalcordtoolbox/csa_pmj.py b/spinalcordtoolbox/csa_pmj.py new file mode 100644 index 0000000000..7a854d025c --- /dev/null +++ b/spinalcordtoolbox/csa_pmj.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 +# Functions to get distance from PMJ for processing segmentation data +# Author: Sandrine Bédard +import logging + +import numpy as np +from spinalcordtoolbox.image import Image +from spinalcordtoolbox.centerline.core import get_centerline + +logger = logging.getLogger(__name__) + +NEAR_ZERO_THRESHOLD = 1e-6 + + +def get_slices_for_pmj_distance(segmentation, pmj, distance, extent, param_centerline=None, verbose=1): + """ + Interpolate centerline with pontomedullary junction (PMJ) label and compute distance from PMJ along the centerline. + Generate a mask from segmentation of the slices used to process segmentation data corresponding to a distance from PMJ. + :param segmentation: input segmentation. Could be either an Image or a file name. + :param pmj: label of PMJ. + :param distance: float: Distance from PMJ in mm. + :param extent: extent of the coverage mask in mm. + :param param_centerline: see centerline.core.ParamCenterline() + :param verbose: + :return im_ctl: + :return mask: + :return slices: + + """ + im_seg = Image(segmentation).change_orientation('RPI') + native_orientation = im_seg.orientation + im_seg.change_orientation('RPI') + im_pmj = Image(pmj).change_orientation('RPI') + if not im_seg.data.shape == im_pmj.data.shape: + raise RuntimeError("segmentation and pmj should be in the same space coordinate.") + # Add PMJ label to the segmentation and then extrapolate to obtain a Centerline object defined between the PMJ + # and the lower end of the centerline. + im_seg_with_pmj = im_seg.copy() + im_seg_with_pmj.data = im_seg_with_pmj.data + im_pmj.data + + # Get max and min index of the segmentation with pmj + _, _, Z = (im_seg_with_pmj.data > NEAR_ZERO_THRESHOLD).nonzero() + min_z_index, max_z_index = min(Z), max(Z) + + from spinalcordtoolbox.straightening import _get_centerline + # Linear interpolation (vs. bspline) ensures strong robustness towards defective segmentations at the top slices. + param_centerline.algo_fitting = 'linear' + # On top of the linear interpolation we add some smoothing to remove discontinuities. + param_centerline.smooth = 50 + param_centerline.minmax = True + # Compute spinalcordtoolbox.types.Centerline class + ctl_seg_with_pmj = _get_centerline(im_seg_with_pmj, param_centerline, verbose=verbose) + # Also get the image centerline (because it is a required output) + # TODO: merge _get_centerline into get_centerline + im_ctl_seg_with_pmj, arr_ctl, _, _ = get_centerline(im_seg_with_pmj, param_centerline, verbose=verbose) + # Compute the incremental distance from the PMJ along each point in the centerline + length_from_pmj = ctl_seg_with_pmj.incremental_length_inverse[::-1] + # From this incremental distance, find the indices corresponding to the requested distance +/- extent/2 from the PMJ + # Get the z index corresponding to the segmentation since the centerline only includes slices of the segmentation. + z_ref = np.array(range(min_z_index.astype(int), max_z_index.max().astype(int) + 1)) + zmin = z_ref[np.argmin(np.array([np.abs(i - distance - extent/2) for i in length_from_pmj]))] + zmax = z_ref[np.argmin(np.array([np.abs(i - distance + extent/2) for i in length_from_pmj]))] + + # Check if distance is out of bounds + if distance > length_from_pmj[0]: + raise ValueError("Input distance of " + str(distance) + " mm is out of bounds for maximum distance of " + str(length_from_pmj[0]) + " mm") + + if distance < length_from_pmj[-1]: + raise ValueError("Input distance of " + str(distance) + " mm is out of bounds for minimum distance of " + str(length_from_pmj[-1]) + " mm") + + # Check if the range of selected slices are covered by the segmentation + if not all(np.any(im_seg.data[:, :, z]) for z in range(zmin, zmax)): + raise ValueError(f"The requested distances from the PMJ are not fully covered by the segmentation.\n" + f"The range of slices are: [{zmin}, {zmax}]") + + # Create mask from segmentation centered on distance from PMJ and with extent length on z axis. + mask = im_seg.copy() + mask.data[:, :, 0:zmin] = 0 + mask.data[:, :, zmax:] = 0 + mask.change_orientation(native_orientation) + + # Get corresponding slices + slices = "{}:{}".format(zmin, zmax-1) # -1 since the last slice is included to compute CSA after. + + return im_ctl_seg_with_pmj.change_orientation(native_orientation), mask, slices, arr_ctl diff --git a/spinalcordtoolbox/reports/qc.py b/spinalcordtoolbox/reports/qc.py index bab15878da..abef972a75 100644 --- a/spinalcordtoolbox/reports/qc.py +++ b/spinalcordtoolbox/reports/qc.py @@ -62,7 +62,7 @@ class QcImage(object): "#7d0434", "#fb1849", "#14aab4", "#a22abd", "#d58240", "#ac2aff"] _seg_colormap = ["#4d0000", "#ff0000"] - + _ctl_colormap = ["#ff000099", '#ffff00'] def __init__(self, qc_report, interpolation, action_list, process, stretch_contrast=True, stretch_contrast_method='contrast_stretching', angle_line=None, fps=None): @@ -264,6 +264,19 @@ def grid(self, mask, ax): ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) + def smooth_centerline(self, mask, ax): + """Display smoothed centerline""" + mask = mask/mask.max() + mask[mask < 0.05] = 0 # Apply 0.5 threshold + img = np.ma.masked_equal(mask, 0) + ax.imshow(img, + cmap=color.LinearSegmentedColormap.from_list("", self._ctl_colormap), + norm=color.Normalize(vmin=0, vmax=1), + interpolation=self.interpolation, + aspect=float(self.aspect_mask)) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + # def colorbar(self): # fig = plt.figure(figsize=(9, 1.5)) # ax = fig.add_axes([0.05, 0.80, 0.9, 0.15]) @@ -294,10 +307,9 @@ def wrapped_f(sct_slice, *args): [images_after_moco, images_before_moco], centermass = func(sct_slice, *args) self._centermass = centermass self._make_QC_image_for_4d_volumes(images_after_moco, images_before_moco) - else: if self._angle_line is None: - img, mask = func(sct_slice, *args) + img, *mask = func(sct_slice, *args) else: [img, mask], centermass = func(sct_slice, *args) self._centermass = centermass @@ -311,8 +323,8 @@ def _make_QC_image_for_3d_volumes(self, img, mask, slice_orientation): Create overlay and background images for all processes that deal with 3d volumes (all except sct_fmri_moco and sct_dmri_moco) - :param img: list of mosaic images after motion correction - :param mask: list of mosaic images before motion correction + :param img: The base image to display underneath the overlays (typically anatomical) + :param mask: A list of images to be processed and overlaid on top of `img` :return: """ @@ -337,17 +349,17 @@ def _make_QC_image_for_3d_volumes(self, img, mask, slice_orientation): logger.info(self.qc_report.qc_params.abs_bkg_img_path()) self._save(fig, self.qc_report.qc_params.abs_bkg_img_path(), dpi=self.qc_report.qc_params.dpi) - for action in self.action_list: + fig = Figure() + fig.set_size_inches(size_fig[0], size_fig[1], forward=True) + FigureCanvas(fig) + for i, action in enumerate(self.action_list): logger.debug('Action List %s', action.__name__) if self._stretch_contrast and action.__name__ in ("no_seg_seg",): - print("Mask type %s" % mask.dtype) - mask = self._func_stretch_contrast(mask) - fig = Figure() - fig.set_size_inches(size_fig[0], size_fig[1], forward=True) - FigureCanvas(fig) - ax = fig.add_axes((0, 0, 1, 1)) - action(self, mask, ax) - self._save(fig, self.qc_report.qc_params.abs_overlay_img_path(), dpi=self.qc_report.qc_params.dpi) + print("Mask type %s" % mask[i].dtype) + mask[i] = self._func_stretch_contrast(mask[i]) + ax = fig.add_axes((0, 0, 1, 1), label=str(i)) + action(self, mask[i], ax) + self._save(fig, self.qc_report.qc_params.abs_overlay_img_path(), dpi=self.qc_report.qc_params.dpi) self.qc_report.update_description_file(img.shape) @@ -836,7 +848,15 @@ def qcslice_layout(x): return x.single() def qcslice_layout(x): return x.single() # Metric outputs (only graphs) elif process in ['sct_process_segmentation']: - assert os.path.isfile(path_img) + plane = 'Sagittal' + dpi = 100 # bigger picture is needed for this special case, hence reduce dpi + fname_list = [fname_in1] + # fname_seg should be a list of 4 images: 3 for each of the `qcslice_operations`, plus an extra + # centerline image, which is needed to make `Sagittal.get_center_spit` work correctly + fname_list.extend(fname_seg) + qcslice_type = qcslice.Sagittal([Image(fname) for fname in fname_list], p_resample=None) + qcslice_operations = [QcImage.smooth_centerline, QcImage.highlight_pmj, QcImage.listed_seg] + def qcslice_layout(x): return x.single() else: raise ValueError("Unrecognized process: {}".format(process)) @@ -855,7 +875,7 @@ def qcslice_layout(x): return x.single() qcslice_layout=qcslice_layout, stretch_contrast_method='equalized', angle_line=angle_line, - fps=fps, + fps=fps ) diff --git a/spinalcordtoolbox/scripts/sct_process_segmentation.py b/spinalcordtoolbox/scripts/sct_process_segmentation.py index ca2b3abd4d..1c7725a664 100755 --- a/spinalcordtoolbox/scripts/sct_process_segmentation.py +++ b/spinalcordtoolbox/scripts/sct_process_segmentation.py @@ -18,6 +18,7 @@ import sys import os +import logging import numpy as np from matplotlib.ticker import MaxNLocator @@ -25,12 +26,17 @@ from spinalcordtoolbox.aggregate_slicewise import aggregate_per_slice_or_level, save_as_csv, func_wa, func_std, \ func_sum, merge_dict from spinalcordtoolbox.process_seg import compute_shape +from spinalcordtoolbox.scripts import sct_maths +from spinalcordtoolbox.csa_pmj import get_slices_for_pmj_distance from spinalcordtoolbox.centerline.core import ParamCenterline +from spinalcordtoolbox.image import add_suffix, splitext from spinalcordtoolbox.reports.qc import generate_qc from spinalcordtoolbox.utils.shell import SCTArgumentParser, Metavar, ActionCreateFolder, parse_num_list, display_open from spinalcordtoolbox.utils.sys import init_sct, set_loglevel from spinalcordtoolbox.utils.fs import get_absolute_path +logger = logging.getLogger(__name__) + def get_parser(): """ @@ -54,6 +60,16 @@ def get_parser(): "metric is interesting for detecting non-convex shape (e.g., in case of strong compression)\n" " - length: Length of the segmentation, computed by summing the slice thickness (corrected for the " "centerline angle at each slice) across the specified superior-inferior region.\n" + "\n" + "To select the region to compute metrics over, choose one of the following arguments:\n" + " 1. '-z': Select axial slices based on slice index.\n" + " 2. '-pmj' + '-pmj-distance' + '-pmj-extent': Select axial slices based on distance from pontomedullary " + "junction.\n" + " (For options 1 and 2, you can also add '-perslice' to compute metrics for each axial slice, rather " + "than averaging.)\n" + " 3. '-vert' + '-vertfile': Select a region based on vertebral labels instead of individual slices.\n" + " (For option 3, you can also add '-perlevel' to compute metrics for each vertebral level, rather " + "than averaging.)" ) ) @@ -90,7 +106,7 @@ def get_parser(): '-z', metavar=Metavar.str, type=str, - help="Slice range to compute the metrics across (requires '-p csa'). Example: 5:23" + help="Slice range to compute the metrics across. Example: 5:23" ) optional.add_argument( '-perslice', @@ -111,7 +127,7 @@ def get_parser(): '-vertfile', metavar=Metavar.str, default='./label/template/PAM50_levels.nii.gz', - help="R|Vertebral labeling file. Only use with flag -vert.\n" + help="R|Vertebral labeling file. Only use with flag -vert.\n" "The input and the vertebral labelling file must in the same voxel coordinate system " "and must match the dimensions between each other. " ) @@ -156,11 +172,40 @@ def get_parser(): default=30, help="Degree of smoothing for centerline fitting. Only use with -centerline-algo {bspline, linear}." ) + optional.add_argument( + '-pmj', + metavar=Metavar.file, + help="Ponto-Medullary Junction (PMJ) label file. " + "Example: pmj.nii.gz" + ) + optional.add_argument( + '-pmj-distance', + type=float, + metavar=Metavar.float, + help="Distance (mm) from Ponto-Medullary Junction (PMJ) to the center of the mask used to compute morphometric " + "measures. (To be used with flag '-pmj'.)" + ) + optional.add_argument( + '-pmj-extent', + type=float, + metavar=Metavar.float, + default=20, + help="Extent (in mm) for the mask used to compute morphometric measures. Each slice covered by the mask is " + "included in the calculation. (To be used with flag '-pmj' and '-pmj-distance'.)" + ) optional.add_argument( '-qc', metavar=Metavar.folder, action=ActionCreateFolder, help="The path where the quality control generated content will be saved." + " The QC report is only available for PMJ-based CSA (with flag '-pmj')." + ) + optional.add_argument( + '-qc-image', + metavar=Metavar.str, + help="Input image to display in QC report. Typically, it would be the " + "source anatomical image used to generate the spinal cord " + "segmentation. This flag is mandatory if using flag '-qc'." ) optional.add_argument( '-qc-dataset', @@ -307,10 +352,24 @@ def main(argv=None): algo_fitting=arguments.centerline_algo, smooth=arguments.centerline_smooth, minmax=True) + if arguments.pmj is not None: + fname_pmj = get_absolute_path(arguments.pmj) + else: + fname_pmj = None + if arguments.pmj_distance is not None: + distance_pmj = arguments.pmj_distance + else: + distance_pmj = None + extent_mask = arguments.pmj_extent path_qc = arguments.qc qc_dataset = arguments.qc_dataset qc_subject = arguments.qc_subject + mutually_inclusive_args = (fname_pmj, distance_pmj) + is_pmj_none, is_distance_none = [arg is None for arg in mutually_inclusive_args] + if not (is_pmj_none == is_distance_none): + raise parser.error("Both '-pmj' and '-pmj-distance' are required in order to process segmentation from PMJ.") + # update fields metrics_agg = {} if not file_out: @@ -320,27 +379,69 @@ def main(argv=None): angle_correction=angle_correction, param_centerline=param_centerline, verbose=verbose) + if fname_pmj is not None: + im_ctl, mask, slices, centerline = get_slices_for_pmj_distance(fname_segmentation, fname_pmj, + distance_pmj, extent_mask, + param_centerline=param_centerline, + verbose=verbose) + + # Save array of the centerline in a .csv file if verbose == 2 + if verbose == 2: + fname_ctl_csv, _ = splitext(add_suffix(arguments.i, '_centerline_extrapolated')) + np.savetxt(fname_ctl_csv + '.csv', centerline, delimiter=",") + for key in metrics: if key == 'length': # For computing cord length, slice-wise length needs to be summed across slices metrics_agg[key] = aggregate_per_slice_or_level(metrics[key], slices=parse_num_list(slices), - levels=parse_num_list(vert_levels), perslice=perslice, + levels=parse_num_list(vert_levels), + distance_pmj=distance_pmj, perslice=perslice, perlevel=perlevel, vert_level=fname_vert_levels, group_funcs=(('SUM', func_sum),)) else: # For other metrics, we compute the average and standard deviation across slices metrics_agg[key] = aggregate_per_slice_or_level(metrics[key], slices=parse_num_list(slices), - levels=parse_num_list(vert_levels), perslice=perslice, + levels=parse_num_list(vert_levels), + distance_pmj=distance_pmj, perslice=perslice, perlevel=perlevel, vert_level=fname_vert_levels, group_funcs=group_funcs) metrics_agg_merged = merge_dict(metrics_agg) save_as_csv(metrics_agg_merged, file_out, fname_in=fname_segmentation, append=append) - - # QC report (only show CSA for clarity) + # QC report (only for PMJ-based CSA) if path_qc is not None: - generate_qc(fname_segmentation, args=arguments, path_qc=os.path.abspath(path_qc), dataset=qc_dataset, - subject=qc_subject, path_img=_make_figure(metrics_agg_merged, fit_results), - process='sct_process_segmentation') + if fname_pmj is not None: + if arguments.qc_image is not None: + fname_mask_out = add_suffix(arguments.i, '_mask_csa') + fname_ctl = add_suffix(arguments.i, '_centerline_extrapolated') + fname_ctl_smooth = add_suffix(fname_ctl, '_smooth') + if verbose != 2: + from spinalcordtoolbox.utils.fs import tmp_create + path_tmp = tmp_create() + fname_mask_out = os.path.join(path_tmp, fname_mask_out) + fname_ctl = os.path.join(path_tmp, fname_ctl) + fname_ctl_smooth = os.path.join(path_tmp, fname_ctl_smooth) + # Save mask + mask.save(fname_mask_out) + # Save extrapolated centerline + im_ctl.save(fname_ctl) + # Generated centerline smoothed in RL direction for visualization (and QC report) + sct_maths.main(['-i', fname_ctl, '-smooth', '10,1,1', '-o', fname_ctl_smooth]) + + generate_qc(fname_in1=get_absolute_path(arguments.qc_image), + # NB: For this QC figure, the centerline has to be first in the list in order for the centerline + # to be properly layered underneath the PMJ + mask. However, Sagittal.get_center_spit + # is called during QC, and it uses `fname_seg[-1]` to center the slices. `fname_mask_out` + # doesn't work for this, so we have to repeat `fname_ctl_smooth` at the end of the list. + fname_seg=[fname_ctl_smooth, fname_pmj, fname_mask_out, fname_ctl_smooth], + args=sys.argv[1:], + path_qc=os.path.abspath(path_qc), + dataset=qc_dataset, + subject=qc_subject, + process='sct_process_segmentation') + else: + raise parser.error('-qc-image is required to display QC report.') + else: + logger.warning('QC report only available for PMJ-based CSA. QC report not generated.') display_open(file_out) diff --git a/testing/api/test_aggregate_slicewise.py b/testing/api/test_aggregate_slicewise.py index 31fb239b34..0d0e881609 100644 --- a/testing/api/test_aggregate_slicewise.py +++ b/testing/api/test_aggregate_slicewise.py @@ -125,7 +125,7 @@ def test_aggregate_across_levels(dummy_metrics, dummy_vert_level): perslice=False, perlevel=False, vert_level=dummy_vert_level, group_funcs=(('WA', aggregate_slicewise.func_wa),)) - assert agg_metric[(0, 1, 2, 3)] == {'VertLevel': (2, 3), 'WA()': 35.0} + assert agg_metric[(0, 1, 2, 3)] == {'VertLevel': (2, 3), 'DistancePMJ': None, 'WA()': 35.0} # noinspection 801,PyShadowingNames @@ -135,8 +135,8 @@ def test_aggregate_across_levels_perslice(dummy_metrics, dummy_vert_level): perslice=True, perlevel=False, vert_level=dummy_vert_level, group_funcs=(('WA', aggregate_slicewise.func_wa),)) - assert agg_metric[(0,)] == {'VertLevel': (2,), 'WA()': 29.0} - assert agg_metric[(2,)] == {'VertLevel': (3,), 'WA()': 39.0} + assert agg_metric[(0,)] == {'VertLevel': (2,), 'DistancePMJ': None, 'WA()': 29.0} + assert agg_metric[(2,)] == {'VertLevel': (3,), 'DistancePMJ': None, 'WA()': 39.0} # noinspection 801,PyShadowingNames @@ -145,8 +145,17 @@ def test_aggregate_per_level(dummy_metrics, dummy_vert_level): agg_metric = aggregate_slicewise.aggregate_per_slice_or_level(dummy_metrics['with float'], levels=[2, 3], perlevel=True, vert_level=dummy_vert_level, group_funcs=(('WA', aggregate_slicewise.func_wa),)) - assert agg_metric[(0, 1)] == {'VertLevel': (2,), 'WA()': 30.0} - assert agg_metric[(2, 3)] == {'VertLevel': (3,), 'WA()': 40.0} + assert agg_metric[(0, 1)] == {'VertLevel': (2,), 'DistancePMJ': None, 'WA()': 30.0} + assert agg_metric[(2, 3)] == {'VertLevel': (3,), 'DistancePMJ': None, 'WA()': 40.0} + + +# noinspection 801,PyShadowingNames +def test_aggregate_slices_pmj(dummy_metrics): + """Test extraction of metrics aggregation within selected slices at a PMJ distance""" + agg_metric = aggregate_slicewise.aggregate_per_slice_or_level(dummy_metrics['with float'], slices=[2, 3, 4, 5], + distance_pmj=64, perslice=False, perlevel=False, + group_funcs=(('WA', aggregate_slicewise.func_wa),)) + assert agg_metric[(2, 3, 4, 5)] == {'VertLevel': None, 'DistancePMJ': [64], 'WA()': 45.25} # noinspection 801,PyShadowingNames @@ -226,15 +235,15 @@ def test_save_as_csv(tmp_path, dummy_metrics): with open(path_out, 'r') as csvfile: spamreader = csv.reader(csvfile, delimiter=',') next(spamreader) # skip header - assert next(spamreader)[1:] == [__version__, 'FakeFile.txt', '3:4', '', '45.5', '4.5'] + assert next(spamreader)[1:] == [__version__, 'FakeFile.txt', '3:4', '', '', '45.5', '4.5'] # with appending aggregate_slicewise.save_as_csv(agg_metric, path_out) aggregate_slicewise.save_as_csv(agg_metric, path_out, append=True) with open(path_out, 'r') as csvfile: spamreader = csv.reader(csvfile, delimiter=',') next(spamreader) # skip header - assert next(spamreader)[1:] == [__version__, '', '3:4', '', '45.5', '4.5'] - assert next(spamreader)[1:] == [__version__, '', '3:4', '', '45.5', '4.5'] + assert next(spamreader)[1:] == [__version__, '', '3:4', '', '', '45.5', '4.5'] + assert next(spamreader)[1:] == [__version__, '', '3:4', '', '', '45.5', '4.5'] # noinspection 801,PyShadowingNames @@ -304,6 +313,21 @@ def test_save_as_csv_sorting(tmp_path, dummy_metrics): assert [row['Slice (I->S)'] for row in spamreader] == ['0', '1', '2', '3', '4', '5', '6', '7', '8'] +def test_save_as_csv_pmj(tmp_path, dummy_metrics): + """Test writing of output metric csv file with distance from PMJ method""" + path_out = str(tmp_path / 'tmp_file_out.csv') + agg_metric = aggregate_slicewise.aggregate_per_slice_or_level(dummy_metrics['with float'], slices=[2, 3, 4, 5], + distance_pmj=64.0, perslice=False, perlevel=False, + group_funcs=(('WA', aggregate_slicewise.func_wa),)) + aggregate_slicewise.save_as_csv(agg_metric, path_out) + with open(path_out, 'r') as csvfile: + reader = csv.DictReader(csvfile, delimiter=',') + row = next(reader) + assert row['Slice (I->S)'] == '2:5' + assert row['DistancePMJ'] == '64.0' + assert row['VertLevel'] == '' + + # noinspection 801,PyShadowingNames def test_save_as_csv_extract_metric(tmp_path, dummy_data_and_labels): """Test file output with extract_metric()""" @@ -316,7 +340,7 @@ def test_save_as_csv_extract_metric(tmp_path, dummy_data_and_labels): with open(path_out, 'r') as csvfile: spamreader = csv.reader(csvfile, delimiter=',') next(spamreader) # skip header - assert next(spamreader)[1:-1] == [__version__, '', '0:4', '', 'label_0', '2.5', '38.0'] + assert next(spamreader)[1:-1] == [__version__, '', '0:4', '', '', 'label_0', '2.5', '38.0'] def test_dimension_mismatch_between_metric_and_vertfile(dummy_metrics, dummy_vert_level): diff --git a/testing/cli/test_cli_sct_process_segmentation.py b/testing/cli/test_cli_sct_process_segmentation.py index 36fb0149d3..ed21f83796 100644 --- a/testing/cli/test_cli_sct_process_segmentation.py +++ b/testing/cli/test_cli_sct_process_segmentation.py @@ -1,11 +1,57 @@ import pytest import logging +import numpy as np +import tempfile +import nibabel +import csv from spinalcordtoolbox.scripts import sct_process_segmentation logger = logging.getLogger(__name__) +@pytest.fixture(scope="session") +def dummy_3d_mask_nib(): + data = np.zeros([32, 32, 32], dtype=np.uint8) + data[9:24, 9:24, 9:24] = 1 + nii = nibabel.nifti1.Nifti1Image(data, np.eye(4)) + filename = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name + nibabel.save(nii, filename) + return filename + + +@pytest.fixture(scope="session") +def dummy_3d_pmj_label(): + data = np.zeros([32, 32, 32], dtype=np.uint8) + data[15, 15, 28] = 1 + nii = nibabel.nifti1.Nifti1Image(data, np.eye(4)) + filename = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name + nibabel.save(nii, filename) + return filename + + +def test_sct_process_segmentation_check_pmj(dummy_3d_mask_nib, dummy_3d_pmj_label, tmp_path): + """ Run sct_process_segmentation with -pmj, -pmj-distance and -pmj-extent and check the results""" + filename = str(tmp_path / 'tmp_file_out.csv') + sct_process_segmentation.main(argv=['-i', dummy_3d_mask_nib, '-pmj', dummy_3d_pmj_label, + '-pmj-distance', '8', '-pmj-extent', '4', '-o', filename]) + with open(filename, "r") as csvfile: + reader = csv.DictReader(csvfile, delimiter=',') + row = next(reader) + assert row['Slice (I->S)'] == '18:21' + assert row['DistancePMJ'] == '8.0' + assert row['VertLevel'] == '' + assert row['SUM(length)'] == '4.0' + + +def test_sct_process_segmentation_missing_pmj_args(dummy_3d_mask_nib, dummy_3d_pmj_label): + """ Run sct_process_segmentation with PMJ method when missing -pmj or -pmj-distance """ + for args in [['-i', dummy_3d_mask_nib, '-pmj', dummy_3d_pmj_label], ['-i', dummy_3d_mask_nib, '-pmj-distance', '4']]: + with pytest.raises(SystemExit) as e: + sct_process_segmentation.main(argv=args) + assert e.value.code == 2 + + @pytest.mark.sct_testing @pytest.mark.usefixtures("run_in_sct_testing_data_dir") def test_sct_process_segmentation_no_checks():