Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds option for cutout output to asdf file #116

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 103 additions & 4 deletions astrocut/asdf_cutouts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""This module implements cutout functionality similar to fitscut, but for the ASDF file format."""
from typing import Union
import copy
import pathlib
from typing import Union, Tuple

import asdf
import astropy
import gwcs
import numpy as np

from astropy.coordinates import SkyCoord
from astropy.modeling import models


def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
Expand Down Expand Up @@ -55,7 +58,8 @@ def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:

def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
write_file: bool = True, fill_value: Union[int, float] = np.nan,
gwcsobj: gwcs.wcs.WCS = None) -> astropy.nddata.Cutout2D:
""" Get a Roman image cutout

Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
Expand All @@ -78,6 +82,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
Flag to write the cutout to a file or not
fill_value: int | float, by default np.nan
The fill value for pixels outside the original image.
gwcsobj : gwcs.wcs.WCS, Optional
the original gwcs object for the full image, needed only when writing cutout as asdf file

Returns
-------
Expand All @@ -90,6 +96,8 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk
when a wcs is not present when coords is a SkyCoord object
RuntimeError:
when the requested cutout does not overlap with the original image
ValueError:
when no gwcs object is provided when writing to an asdf file
"""

# check for correct inputs
Expand All @@ -112,11 +120,102 @@ def get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, Sk

# write the cutout to the output file
if write_file:
astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(), overwrite=True)
# check the output file type
out = pathlib.Path(outfile)
write_as = out.suffix or '.fits'
outfile = outfile if out.suffix else str(out) + write_as

# write out the file
if write_as == '.fits':
_write_fits(cutout, outfile)
elif write_as == '.asdf':
if not gwcsobj:
raise ValueError('The original gwcs object is needed when writing to asdf file.')
_write_asdf(cutout, gwcsobj, outfile)

return cutout


def _write_fits(cutout: astropy.nddata.Cutout2D, outfile: str = "example_roman_cutout.fits"):
""" Write cutout as FITS file

Parameters
----------
cutout : astropy.nddata.Cutout2D
the 2d cutout
outfile : str, optional
the name of the output cutout file, by default "example_roman_cutout.fits"
"""
# check if the data is a quantity and get the array data
if isinstance(cutout.data, astropy.units.Quantity):
data = cutout.data.value
else:
data = cutout.data

astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(relax=True), overwrite=True)


def _slice_gwcs(gwcsobj: gwcs.wcs.WCS, slices: Tuple[slice, slice]) -> gwcs.wcs.WCS:
""" Slice the original gwcs object

"Slices" the original gwcs object down to the cutout shape. This is a hack
until proper gwcs slicing is in place a la fits WCS slicing. The ``slices``
keyword input is a tuple with the x, y cutout boundaries in the original image
array, e.g. ``cutout.slices_original``. Astropy Cutout2D slices are in the form
((ymin, ymax, None), (xmin, xmax, None))

Parameters
----------
gwcsobj : gwcs.wcs.WCS
the original gwcs from the input image
slices : Tuple[slice, slice]
the cutout x, y slices as ((ymin, ymax), (xmin, xmax))

Returns
-------
gwcs.wcs.WCS
The sliced gwcs object
"""
tmp = copy.deepcopy(gwcsobj)

# get the cutout array bounds and create a new shift transform to the cutout
# add the new transform to the gwcs
xmin, xmax = slices[1].start, slices[1].stop
ymin, ymax = slices[0].start, slices[0].stop
shape = (ymax - ymin, xmax - xmin)
offsets = models.Shift(xmin, name='cutout_offset1') & models.Shift(ymin, name='cutout_offset2')
tmp.insert_transform('detector', offsets, after=True)

# modify the gwcs bounding box to the cutout shape
tmp.bounding_box = ((0, shape[0] - 1), (0, shape[1] - 1))
tmp.pixel_shape = shape[::-1]
tmp.array_shape = shape
return tmp


def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile: str = "example_roman_cutout.asdf"):
""" Write cutout as ASDF file

Parameters
----------
cutout : astropy.nddata.Cutout2D
the 2d cutout
gwcsobj : gwcs.wcs.WCS
the original gwcs object for the full image
outfile : str, optional
the name of the output cutout file, by default "example_roman_cutout.asdf"
"""
# slice the origial gwcs to the cutout
sliced_gwcs = _slice_gwcs(gwcsobj, cutout.slices_original)

# create the asdf tree
tree = {'roman': {'meta': {'wcs': sliced_gwcs}, 'data': cutout.data}}
af = asdf.AsdfFile(tree)

# Write the data to a new file
af.write_to(outfile)


def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,
output_file: str = "example_roman_cutout.fits",
write_file: bool = True, fill_value: Union[int, float] = np.nan) -> astropy.nddata.Cutout2D:
Expand Down Expand Up @@ -158,4 +257,4 @@ def asdf_cut(input_file: str, ra: float, dec: float, cutout_size: int = 20,

# create the 2d image cutout
return get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
write_file=write_file, fill_value=fill_value)
write_file=write_file, fill_value=fill_value, gwcsobj=gwcsobj)
102 changes: 87 additions & 15 deletions astrocut/tests/test_asdf_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from astropy.wcs.utils import pixel_to_skycoord
from gwcs import wcs
from gwcs import coordinate_frames as cf
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut
from astrocut.asdf_cutouts import get_center_pixel, get_cutout, asdf_cut, _slice_gwcs


def make_wcs(xsize, ysize, ra=30., dec=45.):
Expand Down Expand Up @@ -70,8 +70,8 @@ def _make_fake(nx, ny, ra, dec, zero=False, asint=False):
def fakedata(makefake):
""" fixture to create fake data and wcs """
# set up initial parameters
nx = 100
ny = 100
nx = 1000
ny = 1000
ra = 30.
dec = 45.

Expand Down Expand Up @@ -105,23 +105,26 @@ def test_get_center_pixel(fakedata):
__, gwcs = fakedata

pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.)
assert np.allclose(pixel_coordinates, (np.array(50.), np.array(50.)))
assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.)))
assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.]))


@pytest.fixture()
def output_file(tmp_path):
def output(tmp_path):
""" fixture to create the output path """
# create output fits path
out = tmp_path / "roman"
out.mkdir(exist_ok=True, parents=True)
output_file = out / "test_output_cutout.fits"
yield output_file
def _output_file(ext='fits'):
# create output fits path
out = tmp_path / "roman"
out.mkdir(exist_ok=True, parents=True)
output_file = out / f"test_output_cutout.{ext}" if ext else "test_output_cutout"
return output_file
yield _output_file


@pytest.mark.parametrize('quantity', [True, False], ids=['quantity', 'array'])
def test_get_cutout(output_file, fakedata, quantity):
def test_get_cutout(output, fakedata, quantity):
""" test we can create a cutout """
output_file = output('fits')

# get the input wcs
data, gwcs = fakedata
Expand All @@ -141,11 +144,12 @@ def test_get_cutout(output_file, fakedata, quantity):
with fits.open(output_file) as hdulist:
data = hdulist[0].data
assert data.shape == (10, 10)
assert data[5, 5] == 2525
assert data[5, 5] == 25025


def test_asdf_cutout(make_file, output_file):
def test_asdf_cutout(make_file, output):
""" test we can make a cutout """
output_file = output('fits')
# make cutout
ra, dec = (29.99901792, 44.99930555)
asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file)
Expand All @@ -154,11 +158,38 @@ def test_asdf_cutout(make_file, output_file):
with fits.open(output_file) as hdulist:
data = hdulist[0].data
assert data.shape == (10, 10)
assert data[5, 5] == 2526
assert data[5, 5] == 475476


def test_cutout_nofile(make_file, output_file):
@pytest.mark.parametrize('suffix', ['fits', 'asdf', None])
def test_write_file(make_file, suffix, output):
""" test we can write an different file types """
output_file = output(suffix)

# make cutout
ra, dec = (29.99901792, 44.99930555)
asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file)

# if no suffix provided, check that the default output is fits
if not suffix:
output_file += ".fits"

assert pathlib.Path(output_file).exists()


def test_fail_write_asdf(fakedata, output):
""" test we fail to write an asdf if no gwcs given """
with pytest.raises(ValueError, match='The original gwcs object is needed when writing to asdf file.'):
output_file = output('asdf')
data, gwcs = fakedata
skycoord = gwcs(25, 25, with_units=True)
wcs = WCS(gwcs.to_fits_sip())
get_cutout(data, skycoord, wcs, size=10, outfile=output_file)


def test_cutout_nofile(make_file, output):
""" test we can make a cutout with no file output """
output_file = output()
# make cutout
ra, dec = (29.99901792, 44.99930555)
cutout = asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file, write_file=False)
Expand Down Expand Up @@ -237,6 +268,47 @@ def test_bad_fill(makefake):
get_cutout(data, cc, wcs, size=50, write_file=False)


def test_cutout_raedge(makefake):
""" test we can make cutouts around ra=0 """
# make fake zero data around the ra edge
ra, dec = 0.0, 10.0
data, gg = makefake(2000, 2000, ra, dec, zero=True)

# check central pixel is correct
ss = gg(1001, 1001)
assert pytest.approx(ss, abs=1e-3) == (ra, dec)

# set input cutout coord
cc = coord.SkyCoord(0.001, 9.999, unit=u.degree)
wcs = WCS(gg.to_fits_sip())

# get cutout
cutout = get_cutout(data, cc, wcs, size=100, write_file=False)
assert_same_coord(5, 10, cutout, wcs)

# assert the RA cutout bounds are > 359 and < 0
bounds = gg(*cutout.bbox_original, with_units=True)
assert bounds[0].ra.value > 359
assert bounds[1].ra.value < 0.1


def test_slice_gwcs(fakedata):
""" test we can slice a gwcs object """
data, gwcsobj = fakedata
skycoord = gwcsobj(250, 250)
wcs = WCS(gwcsobj.to_fits_sip())

cutout = get_cutout(data, skycoord, wcs, size=50, write_file=False)

sliced = _slice_gwcs(gwcsobj, cutout.slices_original)

# check coords between slice and original gwcs
assert cutout.center_cutout == (24.5, 24.5)
assert sliced.array_shape == (50, 50)
assert sliced(*cutout.input_position_cutout) == gwcsobj(*cutout.input_position_original)
assert gwcsobj(*cutout.center_original) == sliced(*cutout.center_cutout)

# assert same sky footprint between slice and original
# gwcs footprint/bounding_box expects ((x0, x1), (y0, y1)) but cutout.bbox is in ((y0, y1), (x0, x1))
assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all()