Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Measure CSA based on distance from pontomedullary junction (PMJ) #3478

Merged
merged 68 commits into from Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
5aca8ef
add args for pmj, distance and extent
sandrinebedard Jun 17, 2021
07d36fa
add script to get mask to compute csa from distance
sandrinebedard Jun 17, 2021
8f98b58
use slices from distance to pmj to compute csa
sandrinebedard Jun 18, 2021
f1be717
get mask of where to compute csa
sandrinebedard Jun 18, 2021
38d7be3
add option to aggregate with slices from pmj and add label
sandrinebedard Jun 18, 2021
dd96f91
add pmj_slices when calling func aggregate
sandrinebedard Jun 18, 2021
433ff7b
adapt index for right lenght
sandrinebedard Jun 18, 2021
641ef77
add comments, clean up
sandrinebedard Jun 18, 2021
a1ad26b
add comments
sandrinebedard Jun 21, 2021
e976786
add comments + change function name
sandrinebedard Jun 21, 2021
c3a18f2
add check is distance out of bound + comments
sandrinebedard Jun 21, 2021
122c615
Merge branch 'master' into sb/3063-csa-measure-from-pmj
sandrinebedard Jun 21, 2021
3e63ff3
add new column for DistancePMJ in tests
sandrinebedard Jun 21, 2021
5be29bd
remove whitespace in line
sandrinebedard Jun 21, 2021
d58abed
change position due to new column in .csv file
sandrinebedard Jun 21, 2021
0696bb0
pass slices and distance from PMJ in API aggregate_slicewise
sandrinebedard Jun 21, 2021
ce7c876
add units of param distance
sandrinebedard Jun 22, 2021
02b2040
typo
sandrinebedard Jun 22, 2021
6bccae4
test if distance is smaller than min distance form PMJ
sandrinebedard Jun 22, 2021
2155bf2
change back mask to native orientation
sandrinebedard Jun 22, 2021
8903cf3
remove a TODO
sandrinebedard Jun 22, 2021
b77e714
add parser error if either distance or pmj not specified
sandrinebedard Jun 22, 2021
df7cbc9
extrapolate centerline and compute distance along
sandrinebedard Jun 22, 2021
0057c50
remove unused variable
sandrinebedard Jun 22, 2021
41e8dbc
set distancePMJ to None, not [None] to pass test
sandrinebedard Jun 23, 2021
db93a02
add unit testing for computing distance from PMJ
sandrinebedard Jun 23, 2021
65a7968
use param minmax from get_centerline to extrapolate
sandrinebedard Jun 27, 2021
9be1b5b
change save of centerline from csa_pmj.py to sct_process_segmentation.py
sandrinebedard Jun 27, 2021
410bb75
add comment in -h
sandrinebedard Jun 27, 2021
387ee69
add warning if extent is out of bound for segmentation
sandrinebedard Jun 27, 2021
2891a11
Tried getting PMJ distance using Centerline class
jcohenadad Jun 28, 2021
753b047
Using interpolation between PMJ and segmentation to compute length
jcohenadad Jun 28, 2021
e943df8
Refactored computation of distance from PMJ using Centerline class
jcohenadad Jun 29, 2021
7bd14cf
Removed unused code
jcohenadad Jun 29, 2021
0597628
Merge branch 'master' into jca/sb/3063-csa-measure-from-pmj
jcohenadad Jun 29, 2021
83b7370
get right index according to z max and min
sandrinebedard Jun 30, 2021
6e99be3
add value error if distance out of range or negative
sandrinebedard Jul 5, 2021
f869536
remove comment
sandrinebedard Jul 5, 2021
3529c2c
fix typo
sandrinebedard Jul 5, 2021
23eb839
fix typo in centerline output name
sandrinebedard Jul 5, 2021
4b43b63
resolve conflict with master
sandrinebedard Jul 6, 2021
9eb1c87
remmove unused imports
sandrinebedard Jul 6, 2021
79752f8
save centerline coordinates in .csv file
sandrinebedard Jul 13, 2021
f00bd13
Merge branch 'master' into jca/sb/3063-csa-measure-from-pmj
sandrinebedard Jul 14, 2021
26106c6
fix conflict with unit tests
sandrinebedard Jul 14, 2021
ce9c111
remove f string
sandrinebedard Jul 19, 2021
39a0037
remove unused import
sandrinebedard Jul 19, 2021
0c96721
remove whitespace in blank line
sandrinebedard Jul 19, 2021
38b7a49
remove test csa pmj since cannot apply here
sandrinebedard Jul 19, 2021
755b112
Merge branch 'master' into jca/sb/3063-csa-measure-from-pmj
sandrinebedard Jul 19, 2021
ed3af4b
Merge branch 'master' into jca/sb/3063-csa-measure-from-pmj
sandrinebedard Jul 21, 2021
797c396
Add QC report for `sct_process_segmentation` for PMJ-based CSA (#3465)
sandrinebedard Jul 25, 2021
d686b9d
update function description and add comments
sandrinebedard Jul 29, 2021
9d3b0ba
save .csv file of the centerline in sct_process_segmentation.py
sandrinebedard Jul 29, 2021
eecb6eb
remove whitespace from line
sandrinebedard Jul 29, 2021
56009cf
Merge branch 'master' into jca/sb/3063-csa-measure-from-pmj
sandrinebedard Jul 29, 2021
3dd2374
add help description to indicate what the correct usage
sandrinebedard Aug 2, 2021
d3d6c3e
change flag -distance for -pmj-distance and -extent for -pmj-extent
sandrinebedard Aug 2, 2021
32e3140
address todo, remove resquires -p since flag doesn't exist anymore
sandrinebedard Aug 2, 2021
e175929
move svaing in qc and to tmp if not verbose 2
sandrinebedard Aug 2, 2021
cd420b4
add test aggregate for slices and pmj distance
sandrinebedard Aug 2, 2021
d151a92
add test to check .csv file with PMJ CSA
sandrinebedard Aug 2, 2021
d44ff06
change flag -distance for -pmj-distance in parser error
sandrinebedard Aug 2, 2021
0e32c43
add test for PMJ-based method with sct_process_segmentation
sandrinebedard Aug 2, 2021
f46999e
Merge branch 'master' into jca/sb/3063-csa-measure-from-pmj
sandrinebedard Aug 2, 2021
4f8a457
adjust sct_process_segmentation usage description
sandrinebedard Aug 2, 2021
bf6c727
Update sct_process_segmentation description
sandrinebedard Aug 2, 2021
8263121
change function name
sandrinebedard Aug 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 11 additions & 5 deletions spinalcordtoolbox/aggregate_slicewise.py
Expand Up @@ -232,15 +232,16 @@ def func_wa(data, mask=None, map_clusters=None):
return np.average(data, weights=mask), None


def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], perslice=None, perlevel=False,
vert_level=None, group_funcs=(('MEAN', func_wa),), map_clusters=None):
def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], distance_pmj=None, perslice=None,
perlevel=False, vert_level=None, group_funcs=(('MEAN', func_wa),), map_clusters=None):
"""
The aggregation will be performed along the last dimension of 'metric' ndarray.

:param metric: Class Metric(): data to aggregate.
:param mask: Class Metric(): mask to use for aggregating the data. Optional.
:param slices: List[int]: Slices to aggregate metric from. If empty, select all slices.
:param levels: List[int]: Vertebral levels to aggregate metric from. It has priority over "slices".
:param distance_pmj: float: Distance from Ponto-Medullary Junction (PMJ) in mm.
:param Bool perslice: Aggregate per slice (True) or across slices (False)
:param Bool perlevel: Aggregate per level (True) or across levels (False). Has priority over "perslice".
:param vert_level: Vertebral level. Could be either an Image or a file name.
Expand Down Expand Up @@ -273,6 +274,7 @@ def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], persli
slices = range(metric.data.shape[ndim-1])

# aggregation based on levels
vertgroups = None
if levels:
im_vert_level = Image(vert_level).change_orientation('RPI')
# slicegroups = [(0, 1, 2), (3, 4, 5), (6, 7, 8)]
Expand All @@ -293,17 +295,20 @@ def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], persli
vertgroups = [tuple([level for level in levels])]
# aggregation based on slices
else:
vertgroups = None
if perslice:
# slicegroups = [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)]
slicegroups = [tuple([slice]) for slice in slices]
else:
# slicegroups = [(0, 1, 2, 3, 4, 5, 6, 7, 8)]
slicegroups = [tuple(slices)]
agg_metric = dict((slicegroup, dict()) for slicegroup in slicegroups)

# loop across slice group
for slicegroup in slicegroups:
# add distance from PMJ info
if distance_pmj is not None:
agg_metric[slicegroup]['DistancePMJ'] = [distance_pmj]
joshuacwnewton marked this conversation as resolved.
Show resolved Hide resolved
else:
agg_metric[slicegroup]['DistancePMJ'] = None
# add level info
if vertgroups is None:
agg_metric[slicegroup]['VertLevel'] = None
Expand Down Expand Up @@ -527,7 +532,7 @@ def save_as_csv(agg_metric, fname_out, fname_in=None, append=False):
if not append or not os.path.isfile(fname_out):
with open(fname_out, 'w') as csvfile:
# spamwriter = csv.writer(csvfile, delimiter=',')
header = ['Timestamp', 'SCT Version', 'Filename', 'Slice (I->S)', 'VertLevel']
header = ['Timestamp', 'SCT Version', 'Filename', 'Slice (I->S)', 'VertLevel', 'DistancePMJ']
agg_metric_key = [v for i, (k, v) in enumerate(agg_metric.items())][0]
for item in list_item:
for key in agg_metric_key:
Expand All @@ -547,6 +552,7 @@ def save_as_csv(agg_metric, fname_out, fname_in=None, append=False):
line.append(fname_in) # file name associated with the results
line.append(parse_num_list_inv(slicegroup)) # list all slices in slicegroup
line.append(parse_num_list_inv(agg_metric[slicegroup]['VertLevel'])) # list vertebral levels
line.append(parse_num_list_inv(agg_metric[slicegroup]['DistancePMJ'])) # list distance from PMJ
agg_metric_key = [v for i, (k, v) in enumerate(agg_metric.items())][0]
for item in list_item:
for key in agg_metric_key:
Expand Down
86 changes: 86 additions & 0 deletions spinalcordtoolbox/csa_pmj.py
@@ -0,0 +1,86 @@
#!/usr/bin/env python
# -*- coding: utf-8
# Functions to get distance from PMJ for processing segmentation data
# Author: Sandrine Bédard
import logging

import numpy as np
from spinalcordtoolbox.image import Image
from spinalcordtoolbox.centerline.core import get_centerline

logger = logging.getLogger(__name__)

NEAR_ZERO_THRESHOLD = 1e-6


def get_slices_for_pmj_distance(segmentation, pmj, distance, extent, param_centerline=None, verbose=1):
"""
Interpolate centerline with pontomedullary junction (PMJ) label and compute distance from PMJ along the centerline.
Generate a mask from segmentation of the slices used to process segmentation data corresponding to a distance from PMJ.
:param segmentation: input segmentation. Could be either an Image or a file name.
:param pmj: label of PMJ.
:param distance: float: Distance from PMJ in mm.
:param extent: extent of the coverage mask in mm.
:param param_centerline: see centerline.core.ParamCenterline()
:param verbose:
:return im_ctl:
:return mask:
:return slices:

"""
im_seg = Image(segmentation).change_orientation('RPI')
native_orientation = im_seg.orientation
im_seg.change_orientation('RPI')
im_pmj = Image(pmj).change_orientation('RPI')
if not im_seg.data.shape == im_pmj.data.shape:
raise RuntimeError("segmentation and pmj should be in the same space coordinate.")
# Add PMJ label to the segmentation and then extrapolate to obtain a Centerline object defined between the PMJ
# and the lower end of the centerline.
im_seg_with_pmj = im_seg.copy()
im_seg_with_pmj.data = im_seg_with_pmj.data + im_pmj.data

# Get max and min index of the segmentation with pmj
_, _, Z = (im_seg_with_pmj.data > NEAR_ZERO_THRESHOLD).nonzero()
min_z_index, max_z_index = min(Z), max(Z)

from spinalcordtoolbox.straightening import _get_centerline
# Linear interpolation (vs. bspline) ensures strong robustness towards defective segmentations at the top slices.
param_centerline.algo_fitting = 'linear'
# On top of the linear interpolation we add some smoothing to remove discontinuities.
param_centerline.smooth = 50
param_centerline.minmax = True
# Compute spinalcordtoolbox.types.Centerline class
ctl_seg_with_pmj = _get_centerline(im_seg_with_pmj, param_centerline, verbose=verbose)
# Also get the image centerline (because it is a required output)
# TODO: merge _get_centerline into get_centerline
im_ctl_seg_with_pmj, arr_ctl, _, _ = get_centerline(im_seg_with_pmj, param_centerline, verbose=verbose)
# Compute the incremental distance from the PMJ along each point in the centerline
length_from_pmj = ctl_seg_with_pmj.incremental_length_inverse[::-1]
# From this incremental distance, find the indices corresponding to the requested distance +/- extent/2 from the PMJ
# Get the z index corresponding to the segmentation since the centerline only includes slices of the segmentation.
z_ref = np.array(range(min_z_index.astype(int), max_z_index.max().astype(int) + 1))
zmin = z_ref[np.argmin(np.array([np.abs(i - distance - extent/2) for i in length_from_pmj]))]
zmax = z_ref[np.argmin(np.array([np.abs(i - distance + extent/2) for i in length_from_pmj]))]

# Check if distance is out of bounds
if distance > length_from_pmj[0]:
raise ValueError("Input distance of " + str(distance) + " mm is out of bounds for maximum distance of " + str(length_from_pmj[0]) + " mm")

if distance < length_from_pmj[-1]:
raise ValueError("Input distance of " + str(distance) + " mm is out of bounds for minimum distance of " + str(length_from_pmj[-1]) + " mm")

# Check if the range of selected slices are covered by the segmentation
if not all(np.any(im_seg.data[:, :, z]) for z in range(zmin, zmax)):
raise ValueError(f"The requested distances from the PMJ are not fully covered by the segmentation.\n"
f"The range of slices are: [{zmin}, {zmax}]")

# Create mask from segmentation centered on distance from PMJ and with extent length on z axis.
mask = im_seg.copy()
mask.data[:, :, 0:zmin] = 0
mask.data[:, :, zmax:] = 0
mask.change_orientation(native_orientation)

# Get corresponding slices
slices = "{}:{}".format(zmin, zmax-1) # -1 since the last slice is included to compute CSA after.

return im_ctl_seg_with_pmj.change_orientation(native_orientation), mask, slices, arr_ctl
52 changes: 36 additions & 16 deletions spinalcordtoolbox/reports/qc.py
Expand Up @@ -62,7 +62,7 @@ class QcImage(object):
"#7d0434", "#fb1849", "#14aab4",
"#a22abd", "#d58240", "#ac2aff"]
_seg_colormap = ["#4d0000", "#ff0000"]

_ctl_colormap = ["#ff000099", '#ffff00']

def __init__(self, qc_report, interpolation, action_list, process, stretch_contrast=True,
stretch_contrast_method='contrast_stretching', angle_line=None, fps=None):
Expand Down Expand Up @@ -264,6 +264,19 @@ def grid(self, mask, ax):
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

def smooth_centerline(self, mask, ax):
"""Display smoothed centerline"""
mask = mask/mask.max()
mask[mask < 0.05] = 0 # Apply 0.5 threshold
img = np.ma.masked_equal(mask, 0)
ax.imshow(img,
cmap=color.LinearSegmentedColormap.from_list("", self._ctl_colormap),
norm=color.Normalize(vmin=0, vmax=1),
interpolation=self.interpolation,
aspect=float(self.aspect_mask))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

# def colorbar(self):
# fig = plt.figure(figsize=(9, 1.5))
# ax = fig.add_axes([0.05, 0.80, 0.9, 0.15])
Expand Down Expand Up @@ -294,10 +307,9 @@ def wrapped_f(sct_slice, *args):
[images_after_moco, images_before_moco], centermass = func(sct_slice, *args)
self._centermass = centermass
self._make_QC_image_for_4d_volumes(images_after_moco, images_before_moco)

else:
if self._angle_line is None:
img, mask = func(sct_slice, *args)
img, *mask = func(sct_slice, *args)
else:
[img, mask], centermass = func(sct_slice, *args)
self._centermass = centermass
Expand All @@ -311,8 +323,8 @@ def _make_QC_image_for_3d_volumes(self, img, mask, slice_orientation):
Create overlay and background images for all processes that deal with 3d volumes
(all except sct_fmri_moco and sct_dmri_moco)

:param img: list of mosaic images after motion correction
:param mask: list of mosaic images before motion correction
:param img: The base image to display underneath the overlays (typically anatomical)
:param mask: A list of images to be processed and overlaid on top of `img`
:return:
"""

Expand All @@ -337,17 +349,17 @@ def _make_QC_image_for_3d_volumes(self, img, mask, slice_orientation):
logger.info(self.qc_report.qc_params.abs_bkg_img_path())
self._save(fig, self.qc_report.qc_params.abs_bkg_img_path(), dpi=self.qc_report.qc_params.dpi)

for action in self.action_list:
fig = Figure()
fig.set_size_inches(size_fig[0], size_fig[1], forward=True)
FigureCanvas(fig)
for i, action in enumerate(self.action_list):
logger.debug('Action List %s', action.__name__)
if self._stretch_contrast and action.__name__ in ("no_seg_seg",):
print("Mask type %s" % mask.dtype)
mask = self._func_stretch_contrast(mask)
fig = Figure()
fig.set_size_inches(size_fig[0], size_fig[1], forward=True)
FigureCanvas(fig)
ax = fig.add_axes((0, 0, 1, 1))
action(self, mask, ax)
self._save(fig, self.qc_report.qc_params.abs_overlay_img_path(), dpi=self.qc_report.qc_params.dpi)
print("Mask type %s" % mask[i].dtype)
mask[i] = self._func_stretch_contrast(mask[i])
ax = fig.add_axes((0, 0, 1, 1), label=str(i))
action(self, mask[i], ax)
self._save(fig, self.qc_report.qc_params.abs_overlay_img_path(), dpi=self.qc_report.qc_params.dpi)

self.qc_report.update_description_file(img.shape)

Expand Down Expand Up @@ -836,7 +848,15 @@ def qcslice_layout(x): return x.single()
def qcslice_layout(x): return x.single()
# Metric outputs (only graphs)
elif process in ['sct_process_segmentation']:
assert os.path.isfile(path_img)
plane = 'Sagittal'
dpi = 100 # bigger picture is needed for this special case, hence reduce dpi
fname_list = [fname_in1]
# fname_seg should be a list of 4 images: 3 for each of the `qcslice_operations`, plus an extra
# centerline image, which is needed to make `Sagittal.get_center_spit` work correctly
fname_list.extend(fname_seg)
qcslice_type = qcslice.Sagittal([Image(fname) for fname in fname_list], p_resample=None)
qcslice_operations = [QcImage.smooth_centerline, QcImage.highlight_pmj, QcImage.listed_seg]
def qcslice_layout(x): return x.single()
else:
raise ValueError("Unrecognized process: {}".format(process))

Expand All @@ -855,7 +875,7 @@ def qcslice_layout(x): return x.single()
qcslice_layout=qcslice_layout,
stretch_contrast_method='equalized',
angle_line=angle_line,
fps=fps,
fps=fps
)


Expand Down