Skip to content

Commit

Permalink
FITS spec I/O col lookup is case-insensitive
Browse files Browse the repository at this point in the history
again.
  • Loading branch information
pllim committed Mar 19, 2024
1 parent f968204 commit 686bada
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
3 changes: 1 addition & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
ensure that the correct ``TUNITn`` is read. As a result, ``wave_unit`` and
``flux_unit`` keywords are deprecated and no longer used in that function.
Additionally, if any ``TUNITn`` in the table is invalid, regardless whether
the column is used or not, an exception will now be raised. The inputs
for ``wave_col`` and ``flux_col`` are now case-sensitive. [#384]
the column is used or not, an exception will now be raised. [#384]

- ``read_spec()`` now detects whether given filename is FITS more consistently
w.r.t. ``astropy``. [#384]
Expand Down
12 changes: 9 additions & 3 deletions synphot/specio.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def read_fits_spec(filename, ext=1, wave_col='WAVELENGTH', flux_col='FLUX',
FITS extension with table data. Default is 1.
wave_col, flux_col : str
Wavelength and flux column names (case-sensitive).
Wavelength and flux column names (case-insensitive).
wave_unit, flux_unit : str or `~astropy.units.Unit`
Wavelength and flux units. These are *no longer used*.
Expand All @@ -185,6 +185,9 @@ def read_fits_spec(filename, ext=1, wave_col='WAVELENGTH', flux_col='FLUX',
Wavelength and flux of the spectrum.
"""
wave_col = wave_col.lower()
flux_col = flux_col.lower()

try:
fs = fits.open(filename)
subhdu = fs[ext]
Expand All @@ -199,9 +202,12 @@ def read_fits_spec(filename, ext=1, wave_col='WAVELENGTH', flux_col='FLUX',

t = QTable.read(subhdu)
header = dict(fs["PRIMARY"].header)
t_col_wave = t[wave_col]

# https://github.com/astropy/astropy/issues/16221
lower_colnames = [c.lower() for c in t.colnames]
t_col_wave = t.columns[lower_colnames.index(wave_col)]
t_col_flux = t.columns[lower_colnames.index(flux_col)]
wavelengths = t_col_wave.value * (t_col_wave.unit or u.dimensionless_unscaled) # noqa: E501
t_col_flux = t[flux_col]
fluxes = t_col_flux.value * (t_col_flux.unit or u.dimensionless_unscaled) # noqa: E501

finally:
Expand Down
16 changes: 11 additions & 5 deletions synphot/tests/test_specio.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,24 @@ def test_read_nonstandard_fits_cols_01(tmp_path):
trace = np.array([0, 0.5, 1, 0.9, 0])
coldefs = fits.ColDefs([
fits.Column(name="X", format="I", array=pix),
fits.Column(name="WAVELENGTH", format="E",
fits.Column(name="Wavelength", format="E",
unit=wav.unit.to_string(format="fits"), array=wav.value),
fits.Column(name="TRACE", format="E", array=trace)])
fits.Column(name="Trace", format="E", array=trace)])
hdulist = fits.HDUList([
fits.PrimaryHDU(),
fits.BinTableHDU.from_columns(coldefs)])
outfile = str(tmp_path / "jwst_niriss_soss_trace.fits")
hdulist.writeto(outfile, overwrite=True)

tr = SpectralElement.from_file(outfile, flux_col="TRACE")
assert_quantity_allclose(tr.waveset, wav)
assert_quantity_allclose(tr(wav), trace, atol=1e-7)
# Make sure column names are still case insensitive.
for (wave_col, flux_col) in (
("Wavelength", "Trace"),
("WAVELENGTH", "TRACE"),
("wavelength", "trace")):
tr = SpectralElement.from_file(
outfile, wave_col=wave_col, flux_col=flux_col)
assert_quantity_allclose(tr.waveset, wav)
assert_quantity_allclose(tr(wav), trace, atol=1e-7)


def test_read_nonstandard_fits_cols_02(tmp_path):
Expand Down

0 comments on commit 686bada

Please sign in to comment.