diff --git a/changelog/283.feature.rst b/changelog/283.feature.rst new file mode 100644 index 000000000..d2ff601c2 --- /dev/null +++ b/changelog/283.feature.rst @@ -0,0 +1 @@ +Add new method, `~ndcube.NDCube.axis_world_coord_values`, to return world coords for all pixels for all axes in WCS as quantity objects. diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index 4ddb28e7f..8504ecba3 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -2,6 +2,8 @@ import abc import warnings import textwrap +import numbers +from collections import namedtuple import numpy as np import astropy.nddata @@ -9,7 +11,7 @@ from ndcube import utils from ndcube.ndcube_sequence import NDCubeSequence -from ndcube.utils.wcs import wcs_ivoa_mapping, reduced_correlation_matrix_and_world_physical_types +from ndcube.utils import wcs as wcs_utils from ndcube.utils.cube import _pixel_centers_or_edges, _get_dimension_for_pixel from ndcube.mixins import NDCubeSlicingMixin, NDCubePlotMixin @@ -237,10 +239,11 @@ def world_axis_physical_types(self): if not axis: # Find keys in wcs_ivoa_mapping dict that represent start of CTYPE. # Ensure CTYPE is capitalized. - keys = list(filter(lambda key: ctype[i].upper().startswith(key), wcs_ivoa_mapping)) + keys = list(filter(lambda key: ctype[i].upper().startswith(key), + wcs_utils.wcs_ivoa_mapping)) # Assuming CTYPE is supported by wcs_ivoa_mapping, use its corresponding axis name. if len(keys) == 1: - axis_name = wcs_ivoa_mapping.get(keys[0]) + axis_name = wcs_utils.wcs_ivoa_mapping.get(keys[0]) # If CTYPE not supported, raise a warning and set the axis name to CTYPE. elif len(keys) == 0: warnings.warn("CTYPE not recognized by ndcube. " @@ -270,7 +273,7 @@ def array_axis_physical_types(self): multiple array axes, the same physical type string can appear in multiple tuples. """ axis_correlation_matrix, world_axis_physical_types = \ - reduced_correlation_matrix_and_world_physical_types( + wcs_utils.reduced_correlation_matrix_and_world_physical_types( self.wcs.axis_correlation_matrix, self.wcs.world_axis_physical_types, self.missing_axes) return [tuple(world_axis_physical_types[axis_correlation_matrix[:, i]]) @@ -447,6 +450,104 @@ def axis_world_coords(self, *axes, edges=False): else: return tuple(axes_coords) + def axis_world_coord_values(self, *axes, edges=False): + """ + Returns WCS coordinate values of all pixels for desired axes. + + Parameters + ---------- + axes: `int` or `str`, or multiple `int` or `str` + Axis number in numpy ordering or unique substring of + `~ndcube.NDCube.wcs.world_axis_physical_types` + of axes for which real world coordinates are desired. + axes=None implies all axes will be returned. + + edges: `bool` + If True, the coords at the edges of the pixels are returned + rather than the coords at the center of the pixels. + Note that there are n+1 edges for n pixels which is reflected + in the returned coords. + Default=False, i.e. pixel centers are returned. + + Returns + ------- + coord_values: `collections.namedtuple` + Real world coords labeled with their real world physical types + for the axes requested by the user. + Returned in same order as axis_names. + + Example + ------- + >>> NDCube.all_world_coords_values(('lat', 'lon')) # doctest: +SKIP + >>> NDCube.all_world_coords_values(2) # doctest: +SKIP + """ + # Create meshgrid of all pixel coordinates. + # If user, wants edges, set pixel values to pixel edges. + # Else make pixel centers. + wcs_shape = self.data.shape[::-1] + # Insert length-1 axes for missing axes. + for i in np.arange(len(self.missing_axes))[self.missing_axes]: + wcs_shape = np.insert(wcs_shape, i, 1) + if edges: + wcs_shape = tuple(np.array(wcs_shape) + 1) + pixel_inputs = np.meshgrid(*[np.arange(i) - 0.5 for i in wcs_shape], + indexing='ij', sparse=True) + else: + pixel_inputs = np.meshgrid(*[np.arange(i) for i in wcs_shape], + indexing='ij', sparse=True) + + # Get world coords for all axes and all pixels. + axes_coords = list(self.wcs.pixel_to_world_values(*pixel_inputs)) + + # Reduce duplication across independent dimensions for each coord + # and transpose to make dimensions mimic numpy array order rather than WCS order. + # Add units to coords + for i, axis_coord in enumerate(axes_coords): + slices = np.array([slice(None)] * self.wcs.world_n_dim) + slices[np.invert(self.wcs.axis_correlation_matrix[i])] = 0 + axes_coords[i] = axis_coord[tuple(slices)].T + axes_coords[i] *= u.Unit(self.wcs.world_axis_units[i]) + + world_axis_physical_types = self.wcs.world_axis_physical_types + # If user has supplied axes, extract only the + # world coords that correspond to those axes. + if axes: + # Convert input axes to WCS world axis indices. + world_indices = set() + for axis in axes: + if isinstance(axis, numbers.Integral): + # If axis is int, it is a numpy order array axis. + # Convert to pixel axis in WCS order. + axis = wcs_utils.convert_between_array_and_pixel_axes( + np.array([axis]), self.wcs.pixel_n_dim)[0] + # Get WCS world axis indices that correspond to the WCS pixel axis + # and add to list of indices of WCS world axes whose coords will be returned. + world_indices.update(wcs_utils.pixel_axis_to_world_axes( + axis, self.wcs.axis_correlation_matrix)) + elif isinstance(axis, str): + # If axis is str, it is a physical type or substring of a physical type. + world_indices.update({wcs_utils.physical_type_to_world_axis( + axis, world_axis_physical_types)}) + else: + raise TypeError(f"Unrecognized axis type: {axis, type(axis)}. " + "Must be of type (numbers.Integral, str)") + # Use inferred world axes to extract the desired coord value + # and corresponding physical types. + world_indices = np.array(list(world_indices), dtype=int) + axes_coords = np.array(axes_coords)[world_indices] + world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices]) + + # Return in array order. + # First replace characters in physical types forbidden for namedtuple identifiers. + identifiers = [] + for physical_type in world_axis_physical_types[::-1]: + identifier = physical_type.replace(":", "_") + identifier = identifier.replace(".", "_") + identifier = identifier.replace("-", "__") + identifiers.append(identifier) + CoordValues = namedtuple("CoordValues", identifiers) + return CoordValues(*axes_coords[::-1]) + @property def extra_coords(self): """ diff --git a/ndcube/tests/test_ndcube.py b/ndcube/tests/test_ndcube.py index 5ffff86f9..7b30e2478 100644 --- a/ndcube/tests/test_ndcube.py +++ b/ndcube/tests/test_ndcube.py @@ -947,6 +947,6 @@ def test_array_axis_physical_types(): ('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat'), ('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat'), ('em.wl',), ('time',)] - output = cube.array_axis_physical_types - for i in range(len(expected)): - assert all([physical_type in expected[i] for physical_type in output[i]]) + output = cube_disordered.array_axis_physical_types + for i, expected_i in enumerate(expected): + assert all([physical_type in expected_i for physical_type in output[i]]) diff --git a/ndcube/tests/test_utils_wcs.py b/ndcube/tests/test_utils_wcs.py index 6631f00b3..f04c6f0e9 100644 --- a/ndcube/tests/test_utils_wcs.py +++ b/ndcube/tests/test_utils_wcs.py @@ -37,6 +37,34 @@ wm_reindexed_102 = utils.wcs.WCS(header=hm_reindexed_102, naxis=3) +@pytest.fixture +def axis_correlation_matrix(): + return _axis_correlation_matrix() + + +def _axis_correlation_matrix(): + shape = (4, 4) + acm = np.zeros(shape, dtype=bool) + for i in range(min(shape)): + acm[i, i] = True + acm[0, 1] = True + acm[1, 0] = True + acm[-1, 0] = True + return acm + + +@pytest.fixture +def test_wcs(): + return TestWCS() + + +class TestWCS(): + def __init__(self): + self.world_axis_physical_types = [ + 'custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat', 'em.wl', 'time'] + self.axis_correlation_matrix = _axis_correlation_matrix() + + @pytest.mark.parametrize("test_input,expected", [(ht, True), (hm, False)]) def test_wcs_needs_augmenting(test_input, expected): assert utils.wcs.WCS._needs_augmenting(test_input) is expected @@ -115,3 +143,25 @@ def test_get_dependent_wcs_axes(test_input, expected): ]) def test_axis_correlation_matrix(test_input, expected): assert (utils.wcs.axis_correlation_matrix(test_input) == expected).all() + + +def test_convert_between_array_and_pixel_axes(): + test_input = np.array([1, 4, -2]) + naxes = 5 + expected = np.array([3, 0, 1]) + output = utils.wcs.convert_between_array_and_pixel_axes(test_input, naxes) + assert all(output == expected) + + +def test_pixel_axis_to_world_axes(axis_correlation_matrix): + output = utils.wcs.pixel_axis_to_world_axes(0, axis_correlation_matrix) + expected = np.array([0, 1, 3]) + assert all(output == expected) + + +@pytest.mark.parametrize("test_input,expected", [('wl', 2), ('em.wl', 2)]) +def test_physical_type_to_world_axis(test_input, expected): + world_axis_physical_types = ['custom:pos.helioprojective.lon', + 'custom:pos.helioprojective.lat', 'em.wl', 'time'] + output = utils.wcs.physical_type_to_world_axis(test_input, world_axis_physical_types) + assert output == expected diff --git a/ndcube/utils/wcs.py b/ndcube/utils/wcs.py index 0575a4abf..92ddd48f4 100644 --- a/ndcube/utils/wcs.py +++ b/ndcube/utils/wcs.py @@ -441,6 +441,99 @@ def append_sequence_axis_to_wcs(wcs_object): return WCS(wcs_header) +def convert_between_array_and_pixel_axes(axis, naxes): + """Reflects axis index about center of number of axes. + + This is used to convert between array axes in numpy order and pixel axes in WCS order. + Works in both directions. + + Parameters + ---------- + axis: `numpy.ndarray` of `int` + The axis number(s) before reflection. + + naxes: `int` + The number of array axes. + + Returns + ------- + reflected_axis: `numpy.ndarray` of `int` + The axis number(s) after reflection. + """ + # Check type of input. + if not isinstance(axis, np.ndarray): + raise TypeError("input must be of array type. Got type: {type(axis)}") + if axis.dtype.char not in np.typecodes['AllInteger']: + raise TypeError("input dtype must be of int type. Got dtype: {axis.dtype})") + # Convert negative indices to positive equivalents. + axis[axis < 0] += naxes + if any(axis > naxes - 1): + raise IndexError("Axis out of range. " + f"Number of axes = {naxes}; Axis numbers requested = {axes}") + # Reflect axis about center of number of axes. + reflected_axis = naxes - 1 - axis + + return reflected_axis + + +def pixel_axis_to_world_axes(pixel_axis, axis_correlation_matrix): + """ + Retrieves the indices of the world axis physical types corresponding to a pixel axis. + + Parameters + ---------- + pixel_axis: `int` + The pixel axis index/indices for which the world axes are desired. + + axis_correlation_matrix: `numpy.ndarray` of `bool` + 2D boolean correlation matrix defining the dependence between the pixel and world axes. + Format same as `astropy.wcs.BaseLowLevelWCS.axis_correlation_matrix`. + + Returns + ------- + world_axes: `numpy.ndarray` + The world axis indices corresponding to the pixel axis. + """ + return np.arange(axis_correlation_matrix.shape[0])[axis_correlation_matrix[:, pixel_axis]] + + +def physical_type_to_world_axis(physical_type, world_axis_physical_types): + """ + Returns world axis index of a physical type based on WCS world_axis_physical_types. + + Input can be a substring of a physical type, so long as it is unique. + + Parameters + ---------- + physical_type: `str` + The physical type or a substring unique to a physical type. + + world_axis_physical_types: sequence of `str` + All available physical types. Ordering must be same as + `astropy.wcs.BaseLowLevelWCS.world_axis_physical_types` + + Returns + ------- + world_axis: `numbers.Integral` + The world axis index of the physical type. + """ + # Find world axis index described by physical type. + widx = np.where(world_axis_physical_types == physical_type)[0] + # If physical type does not correspond to entry in world_axis_physical_types, + # check if it is a substring of any physical types. + if len(widx) == 0: + widx = [physical_type in world_axis_physical_type + for world_axis_physical_type in world_axis_physical_types] + widx = np.arange(len(world_axis_physical_types))[widx] + if len(widx) != 1: + raise ValueError( + "Input does not uniquely correspond to a physical type." + f" Expected unique substring of one of {world_axis_physical_types}." + f" Got: {physical_type}") + # Return axes with duplicates removed. + return widx[0] + + def reduced_axis_correlation_matrix(axis_correlation_matrix, missing_axes, return_world_indices=False): """