Skip to content

Commit

Permalink
adding temp gwcs slicing for asdf writes
Browse files Browse the repository at this point in the history
  • Loading branch information
havok2063 committed Feb 2, 2024
1 parent 9731e33 commit 78fcfce
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 13 deletions.
85 changes: 78 additions & 7 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""This module implements cutout functionality similar to fitscut, but for the ASDF file format."""
import copy
import pathlib
from typing import Union
from typing import Union, Tuple

import asdf
import astropy
import gwcs
import numpy as np

from astropy.coordinates import SkyCoord
from astropy.modeling import models


def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
Expand Down Expand Up @@ -56,7 +58,8 @@ def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:

def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: Union[int, float] = np.nan,
gwcsobj: gwcs.wcs.WCS = None) -> astropy.nddata.Cutout2D:
""" Get a Roman image cutout
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
Expand All @@ -79,6 +82,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
Flag to write the cutout to a file or not
fill_value: int | float, by default np.nan
The fill value for pixels outside the original image.
gwcsobj : gwcs.wcs.WCS, Optional
the original gwcs object for the full image, needed only when writing cutout as asdf file
Returns
-------
Expand All @@ -91,6 +96,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
when a wcs is not present when coords is a SkyCoord object
RuntimeError:
when the requested cutout does not overlap with the original image
ValueError:
when no gwcs object is provided when writing to an asdf file
"""

# check for correct inputs
Expand Down Expand Up @@ -122,12 +129,23 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
if write_as == '.fits':
_write_fits(cutout, outfile)
elif write_as == '.asdf':
_write_asdf(cutout, outfile)
if not gwcsobj:
raise ValueError('The original gwcs object is needed when writing to asdf file.')
_write_asdf(cutout, gwcsobj, outfile)

return cutout


def _write_fits(cutout, outfile="example_roman_cutout.fits"):
def _write_fits(cutout: astropy.nddata.Cutout2D, outfile: str = "example_roman_cutout.fits"):
""" Write cutout as FITS file
Parameters
----------
cutout : astropy.nddata.Cutout2D
the 2d cutout
outfile : str, optional
the name of the output cutout file, by default "example_roman_cutout.fits"
"""
# check if the data is a quantity and get the array data
if isinstance(cutout.data, astropy.units.Quantity):
data = cutout.data.value
Expand All @@ -137,8 +155,61 @@ def _write_fits(cutout, outfile="example_roman_cutout.fits"):
astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(relax=True), overwrite=True)


def _write_asdf(cutout, outfile="example_roman_cutout.asdf"):
tree = {'roman': {'meta': {'wcs': dict(cutout.wcs.to_header(relax=True))}, 'data': cutout.data}}
def _slice_gwcs(gwcsobj: gwcs.wcs.WCS, slices: Tuple[slice, slice]) -> gwcs.wcs.WCS:
""" Slice the original gwcs object
"Slices" the original gwcs object down to the cutout shape. This is a hack
until proper gwcs slicing is in place a la fits WCS slicing. The ``slices``
keyword input is a tuple with the x, y cutout boundaries in the original image
array, e.g. ``cutout.slices_original``. Astropy Cutout2D slices are in the form
((ymin, ymax, None), (xmin, xmax, None))
Parameters
----------
gwcsobj : gwcs.wcs.WCS
the original gwcs from the input image
slices : Tuple[slice, slice]
the cutout x, y slices as ((ymin, ymax), (xmin, xmax))
Returns
-------
gwcs.wcs.WCS
The sliced gwcs object
"""
tmp = copy.deepcopy(gwcsobj)

# get the cutout array bounds and create a new shift transform to the cutout
# add the new transform to the gwcs
xmin, xmax = slices[1].start, slices[1].stop
ymin, ymax = slices[0].start, slices[0].stop
shape = (ymax - ymin, xmax - xmin)
offsets = models.Shift(xmin, name='cutout_offset1') & models.Shift(ymin, name='cutout_offset2')
tmp.insert_transform('detector', offsets, after=True)

# modify the gwcs bounding box to the cutout shape
tmp.bounding_box = ((0, shape[0] - 1), (0, shape[1] - 1))
tmp.pixel_shape = shape[::-1]
tmp.array_shape = shape
return tmp


def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile: str = "example_roman_cutout.asdf"):
""" Write cutout as ASDF file
Parameters
----------
cutout : astropy.nddata.Cutout2D
the 2d cutout
gwcsobj : gwcs.wcs.WCS
the original gwcs object for the full image
outfile : str, optional
the name of the output cutout file, by default "example_roman_cutout.asdf"
"""
# slice the origial gwcs to the cutout
sliced_gwcs = _slice_gwcs(gwcsobj, cutout.slices_original)

# create the asdf tree
tree = {'roman': {'meta': {'wcs': sliced_gwcs, 'data': cutout.data}}}
af = asdf.AsdfFile(tree)

# Write the data to a new file
Expand Down Expand Up @@ -186,4 +257,4 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,

# create the 2d image cutout
return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
write_file=write_file, fill_value=fill_value)
write_file=write_file, fill_value=fill_value, gwcsobj=gwcsobj)
43 changes: 37 additions & 6 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from astropy.wcs.utils import pixel_to_skycoord
from gwcs import wcs
from gwcs import coordinate_frames as cf
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs


def make_wcs(xsize, ysize, ra=30., dec=45.):
Expand Down Expand Up @@ -70,8 +70,8 @@ def _make_fake(nx, ny, ra, dec, zero=False, asint=False):
def fakedata(makefake):
""" fixture to create fake data and wcs """
# set up initial parameters
nx = 100
ny = 100
nx = 1000
ny = 1000
ra = 30.
dec = 45.

Expand Down Expand Up @@ -105,7 +105,7 @@ def test_get_center_pixel(fakedata):
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(50.), np.array(50.)))
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


Expand Down Expand Up @@ -144,7 +144,7 @@ def test_get_cutout(output, fakedata, quantity):
with fits.open(output_file) as hdulist:
data = hdulist[0].data
assert data.shape == (10, 10)
assert data[5, 5] == 2525
assert data[5, 5] == 25025


def test_asdf_cutout(make_file, output):
Expand All @@ -158,7 +158,7 @@ def test_asdf_cutout(make_file, output):
with fits.open(output_file) as hdulist:
data = hdulist[0].data
assert data.shape == (10, 10)
assert data[5, 5] == 2526
assert data[5, 5] == 475476


@pytest.mark.parametrize('suffix', ['fits', 'asdf', None])
Expand All @@ -177,6 +177,16 @@ def test_write_file(make_file, suffix, output):
assert pathlib.Path(output_file).exists()


def test_fail_write_asdf(fakedata, output):
""" test we fail to write an asdf if no gwcs given """
with pytest.raises(ValueError, match='The original gwcs object is needed when writing to asdf file.'):
output_file = output('asdf')
data, gwcs = fakedata
skycoord = gwcs(25, 25, with_units=True)
wcs = WCS(gwcs.to_fits_sip())
get_cutout(data, skycoord, wcs, size=10, outfile=output_file)


def test_cutout_nofile(make_file, output):
""" test we can make a cutout with no file output """
output_file = output()
Expand Down Expand Up @@ -281,3 +291,24 @@ def test_cutout_raedge(makefake):
assert bounds[0].ra.value > 359
assert bounds[1].ra.value < 0.1


def test_slice_gwcs(fakedata):
""" test we can slice a gwcs object """
data, gwcsobj = fakedata
skycoord = gwcsobj(250, 250)
wcs = WCS(gwcsobj.to_fits_sip())

cutout = get_cutout(data, skycoord, wcs, size=50, write_file=False)

sliced = _slice_gwcs(gwcsobj, cutout.slices_original)

# check coords between slice and original gwcs
assert cutout.center_cutout == (24.5, 24.5)
assert sliced.array_shape == (50, 50)
assert sliced(*cutout.input_position_cutout) == gwcsobj(*cutout.input_position_original)
assert gwcsobj(*cutout.center_original) == sliced(*cutout.center_cutout)

# assert same sky footprint between slice and original
# gwcs footprint/bounding_box expects ((x0, x1), (y0, y1)) but cutout.bbox is in ((y0, y1), (x0, x1))
assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all()

0 comments on commit 78fcfce

Please sign in to comment.