Skip to content

Commit

Permalink
BUG/FEAT: read_fits_spec now more flexible
Browse files Browse the repository at this point in the history
but at a cost.
  • Loading branch information
pllim committed Mar 15, 2024
1 parent c60d653 commit e8a9607
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 80 deletions.
12 changes: 11 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
1.3.1 (unreleased)
1.4.0 (unreleased)
==================

- ``read_fits_spec()`` now uses ``astropy.table.QTable.read`` for parsing to
ensure that the correct ``TUNITn`` is read. As a result, ``wave_unit`` and
``flux_unit`` keywords are deprecated and no longer used in that function.
Additionally, if any ``TUNITn`` in the table is invalid, regardless whether
the column is used or not, an exception will now be raised. The inputs
for ``wave_col`` and ``flux_col`` are now case-sensitive. [#384]

- ``read_spec()`` now detects whether given filename is FITS more consistently
w.r.t. ``astropy``. [#384]

1.3.0 (2023-11-28)
==================

Expand Down
23 changes: 11 additions & 12 deletions synphot/reddening.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# THIRD-PARTY
import numpy as np
from astropy import units as u
from astropy.io.fits.connect import is_fits

# LOCAL
from synphot import exceptions, specio, units
Expand Down Expand Up @@ -137,8 +138,8 @@ def to_fits(self, filename, wavelengths=None, **kwargs):
def from_file(cls, filename, **kwargs):
"""Create a reddening law from file.
If filename has 'fits' or 'fit' suffix, it is read as FITS.
Otherwise, it is read as ASCII.
If filename is recognized by ``astropy.io.fits`` as FITS,
it is read as such. Otherwise, it is read as ASCII.
Parameters
----------
Expand All @@ -156,13 +157,12 @@ def from_file(cls, filename, **kwargs):
Empirical reddening law.
"""
if 'flux_unit' not in kwargs:
if is_fits("", filename, None):
if 'flux_col' not in kwargs:
kwargs['flux_col'] = 'Av/E(B-V)'
elif 'flux_unit' not in kwargs:

Check warning on line 163 in synphot/reddening.py

View check run for this annotation

Codecov / codecov/patch

synphot/reddening.py#L163

Added line #L163 was not covered by tests
kwargs['flux_unit'] = cls._internal_flux_unit

if ((filename.endswith('fits') or filename.endswith('fit')) and
'flux_col' not in kwargs):
kwargs['flux_col'] = 'Av/E(B-V)'

header, wavelengths, rvs = specio.read_spec(filename, **kwargs)

return cls(Empirical1D, points=wavelengths, lookup_table=rvs,
Expand Down Expand Up @@ -217,13 +217,12 @@ def from_extinction_model(cls, modelname, **kwargs):

filename = cfgitem()

if 'flux_unit' not in kwargs:
if is_fits("", filename, None):
if 'flux_col' not in kwargs:
kwargs['flux_col'] = 'Av/E(B-V)'
elif 'flux_unit' not in kwargs:

Check warning on line 223 in synphot/reddening.py

View check run for this annotation

Codecov / codecov/patch

synphot/reddening.py#L223

Added line #L223 was not covered by tests
kwargs['flux_unit'] = cls._internal_flux_unit

if ((filename.endswith('fits') or filename.endswith('fit')) and
'flux_col' not in kwargs):
kwargs['flux_col'] = 'Av/E(B-V)'

header, wavelengths, rvs = specio.read_remote_spec(filename, **kwargs)
header['filename'] = filename
header['descrip'] = cfgitem.description
Expand Down
68 changes: 32 additions & 36 deletions synphot/specio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from astropy import log
from astropy import units as u
from astropy.io import ascii, fits
from astropy.io.fits.connect import is_fits
from astropy.table import QTable
from astropy.utils.data import get_readable_fileobj
from astropy.utils.decorators import deprecated_renamed_argument
from astropy.utils.exceptions import AstropyUserWarning

# LOCAL
Expand Down Expand Up @@ -88,7 +91,7 @@ def read_spec(filename, fname='', **kwargs):
elif not fname: # pragma: no cover
raise exceptions.SynphotError('Cannot determine filename.')

if fname.endswith('fits') or fname.endswith('fit'):
if is_fits("", fname, None):
read_func = read_fits_spec
else:
read_func = read_ascii_spec
Expand Down Expand Up @@ -143,12 +146,15 @@ def read_ascii_spec(filename, wave_unit=u.AA, flux_unit=units.FLAM, **kwargs):
return header, wavelengths, fluxes


@deprecated_renamed_argument(
["wave_unit", "flux_unit"], [None, None], ["1.4", "1.4"],
alternative='TUNITn as per FITS standards')
def read_fits_spec(filename, ext=1, wave_col='WAVELENGTH', flux_col='FLUX',
wave_unit=u.AA, flux_unit=units.FLAM):
"""Read FITS spectrum.
Wavelength and flux units are extracted from ``TUNIT1`` and ``TUNIT2``
keywords, respectively, from data table (not primary) header.
Wavelength and flux units are extracted from respective ``TUNITn``
keywords, from data table (not primary) header.
If these keywords are not present, units are taken from
``wave_unit`` and ``flux_unit`` instead.
Expand All @@ -161,12 +167,14 @@ def read_fits_spec(filename, ext=1, wave_col='WAVELENGTH', flux_col='FLUX',
FITS extension with table data. Default is 1.
wave_col, flux_col : str
Wavelength and flux column names (case-insensitive).
Wavelength and flux column names (case-sensitive).
wave_unit, flux_unit : str or `~astropy.units.Unit`
Wavelength and flux units, which default to Angstrom and FLAM,
respectively. These are *only* used if ``TUNIT1`` and ``TUNIT2``
keywords are not present in table (not primary) header.
Wavelength and flux units. These are *no longer used*.
Define your units in the respective ``TUNITn``
keywords in table (not primary) header.
.. deprecated:: 1.4
Returns
-------
Expand All @@ -179,35 +187,23 @@ def read_fits_spec(filename, ext=1, wave_col='WAVELENGTH', flux_col='FLUX',
"""
try:
fs = fits.open(filename)
header = dict(fs[str('PRIMARY')].header)
wave_dat = fs[ext].data.field(wave_col).copy()
flux_dat = fs[ext].data.field(flux_col).copy()
fits_wave_unit = fs[ext].header.get('TUNIT1')
fits_flux_unit = fs[ext].header.get('TUNIT2')

if fits_wave_unit is not None:
try:
wave_unit = units.validate_unit(fits_wave_unit)
except (exceptions.SynphotError, ValueError) as e: # pragma: no cover # noqa: E501
warnings.warn(
'{0} from FITS header is not valid wavelength unit, using '
'{1}: {2}'.format(fits_wave_unit, wave_unit, e),
AstropyUserWarning)

if fits_flux_unit is not None:
try:
flux_unit = units.validate_unit(fits_flux_unit)
except (exceptions.SynphotError, ValueError) as e: # pragma: no cover # noqa: E501
warnings.warn(
'{0} from FITS header is not valid flux unit, using '
'{1}: {2}'.format(fits_flux_unit, flux_unit, e),
AstropyUserWarning)

wave_unit = units.validate_unit(wave_unit)
flux_unit = units.validate_unit(flux_unit)

wavelengths = wave_dat * wave_unit
fluxes = flux_dat * flux_unit
subhdu = fs[ext]

# Need to fix table units
for key in subhdu.header["TUNIT*"]:
val = subhdu.header[key]
if not val:
continue
newval = units.validate_unit(val)
subhdu.header[key] = newval.to_string("fits")

t = QTable.read(subhdu)
header = dict(fs["PRIMARY"].header)
t_col_wave = t[wave_col]
wavelengths = t_col_wave.value * (t_col_wave.unit or u.dimensionless_unscaled) # noqa: E501
t_col_flux = t[flux_col]
fluxes = t_col_flux.value * (t_col_flux.unit or u.dimensionless_unscaled) # noqa: E501

finally:
if isinstance(filename, str):
fs.close()
Expand Down
23 changes: 11 additions & 12 deletions synphot/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ASTROPY
from astropy import log
from astropy import units as u
from astropy.io.fits.connect import is_fits
from astropy.modeling import Model
from astropy.modeling.core import CompoundModel
from astropy.modeling.models import RedshiftScaleFactor, Scale
Expand Down Expand Up @@ -1921,8 +1922,8 @@ def to_fits(self, filename, wavelengths=None, **kwargs):
def from_file(cls, filename, **kwargs):
"""Creates a bandpass from file.
If filename has 'fits' or 'fit' suffix, it is read as FITS.
Otherwise, it is read as ASCII.
If filename is recognized by ``astropy.io.fits`` as FITS,
it is read as such. Otherwise, it is read as ASCII.
Parameters
----------
Expand All @@ -1940,13 +1941,12 @@ def from_file(cls, filename, **kwargs):
Empirical bandpass.
"""
if 'flux_unit' not in kwargs:
if is_fits("", filename, None):
if 'flux_col' not in kwargs:
kwargs['flux_col'] = 'THROUGHPUT'
elif 'flux_unit' not in kwargs:

Check warning on line 1947 in synphot/spectrum.py

View check run for this annotation

Codecov / codecov/patch

synphot/spectrum.py#L1947

Added line #L1947 was not covered by tests
kwargs['flux_unit'] = cls._internal_flux_unit

if ((filename.endswith('fits') or filename.endswith('fit')) and
'flux_col' not in kwargs):
kwargs['flux_col'] = 'THROUGHPUT'

header, wavelengths, throughput = specio.read_spec(filename, **kwargs)
return cls(Empirical1D, points=wavelengths, lookup_table=throughput,
keep_neg=True, meta={'header': header})
Expand Down Expand Up @@ -2009,13 +2009,12 @@ def from_filter(cls, filtername, **kwargs):

filename = cfgitem()

if 'flux_unit' not in kwargs:
if is_fits("", filename, None):
if 'flux_col' not in kwargs:
kwargs['flux_col'] = 'THROUGHPUT'
elif 'flux_unit' not in kwargs:

Check warning on line 2015 in synphot/spectrum.py

View check run for this annotation

Codecov / codecov/patch

synphot/spectrum.py#L2015

Added line #L2015 was not covered by tests
kwargs['flux_unit'] = cls._internal_flux_unit

if ((filename.endswith('fits') or filename.endswith('fit')) and
'flux_col' not in kwargs):
kwargs['flux_col'] = 'THROUGHPUT'

header, wavelengths, throughput = specio.read_remote_spec(
filename, **kwargs)
header['filename'] = filename
Expand Down
2 changes: 1 addition & 1 deletion synphot/tests/test_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def setup_class(self):
get_pkg_data_filename(
os.path.join('data', 'hst_acs_hrc_f555w.fits'),
package='synphot.tests'),
flux_col='THROUGHPUT', flux_unit=u.dimensionless_unscaled)
flux_col='THROUGHPUT')

# Binned data.
bins = generate_wavelengths(
Expand Down
79 changes: 65 additions & 14 deletions synphot/tests/test_specio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

# STDLIB
import os
import shutil
import tempfile

# THIRD-PARTY
import numpy as np
Expand All @@ -15,10 +13,12 @@
from astropy.io import fits
from astropy.tests.helper import assert_quantity_allclose
from astropy.utils.data import get_pkg_data_filename
from astropy.utils.exceptions import AstropyUserWarning
from astropy.utils.exceptions import (
AstropyUserWarning, AstropyDeprecationWarning)

# LOCAL
from synphot import exceptions, specio, units
from synphot.spectrum import SpectralElement


@pytest.mark.remote_data
Expand Down Expand Up @@ -55,17 +55,16 @@ class TestReadWriteFITS:
"""Test read/write FITS spectrum."""
def setup_class(self):
self.epsilon = 0.00031
self.outdir = tempfile.mkdtemp()
self.wave = np.array([1000.0, 2000.0, 2000.0 + self.epsilon, 3000.0,
4000.0, 5000.0], dtype=np.float64) * u.AA
self.flux = np.array([0.1, 100.2, 10.0, 0.0, 6.5, 1.2],
dtype=np.float64) * units.PHOTLAM
self.prihdr = {'PEDIGREE': 'DUMMY'}
self.scihdr = {'SPEC_SRC': 'RANDOM'}

def test_array_data(self):
def test_array_data(self, tmp_path):
"""Data as Numpy array."""
outfile = os.path.join(self.outdir, 'outspec1.fits')
outfile = str(tmp_path / 'outspec1.fits')

# Write it out
with pytest.warns(AstropyUserWarning, match=r'rows are thrown out'):
Expand All @@ -76,7 +75,8 @@ def test_array_data(self):
wave_unit=self.wave.unit, flux_unit=self.flux.unit)

# Read it back in and check values (flux_unit should be ignored)
hdr, wave, flux = specio.read_spec(outfile, flux_unit='foo')
with pytest.warns(AstropyDeprecationWarning, match=r"\"flux_unit\" was deprecated"): # noqa: E501
hdr, wave, flux = specio.read_spec(outfile, flux_unit='foo')

# Compare data
np.testing.assert_allclose(
Expand All @@ -95,17 +95,18 @@ def test_array_data(self):
assert sci_hdr['SPEC_SRC'] == 'RANDOM'
assert sci_hdr['TFORM2'].lower() == 'e'

def test_quantity_data(self):
def test_quantity_data(self, tmp_path):
"""Data as Quantity."""
outfile = os.path.join(self.outdir, 'outspec2.fits')
outfile = str(tmp_path / 'outspec2.fits')

# Write it out (flux_unit should be ignored)
specio.write_fits_spec(
outfile, self.wave, self.flux, pri_header=self.prihdr,
ext_header=self.scihdr, precision='double', flux_unit='foo')

# Read it back in and check values (flux_unit should be ignored)
hdr, wave, flux = specio.read_spec(outfile, flux_unit='foo')
with pytest.warns(AstropyDeprecationWarning, match=r"\"flux_unit\" was deprecated"): # noqa: E501
hdr, wave, flux = specio.read_spec(outfile, flux_unit='foo')

# Compare data (trim_zero=True, pad_zero_ends=True)
np.testing.assert_allclose(
Expand All @@ -125,9 +126,9 @@ def test_quantity_data(self):
assert sci_hdr['SPEC_SRC'] == 'RANDOM'
assert sci_hdr['TFORM2'].lower() == 'd'

def test_exceptions(self):
def test_exceptions(self, tmp_path):
"""Test for appropriate exceptions."""
outfile = os.path.join(self.outdir, 'outspec3.fits')
outfile = str(tmp_path / 'outspec3.fits')

# Shape mismatch
with pytest.raises(exceptions.SynphotError):
Expand All @@ -149,5 +150,55 @@ def test_exceptions(self):
specio.write_fits_spec(
outfile, self.wave, np.arange(6), overwrite=True)

def teardown_class(self):
shutil.rmtree(self.outdir)

def test_read_nonstandard_fits_cols_01(tmp_path):
"""See https://github.com/spacetelescope/synphot_refactor/issues/372"""
pix = np.arange(5, dtype=int) + 1
wav = (pix * 0.1) * u.micron
trace = np.array([0, 0.5, 1, 0.9, 0])
coldefs = fits.ColDefs([
fits.Column(name="X", format="I", array=pix),
fits.Column(name="WAVELENGTH", format="E",
unit=wav.unit.to_string(format="fits"), array=wav.value),
fits.Column(name="TRACE", format="E", array=trace)])
hdulist = fits.HDUList([
fits.PrimaryHDU(),
fits.BinTableHDU.from_columns(coldefs)])
outfile = str(tmp_path / "jwst_niriss_soss_trace.fits")
hdulist.writeto(outfile, overwrite=True)

tr = SpectralElement.from_file(outfile, flux_col="TRACE")
assert_quantity_allclose(tr.waveset, wav)
assert_quantity_allclose(tr(wav), trace, atol=1e-7)


def test_read_nonstandard_fits_cols_02(tmp_path):
"""See https://github.com/spacetelescope/synphot_refactor/issues/372"""

wav = (np.arange(5) + 1) * u.nm
flux_unit_str = "ph/s/m2/micron/arcsec2" # Invalid but should not matter.
flux = np.ones(5)
thru = np.array([0, 0.5, 1, 0.9, 0])
coldefs = fits.ColDefs([
fits.Column(name="lam", format="E",
unit=wav.unit.to_string(format="fits"), array=wav.value),
fits.Column(name="flux", format="E",
unit=flux_unit_str, array=flux),
fits.Column(name="dflux1", format="E",
unit=flux_unit_str, array=flux),
fits.Column(name="dflux2", format="E",
unit=flux_unit_str, array=flux),
fits.Column(name="trans", format="E", unit="1", array=thru)])
hdulist = fits.HDUList([
fits.PrimaryHDU(),
fits.BinTableHDU.from_columns(coldefs)])
outfile = str(tmp_path / "skytable.fits")
hdulist.writeto(outfile, overwrite=True)

with pytest.warns(u.UnitsWarning, match="'ph/s/m2/micron/arcsec2'"): # noqa: E501
header, wavelengths, transmission = specio.read_spec(
outfile, wave_col="lam", flux_col="trans")

assert header["SIMPLE"]
assert_quantity_allclose(wavelengths, wav)
assert_quantity_allclose(transmission, thru)

0 comments on commit e8a9607

Please sign in to comment.