Skip to content

Commit

Permalink
Merge pull request #732 from samaloney/feat-crop-keepdims
Browse files Browse the repository at this point in the history
Add a keepdims kwarg to crop and crop_by_value to keep length-1 dimensions
  • Loading branch information
DanRyanIrish committed Jun 21, 2024
2 parents e7a0e20 + 15122a8 commit 0107c5d
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog/732.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a ``keepdims=False`` kwarg to `~ndcube.NDCube.crop` and `~ndcube.NDCube.crop_by_values` setting to true keeps length-1 dimensions default behavior drops these dimensions.
40 changes: 40 additions & 0 deletions docs/explaining_ndcube/slicing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ This means that these world points are not used in calculating the pixel range t
>>> upper_left = [None, SkyCoord(Tx=1, Ty=1.5, unit=u.deg, frame=Helioprojective)]
>>> my_cube_roi = my_cube.crop(lower_left, upper_right, lower_right, upper_left)
By default, :meth:`~ndcube.NDCube.crop` and :meth:`~ndcube.NDCube.crop_by_values` discard length-1 dimensions to make the resulting cube more wieldy.
However, there are cases where it is preferable to keep the number of dimensions the same.
In such cases setting the :code:`keepdims=True` kwarg in either crop or crop_by_values.

>>> # Use coordinate objects to mark the lower limit of the region of interest.
>>> lower_left = [SpectralCoord(1.02e-9, unit=u.m),
... SkyCoord(Tx=1, Ty=0.5, unit=u.deg, frame=Helioprojective)]
>>> upper_right = [SpectralCoord(1.03e-9, unit=u.m),
... SkyCoord(Tx=1.5, Ty=1.5, unit=u.deg, frame=Helioprojective)]
>>> lower_right = [None, SkyCoord(Tx=1.5, Ty=0.5, unit=u.deg, frame=Helioprojective)]
>>> upper_left = [None, SkyCoord(Tx=1, Ty=1.5, unit=u.deg, frame=Helioprojective)]
>>> my_cube_roi = my_cube.crop(lower_left, upper_right, lower_right, upper_left)
>>> my_cube_roi.shape
(2, 3)
>>> my_cube_roi_keep = my_cube.crop(lower_left, upper_right, lower_right, upper_left,
... keepdims=True)
>>> my_cube_roi_keep.shape
(2, 3, 1)

One use case for :code:`keepdims=True` is when cropping leads to a cube with only one array element.
Because cropping an `~ndcube.NDCube` to a scalar is not allowed, such an operation would normally raise an error.
But if :code:`keepdims=True`, a valid NDCube is returned with N length-1 dimensions.

>>> # Use coordinate objects to mark the lower limit of the region of interest.
>>> lower_left = [SpectralCoord(1.02e-9, unit=u.m),
... SkyCoord(Tx=1.5, Ty=0.5, unit=u.deg, frame=Helioprojective)]
>>> upper_right = [SpectralCoord(1.03e-9, unit=u.m),
... SkyCoord(Tx=1.5, Ty=0.5, unit=u.deg, frame=Helioprojective)]
>>> lower_right = [None, SkyCoord(Tx=1.5, Ty=0.5, unit=u.deg, frame=Helioprojective)]
>>> upper_left = [None, SkyCoord(Tx=1.5, Ty=0.5, unit=u.deg, frame=Helioprojective)]
>>> my_cube_roi = my_cube.crop(lower_left, upper_right, lower_right, upper_left)
Traceback (most recent call last):
...
ValueError: Input points causes cube to be cropped to a single pixel. This is not supported.
>>> my_cube_roi_keep = my_cube.crop(lower_left, upper_right, lower_right, upper_left,
... keepdims=True)
>>> my_cube_roi_keep.shape
(1, 1, 1)


.. _sequence_slicing:

Slicing NDCubeSequences
Expand Down
30 changes: 20 additions & 10 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def axis_world_coords_values(self,
@abc.abstractmethod
def crop(self,
*points: Iterable[Any],
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False,
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -215,6 +216,10 @@ def crop(self,
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.
keepdims: `bool`, optional
If `False` and if cropping results in length-1 dimensions, these are sliced away in output cube.
If `True`, length-1 dimensions are kept. Default=False
Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand All @@ -231,7 +236,8 @@ def crop(self,
def crop_by_values(self,
*points: Iterable[u.Quantity | float],
units: Iterable[str | u.Unit] | None = None,
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None
wcs: BaseHighLevelWCS | ExtraCoordsABC | None = None,
keepdims: bool = False
) -> "NDCubeABC":
"""
Crop using real world coordinates.
Expand Down Expand Up @@ -264,6 +270,10 @@ def crop_by_values(self,
could be used it is expected that either the ``.wcs`` or
``.extra_coords`` properties will be used.
keepdims: `bool`, optional
If `False` and if cropping results in length-1 dimensions, these are sliced away in output cube.
If `True`, length-1 dimensions are kept. Default=False
Returns
-------
`~ndcube.ndcube.NDCubeABC`
Expand Down Expand Up @@ -554,14 +564,14 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
CoordValues = namedtuple("CoordValues", identifiers)
return CoordValues(*axes_coords[::-1])

def crop(self, *points, wcs=None):
def crop(self, *points, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_item(*points, wcs=wcs)
item = self._get_crop_item(*points, wcs=wcs, keepdims=keepdims)
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_item(self, *points, wcs=None):
def _get_crop_item(self, *points, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand All @@ -584,16 +594,16 @@ def _get_crop_item(self, *points, wcs=None):
raise TypeError(f"{type(value)} of component {j} in point {i} is "
f"incompatible with WCS component {comp[j]} "
f"{classes[j]}.")
return utils.cube.get_crop_item_from_points(points, wcs, False)
return utils.cube.get_crop_item_from_points(points, wcs, False, keepdims=keepdims)

def crop_by_values(self, *points, units=None, wcs=None):
def crop_by_values(self, *points, units=None, wcs=None, keepdims=False):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs)
item = self._get_crop_by_values_item(*points, units=units, wcs=wcs, keepdims=keepdims)
return self[item]

@utils.cube.sanitize_wcs
def _get_crop_by_values_item(self, *points, units=None, wcs=None):
def _get_crop_by_values_item(self, *points, units=None, wcs=None, keepdims=False):
# Sanitize inputs.
no_op, points, wcs = utils.cube.sanitize_crop_inputs(points, wcs)
# Quit out early if we are no-op
Expand Down Expand Up @@ -626,7 +636,7 @@ def _get_crop_by_values_item(self, *points, units=None, wcs=None):
raise UnitsError(f"Unit '{points[i][j].unit}' of coordinate object {j} in point {i} is "
f"incompatible with WCS unit '{wcs.world_axis_units[j]}'") from err

return utils.cube.get_crop_item_from_points(points, wcs, True)
return utils.cube.get_crop_item_from_points(points, wcs, True, keepdims=keepdims)

def __str__(self):
return textwrap.dedent(f"""\
Expand Down
21 changes: 21 additions & 0 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,15 @@ def test_crop_reduces_dimensionality(ndcube_4d_ln_lt_l_t):
helpers.assert_cubes_equal(output, expected)


def test_crop_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
point = (None, SpectralCoord([3e-11], unit=u.m), None)
output = cube.crop(point, keepdims=True)
expected = cube[:, :, 0:1, :]
assert output.shape == (5, 8, 1, 12)
helpers.assert_cubes_equal(output, expected)


def test_crop_scalar_valuerror(ndcube_2d_ln_lt):
cube = ndcube_2d_ln_lt
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)
Expand Down Expand Up @@ -506,6 +515,18 @@ def test_crop_by_values(ndcube_4d_ln_lt_l_t):
helpers.assert_cubes_equal(output, expected)


def test_crop_by_values_keepdims(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
intervals = list(cube.wcs.array_index_to_world_values([1, 2], [0], [0, 1], [0, 2]))
units = [u.min, u.m, u.deg, u.deg]
lower_corner = [coord[0] * unit for coord, unit in zip(intervals, units)]
upper_corner = [coord[-1] * unit for coord, unit in zip(intervals, units)]
expected = cube[1:3, 0:1, 0:2, 0:3]
output = cube.crop_by_values(lower_corner, upper_corner, keepdims=True)
assert output.shape == (2, 1, 2, 3)
helpers.assert_cubes_equal(output, expected)


def test_crop_by_values_with_units(ndcube_4d_ln_lt_l_t):
intervals = ndcube_4d_ln_lt_l_t.wcs.array_index_to_world_values([1, 2], [0, 1], [0, 1], [0, 2])
units = [u.min, u.m, u.deg, u.deg]
Expand Down
7 changes: 5 additions & 2 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def sanitize_crop_inputs(points, wcs):
return False, points, wcs


def get_crop_item_from_points(points, wcs, crop_by_values):
def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
"""
Find slice item that crops to minimum cube in array-space containing specified world points.
Expand All @@ -121,6 +121,9 @@ def get_crop_item_from_points(points, wcs, crop_by_values):
Denotes whether cropping is done using high-level objects or "values",
i.e. low-level objects.
keep_dims : `bool`
If `False`, returned item will drop length-1 dimensions otherwise, item will keep length-1 dimensions.
Returns
-------
item : `tuple` of `slice`
Expand Down Expand Up @@ -190,7 +193,7 @@ def get_crop_item_from_points(points, wcs, crop_by_values):
else:
min_idx = min(axis_indices)
max_idx = max(axis_indices) + 1
if max_idx - min_idx == 1:
if max_idx - min_idx == 1 and not keepdims:
item.append(min_idx)
else:
item.append(slice(min_idx, max_idx))
Expand Down

0 comments on commit 0107c5d

Please sign in to comment.