Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert WCS wrappers to FITS WCS. #649

Merged
merged 13 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/649.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Provides `~ndcube.wcs.tools.unwrap_wcs_to_fitswcs`, a function to create a `astropy.wcs.WCS` instance equivalent to a sliced and/or resampled WCS instance.
Only valid if the underlying implementation of the wrapped WCS instance is also an `astropy.wcs.WCS` instance.
2 changes: 2 additions & 0 deletions docs/reference/wcs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ wcs (`ndcube.wcs`)
.. automodapi:: ndcube.wcs

.. automodapi:: ndcube.wcs.wrappers

.. automodapi:: ndcube.wcs.tools
41 changes: 41 additions & 0 deletions ndcube/wcs/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from astropy.time import Time
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS
from numpy.testing import assert_array_almost_equal, assert_array_equal

from ndcube.wcs.tools import unwrap_wcs_to_fitswcs
from ndcube.wcs.wrappers import ResampledLowLevelWCS


def test_unwrap_wcs_to_fitswcs():
# Build FITS-WCS and wrap it in different operations.
time_ref = Time("2000-01-01T00:00:00", scale="utc", format="fits")
header = {
"CTYPE1": "TIME", "CTYPE2": "WAVE", "CTYPE3": "HPLT-TAN", "CTYPE4": "HPLN-TAN",
"CUNIT1": "s", "CUNIT2": "Angstrom", "CUNIT3": "deg", "CUNIT4": "deg",
"CDELT1": 600, "CDELT2": 0.2, "CDELT3": 0.5, "CDELT4": 0.4,
"CRPIX1": 0, "CRPIX2": 0, "CRPIX3": 2, "CRPIX4": 2,
"CRVAL1": 0, "CRVAL2": 10, "CRVAL3": 0.5, "CRVAL4": 1,
"CNAME1": "time", "CNAME2": "wavelength", "CNAME3": "HPC lat", "CNAME4": "HPC lon",
"NAXIS1": 5, "NAXIS2": 9, "NAXIS3": 4, "NAXIS4": 4,
"DATEREF": time_ref.fits}
orig_wcs = WCS(header)
# Slice WCS
wcs1 = SlicedLowLevelWCS(orig_wcs, (0, 0, slice(None), slice(1, None))) # numpy order
# Resample WCS
wcs2 = ResampledLowLevelWCS(wcs1, [2, 3], offset=[0.5, 1]) # WCS order
# Slice WCS again
wcs3 = SlicedLowLevelWCS(wcs2, (slice(0, 2), slice(1, 2))) # numpy order
# Reconstruct fitswcs
output_wcs, dropped_data_dimensions = unwrap_wcs_to_fitswcs(wcs3)
# Assert output_wcs is correct
assert_array_equal(dropped_data_dimensions, np.array([True, True, False, False]))
assert isinstance(output_wcs, WCS)
assert output_wcs._naxis == [1, 2, 1, 1]
assert list(output_wcs.wcs.ctype) == ['TIME', 'WAVE', 'HPLT-TAN', 'HPLN-TAN']
world_values = output_wcs.array_index_to_world_values([0], [0], [0, 1], [0])
assert_array_almost_equal(world_values[0][0], np.array([2700]))
assert_array_almost_equal(world_values[1], np.array([1.04e-09, 1.10e-09]))
assert_array_almost_equal(world_values[2][0], np.array([1.26915033e-05]))
assert_array_almost_equal(world_values[3][0], np.array([0.60002173]))
187 changes: 187 additions & 0 deletions ndcube/wcs/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from numbers import Integral

import numpy as np
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper

from ndcube.wcs.wrappers import ResampledLowLevelWCS

__all__ = ["unwrap_wcs_to_fitswcs"]


def unwrap_wcs_to_fitswcs(wcs):
"""
Create FITS-WCS equivalent to (nested) WCS wrapper object.

Underlying WCS must be FITS-WCS.
No axes are dropped from original FITS-WCS, even if sliced by an integer.
Instead, integer-sliced axes is sliced to length-1 and marked True in the
``dropped_data_axes`` output.
Currently supported wrapper classes include `astropy.wcs.wcsapi.SlicedLowLevelWCS`
and `ndcube.wcs.wrappers.ResampledLowLevelWCS`.

Parameters
----------
wcs: `~astropy.wcs.wcsapi.BaseWCSWrapper`
The WCS Wrapper object.
Base level WCS implementation must be FITS-WCS.

Returns
-------
fitswcs: `astropy.wcs.WCS`
The equivalent FITS-WCS object.
dropped_data_axes: 1-D `numpy.ndarray`
Denotes which axes must have been dropped from the data array by slicing wrappers.
Axes are in array/numpy order, reversed compared to WCS.
"""
# If wcs is already a FITS-WCS, return it.
low_level_wrapper = wcs.low_level_wcs if hasattr(wcs, "low_level_wcs") else wcs
if isinstance(low_level_wrapper, WCS):
return low_level_wrapper, np.zeros(low_level_wrapper.naxis, dtype=bool)

Check warning on line 41 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L41

Added line #L41 was not covered by tests
# Determine chain of wrappers down to the FITS-WCS.
wrapper_chain = []
while isinstance(low_level_wrapper, BaseWCSWrapper):
wrapper_chain.append(low_level_wrapper)
low_level_wrapper = low_level_wrapper._wcs
if hasattr(low_level_wrapper, "low_level_wcs"):
low_level_wrapper = low_level_wrapper.low_level_wcs
if not isinstance(low_level_wrapper, WCS):
raise TypeError(f"Base-level WCS must be type {type(WCS)}. Found: {type(low_level_wcs)}")

Check warning on line 50 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L50

Added line #L50 was not covered by tests
fitswcs = low_level_wrapper
dropped_data_axes = np.zeros(fitswcs.naxis, dtype=bool)
# Unwrap each wrapper in reverse order and edit fitswcs.
for low_level_wrapper in wrapper_chain[::-1]:
if isinstance(low_level_wrapper, SlicedLowLevelWCS):
slice_items = np.array([slice(None)] * fitswcs.naxis)
slice_items[dropped_data_axes == False] = low_level_wrapper._slices_array # numpy order
fitswcs, dda = _slice_fitswcs(fitswcs, slice_items, numpy_order=True)
dropped_data_axes[dda] = True
elif isinstance(low_level_wrapper, ResampledLowLevelWCS):
factor = np.ones(fitswcs.naxis)
offset = np.zeros(fitswcs.naxis)
kept_wcs_axes = dropped_data_axes[::-1] == False # WCS-order
factor[kept_wcs_axes] = low_level_wrapper._factor
offset[kept_wcs_axes] = low_level_wrapper._offset
fitswcs = _resample_fitswcs(fitswcs, factor, offset)
else:
raise TypeError("Unrecognized/unsupported WCS Wrapper type: {type(low_level_wrapper)}")

Check warning on line 68 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L68

Added line #L68 was not covered by tests
return fitswcs, dropped_data_axes


def _slice_fitswcs(fitswcs, slice_items, numpy_order=True, shape=None):
"""
Slice a FITS-WCS.

If an `int` is given in ``slice_items``, the corresponding axis is not dropped.
But the new 0th pixel will correspond the index given by the `int` in the
original WCS.

Parameters
----------
fitswcs: `astropy.wcs.WCS`
The FITS-WCS object to be sliced.
slice_items: iterable of `slice` objects or `int`
The slices to by applied to each axis. If an `int` is provided, the axis
is sliced to length-1, but not dropped. However, its corresponding entry
in the ``dropped_data_axes`` output is marked True.
numpy_order: `bool`
If True, slices in ``slice_items`` are in array/numpy order, which is
reversed compared to the WCS order.
shape: sequence of `int`, optional
The length of each axis. Only used if negative indices are supplied
in ``slice_items``. If not supplied, set to ``fitswcs._naxis``.
Order defined by numpy_order kwarg.

Returns
-------
sliced_wcs: `astropy.wcs.WCS`
The sliced FITS-WCS.
dropped_data_axes: 1-D `numpy.ndarray`
Denotes which axes must have been dropped from the data array by slicing wrappers.
Order of axes (numpy or WCS) is dictated by ``numpy_order`` kwarg.
"""
def negative_index_error_msg(x): return (
f"Negative indexing not supported as {x}th axis length is 0 in "
"underlying FITS-WCS. Supply axes lengths via shape kwarg.")
naxis = fitswcs.naxis
dropped_data_axes = np.zeros(naxis, dtype=bool)
# Sanitize inputs
if shape is None:
shape = fitswcs._naxis
if numpy_order:
shape = shape[::-1]
else:
if len(shape) != naxis:
raise ValueError("shape kwarg must be same length as number of pixel axes "

Check warning on line 116 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L115-L116

Added lines #L115 - L116 were not covered by tests
f"in FITS-WCS, i.e. {naxis}")
if not all(isinstance(s, Integral) for s in shape):
raise TypeError("All elements of ``shape`` must be integers. "

Check warning on line 119 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L118-L119

Added lines #L118 - L119 were not covered by tests
f"shapes types = {[type(s) for s in shape]}")
slice_items = list(slice_items)
for i, (item, len_axis) in enumerate(zip(slice_items, shape)):
if isinstance(item, Integral):
# Mark axis corresponding to int item as dropped from data array.
dropped_data_axes[i] = True
# Convert negative indices to positive equivalent.
if item < 0:
if len_axis == 0:
raise ValueError(negative_index_error_msg(i))
item = len_axis + item

Check warning on line 130 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L128-L130

Added lines #L128 - L130 were not covered by tests
# Convert int item to slice so a FITS-WCS is returned after slicing.
slice_items[i] = slice(item, item + 1)
elif isinstance(item, slice):
# Convert negative indices inside slice item to positive equivalent.
start_neg = item.start is not None and item.start < 0
stop_neg = item.stop is not None and item.stop < 0
if start_neg or stop_neg:
if len_axis == 0:
raise ValueError(negative_index_error_msg(i))
start = len_axis + item.start if start_neg else item.start
stop = len_axis + item.stop if stop_neg else item.stop
slice_items[i] = slice(start, stop, item.step)

Check warning on line 142 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L138-L142

Added lines #L138 - L142 were not covered by tests
else:
raise TypeError("All slice_items must be a slice or an int. "

Check warning on line 144 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L144

Added line #L144 was not covered by tests
f"type(slice_items[{i}]) = {type(slice_items[i])}")
# Slice WCS
sliced_wcs = fitswcs.slice(slice_items, numpy_order=numpy_order)
return sliced_wcs, dropped_data_axes


def _resample_fitswcs(fitswcs, factor, offset=0):
"""
Resample the plate scale of a FITS-WCS by a given factor.

``factor`` and ``offset`` inputs are in pixel order.

Parameters
----------
fitswcs: `astropy.wcs.WCS`
The FITS-WCS object to be resampled.
factor: 1-D array-like or scalar
The factor by which the FITS-WCS is resampled.
Must be same length as number of axes in ``fitswcs``.
If scalar, the same factor is applied to all axes.
Factors must be given in WCS-order (opposite to data axes order).
offset: 1-D array-like or scalar
The location on the initial pixel grid which corresponds to zero on the
resampled pixel grid. If scalar, the same offset is applied to all axes.
Offsets must be given in WCS-order (opposite to data axes order).

Returns
-------
resampled_wcs: `astropy.wcs.WCS`
The resampled FITS-WCS.
"""
# Sanitize inputs.
factor = np.asarray(factor)
if len(factor) != fitswcs.naxis:
raise ValueError(f"Length of factor must equal number of dimensions {fitswcs.naxis}.")

Check warning on line 179 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L179

Added line #L179 was not covered by tests
offset = np.asarray(offset)
if len(offset) != fitswcs.naxis:
raise ValueError(f"Length of offset must equal number of dimensions {fitswcs.naxis}.")

Check warning on line 182 in ndcube/wcs/tools.py

View check run for this annotation

Codecov / codecov/patch

ndcube/wcs/tools.py#L182

Added line #L182 was not covered by tests
# Scale plate scale and shift by offset.
fitswcs.wcs.cdelt *= factor
fitswcs.wcs.crpix = (fitswcs.wcs.crpix + offset) / factor
fitswcs._naxis = list(np.round(np.array(fitswcs._naxis) / factor).astype(int))
return fitswcs