Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/283.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add new method, `~ndcube.NDCube.axis_world_coord_values`, to return world coords for all pixels for all axes in WCS as quantity objects.
109 changes: 105 additions & 4 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import abc
import warnings
import textwrap
import numbers
from collections import namedtuple

import numpy as np
import astropy.nddata
import astropy.units as u

from ndcube import utils
from ndcube.ndcube_sequence import NDCubeSequence
from ndcube.utils.wcs import wcs_ivoa_mapping, reduced_correlation_matrix_and_world_physical_types
from ndcube.utils import wcs as wcs_utils
from ndcube.utils.cube import _pixel_centers_or_edges, _get_dimension_for_pixel
from ndcube.mixins import NDCubeSlicingMixin, NDCubePlotMixin

Expand Down Expand Up @@ -237,10 +239,11 @@ def world_axis_physical_types(self):
if not axis:
# Find keys in wcs_ivoa_mapping dict that represent start of CTYPE.
# Ensure CTYPE is capitalized.
keys = list(filter(lambda key: ctype[i].upper().startswith(key), wcs_ivoa_mapping))
keys = list(filter(lambda key: ctype[i].upper().startswith(key),
wcs_utils.wcs_ivoa_mapping))
# Assuming CTYPE is supported by wcs_ivoa_mapping, use its corresponding axis name.
if len(keys) == 1:
axis_name = wcs_ivoa_mapping.get(keys[0])
axis_name = wcs_utils.wcs_ivoa_mapping.get(keys[0])
# If CTYPE not supported, raise a warning and set the axis name to CTYPE.
elif len(keys) == 0:
warnings.warn("CTYPE not recognized by ndcube. "
Expand Down Expand Up @@ -270,7 +273,7 @@ def array_axis_physical_types(self):
multiple array axes, the same physical type string can appear in multiple tuples.
"""
axis_correlation_matrix, world_axis_physical_types = \
reduced_correlation_matrix_and_world_physical_types(
wcs_utils.reduced_correlation_matrix_and_world_physical_types(
self.wcs.axis_correlation_matrix, self.wcs.world_axis_physical_types,
self.missing_axes)
return [tuple(world_axis_physical_types[axis_correlation_matrix[:, i]])
Expand Down Expand Up @@ -447,6 +450,104 @@ def axis_world_coords(self, *axes, edges=False):
else:
return tuple(axes_coords)

def axis_world_coord_values(self, *axes, edges=False):
"""
Returns WCS coordinate values of all pixels for desired axes.

Parameters
----------
axes: `int` or `str`, or multiple `int` or `str`
Axis number in numpy ordering or unique substring of
`~ndcube.NDCube.wcs.world_axis_physical_types`
of axes for which real world coordinates are desired.
axes=None implies all axes will be returned.

edges: `bool`
If True, the coords at the edges of the pixels are returned
rather than the coords at the center of the pixels.
Note that there are n+1 edges for n pixels which is reflected
in the returned coords.
Default=False, i.e. pixel centers are returned.

Returns
-------
coord_values: `collections.namedtuple`
Real world coords labeled with their real world physical types
for the axes requested by the user.
Returned in same order as axis_names.

Example
-------
>>> NDCube.all_world_coords_values(('lat', 'lon')) # doctest: +SKIP
>>> NDCube.all_world_coords_values(2) # doctest: +SKIP
"""
# Create meshgrid of all pixel coordinates.
# If user, wants edges, set pixel values to pixel edges.
# Else make pixel centers.
wcs_shape = self.data.shape[::-1]
# Insert length-1 axes for missing axes.
for i in np.arange(len(self.missing_axes))[self.missing_axes]:
wcs_shape = np.insert(wcs_shape, i, 1)
if edges:
wcs_shape = tuple(np.array(wcs_shape) + 1)
pixel_inputs = np.meshgrid(*[np.arange(i) - 0.5 for i in wcs_shape],
indexing='ij', sparse=True)
else:
pixel_inputs = np.meshgrid(*[np.arange(i) for i in wcs_shape],
indexing='ij', sparse=True)

# Get world coords for all axes and all pixels.
axes_coords = list(self.wcs.pixel_to_world_values(*pixel_inputs))

# Reduce duplication across independent dimensions for each coord
# and transpose to make dimensions mimic numpy array order rather than WCS order.
# Add units to coords
for i, axis_coord in enumerate(axes_coords):
slices = np.array([slice(None)] * self.wcs.world_n_dim)
slices[np.invert(self.wcs.axis_correlation_matrix[i])] = 0
axes_coords[i] = axis_coord[tuple(slices)].T
axes_coords[i] *= u.Unit(self.wcs.world_axis_units[i])

world_axis_physical_types = self.wcs.world_axis_physical_types
# If user has supplied axes, extract only the
# world coords that correspond to those axes.
if axes:
# Convert input axes to WCS world axis indices.
world_indices = set()
for axis in axes:
if isinstance(axis, numbers.Integral):
# If axis is int, it is a numpy order array axis.
# Convert to pixel axis in WCS order.
axis = wcs_utils.convert_between_array_and_pixel_axes(
np.array([axis]), self.wcs.pixel_n_dim)[0]
# Get WCS world axis indices that correspond to the WCS pixel axis
# and add to list of indices of WCS world axes whose coords will be returned.
world_indices.update(wcs_utils.pixel_axis_to_world_axes(
axis, self.wcs.axis_correlation_matrix))
elif isinstance(axis, str):
# If axis is str, it is a physical type or substring of a physical type.
world_indices.update({wcs_utils.physical_type_to_world_axis(
axis, world_axis_physical_types)})
else:
raise TypeError(f"Unrecognized axis type: {axis, type(axis)}. "
"Must be of type (numbers.Integral, str)")
# Use inferred world axes to extract the desired coord value
# and corresponding physical types.
world_indices = np.array(list(world_indices), dtype=int)
axes_coords = np.array(axes_coords)[world_indices]
world_axis_physical_types = tuple(np.array(world_axis_physical_types)[world_indices])

# Return in array order.
# First replace characters in physical types forbidden for namedtuple identifiers.
identifiers = []
for physical_type in world_axis_physical_types[::-1]:
identifier = physical_type.replace(":", "_")
identifier = identifier.replace(".", "_")
identifier = identifier.replace("-", "__")
identifiers.append(identifier)
CoordValues = namedtuple("CoordValues", identifiers)
return CoordValues(*axes_coords[::-1])

@property
def extra_coords(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,6 @@ def test_array_axis_physical_types():
('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat'),
('custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat'),
('em.wl',), ('time',)]
output = cube.array_axis_physical_types
for i in range(len(expected)):
assert all([physical_type in expected[i] for physical_type in output[i]])
output = cube_disordered.array_axis_physical_types
for i, expected_i in enumerate(expected):
assert all([physical_type in expected_i for physical_type in output[i]])
50 changes: 50 additions & 0 deletions ndcube/tests/test_utils_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,34 @@
wm_reindexed_102 = utils.wcs.WCS(header=hm_reindexed_102, naxis=3)


@pytest.fixture
def axis_correlation_matrix():
return _axis_correlation_matrix()


def _axis_correlation_matrix():
shape = (4, 4)
acm = np.zeros(shape, dtype=bool)
for i in range(min(shape)):
acm[i, i] = True
acm[0, 1] = True
acm[1, 0] = True
acm[-1, 0] = True
return acm


@pytest.fixture
def test_wcs():
return TestWCS()


class TestWCS():
def __init__(self):
self.world_axis_physical_types = [
'custom:pos.helioprojective.lon', 'custom:pos.helioprojective.lat', 'em.wl', 'time']
self.axis_correlation_matrix = _axis_correlation_matrix()


@pytest.mark.parametrize("test_input,expected", [(ht, True), (hm, False)])
def test_wcs_needs_augmenting(test_input, expected):
assert utils.wcs.WCS._needs_augmenting(test_input) is expected
Expand Down Expand Up @@ -115,3 +143,25 @@ def test_get_dependent_wcs_axes(test_input, expected):
])
def test_axis_correlation_matrix(test_input, expected):
assert (utils.wcs.axis_correlation_matrix(test_input) == expected).all()


def test_convert_between_array_and_pixel_axes():
test_input = np.array([1, 4, -2])
naxes = 5
expected = np.array([3, 0, 1])
output = utils.wcs.convert_between_array_and_pixel_axes(test_input, naxes)
assert all(output == expected)


def test_pixel_axis_to_world_axes(axis_correlation_matrix):
output = utils.wcs.pixel_axis_to_world_axes(0, axis_correlation_matrix)
expected = np.array([0, 1, 3])
assert all(output == expected)


@pytest.mark.parametrize("test_input,expected", [('wl', 2), ('em.wl', 2)])
def test_physical_type_to_world_axis(test_input, expected):
world_axis_physical_types = ['custom:pos.helioprojective.lon',
'custom:pos.helioprojective.lat', 'em.wl', 'time']
output = utils.wcs.physical_type_to_world_axis(test_input, world_axis_physical_types)
assert output == expected
93 changes: 93 additions & 0 deletions ndcube/utils/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,99 @@ def append_sequence_axis_to_wcs(wcs_object):
return WCS(wcs_header)


def convert_between_array_and_pixel_axes(axis, naxes):
"""Reflects axis index about center of number of axes.

This is used to convert between array axes in numpy order and pixel axes in WCS order.
Works in both directions.

Parameters
----------
axis: `numpy.ndarray` of `int`
The axis number(s) before reflection.

naxes: `int`
The number of array axes.

Returns
-------
reflected_axis: `numpy.ndarray` of `int`
The axis number(s) after reflection.
"""
# Check type of input.
if not isinstance(axis, np.ndarray):
raise TypeError("input must be of array type. Got type: {type(axis)}")
if axis.dtype.char not in np.typecodes['AllInteger']:
raise TypeError("input dtype must be of int type. Got dtype: {axis.dtype})")
# Convert negative indices to positive equivalents.
axis[axis < 0] += naxes
if any(axis > naxes - 1):
raise IndexError("Axis out of range. "
f"Number of axes = {naxes}; Axis numbers requested = {axes}")
# Reflect axis about center of number of axes.
reflected_axis = naxes - 1 - axis

return reflected_axis


def pixel_axis_to_world_axes(pixel_axis, axis_correlation_matrix):
"""
Retrieves the indices of the world axis physical types corresponding to a pixel axis.

Parameters
----------
pixel_axis: `int`
The pixel axis index/indices for which the world axes are desired.

axis_correlation_matrix: `numpy.ndarray` of `bool`
2D boolean correlation matrix defining the dependence between the pixel and world axes.
Format same as `astropy.wcs.BaseLowLevelWCS.axis_correlation_matrix`.

Returns
-------
world_axes: `numpy.ndarray`
The world axis indices corresponding to the pixel axis.
"""
return np.arange(axis_correlation_matrix.shape[0])[axis_correlation_matrix[:, pixel_axis]]


def physical_type_to_world_axis(physical_type, world_axis_physical_types):
"""
Returns world axis index of a physical type based on WCS world_axis_physical_types.

Input can be a substring of a physical type, so long as it is unique.

Parameters
----------
physical_type: `str`
The physical type or a substring unique to a physical type.

world_axis_physical_types: sequence of `str`
All available physical types. Ordering must be same as
`astropy.wcs.BaseLowLevelWCS.world_axis_physical_types`

Returns
-------
world_axis: `numbers.Integral`
The world axis index of the physical type.
"""
# Find world axis index described by physical type.
widx = np.where(world_axis_physical_types == physical_type)[0]
# If physical type does not correspond to entry in world_axis_physical_types,
# check if it is a substring of any physical types.
if len(widx) == 0:
widx = [physical_type in world_axis_physical_type
for world_axis_physical_type in world_axis_physical_types]
widx = np.arange(len(world_axis_physical_types))[widx]
if len(widx) != 1:
raise ValueError(
"Input does not uniquely correspond to a physical type."
f" Expected unique substring of one of {world_axis_physical_types}."
f" Got: {physical_type}")
# Return axes with duplicates removed.
return widx[0]


def reduced_axis_correlation_matrix(axis_correlation_matrix, missing_axes,
return_world_indices=False):
"""
Expand Down