Skip to content

Commit

Permalink
TST: Improve coverage, use tmp_path
Browse files Browse the repository at this point in the history
instead of tempfile or tmpdir
  • Loading branch information
pllim committed Mar 15, 2024
1 parent c685801 commit 7e2fbb0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 47 deletions.
2 changes: 1 addition & 1 deletion synphot/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,7 +1944,7 @@ def from_file(cls, filename, **kwargs):
if is_fits("", filename, None):
if 'flux_col' not in kwargs:
kwargs['flux_col'] = 'THROUGHPUT'
elif 'flux_unit' not in kwargs:
elif 'flux_unit' not in kwargs: # pragma: no cover
kwargs['flux_unit'] = cls._internal_flux_unit

header, wavelengths, throughput = specio.read_spec(filename, **kwargs)
Expand Down
52 changes: 22 additions & 30 deletions synphot/tests/test_reddening.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

# STDLIB
import os
import shutil
import tempfile

# THIRD-PARTY
import numpy as np
Expand Down Expand Up @@ -156,32 +154,26 @@ def test_redlaw_from_model_exception():
ReddeningLaw.from_extinction_model('foo')


class TestWriteReddeningLaw:
@pytest.mark.parametrize('ext_hdr', [None, {'foo': 'foo'}])
def test_write_reddening_law(tmp_path, ext_hdr):
"""Test ReddeningLaw ``to_fits()`` method."""
def setup_class(self):
self.outdir = tempfile.mkdtemp()
self.x = np.linspace(1000, 5000, 5)
self.y = np.linspace(1, 5, 5) * 0.1
self.redlaw = ReddeningLaw(
Empirical1D, points=self.x, lookup_table=self.y)

@pytest.mark.parametrize('ext_hdr', [None, {'foo': 'foo'}])
def test_write(self, ext_hdr):
outfile = os.path.join(self.outdir, 'outredlaw.fits')

if ext_hdr is None:
self.redlaw.to_fits(outfile, overwrite=True)
else:
self.redlaw.to_fits(outfile, overwrite=True, ext_header=ext_hdr)

# Read it back in and check
redlaw2 = ReddeningLaw.from_file(outfile)
np.testing.assert_allclose(redlaw2.waveset.value, self.x)
np.testing.assert_allclose(redlaw2(self.x).value, self.y)

if ext_hdr is not None:
hdr = fits.getheader(outfile, 1)
assert 'foo' in hdr

def teardown_class(self):
shutil.rmtree(self.outdir)
x = np.linspace(1000, 5000, 5)
y = np.linspace(1, 5, 5) * 0.1
redlaw = ReddeningLaw(
Empirical1D, points=x, lookup_table=y, meta={"expr": "ebv(test)"})

outfile = str(tmp_path / 'outredlaw.fits')

if ext_hdr is None:
redlaw.to_fits(outfile, overwrite=True)
else:
redlaw.to_fits(outfile, overwrite=True, ext_header=ext_hdr)

# Read it back in and check
redlaw2 = ReddeningLaw.from_file(outfile)
np.testing.assert_allclose(redlaw2.waveset.value, x)
np.testing.assert_allclose(redlaw2(x).value, y)

if ext_hdr is not None:
hdr = fits.getheader(outfile, 1)
assert 'foo' in hdr
13 changes: 2 additions & 11 deletions synphot/tests/test_spectrum_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
"""Test spectrum.py module and related functionalities that are not covered
by ``test_spectrum_source.py`` nor ``test_spectrum_bandpass.py``."""

# STDLIB
import os
import shutil
import tempfile

# THIRD-PARTY
import numpy as np
import pytest
Expand Down Expand Up @@ -390,7 +385,6 @@ class DummyObject:
class TestWriteSpec:
"""Test spectrum to_fits() method."""
def setup_class(self):
self.outdir = tempfile.mkdtemp()
self.sp = SourceSpectrum(
Empirical1D, points=_wave, lookup_table=_flux_photlam,
meta={'expr': 'Test source'})
Expand All @@ -404,8 +398,8 @@ def setup_class(self):
(True, {'foo': 'foo'}),
(False, None),
(False, {'foo': 'foo'})])
def test_write(self, is_sp, ext_hdr):
outfile = os.path.join(self.outdir, 'outspec.fits')
def test_write(self, tmp_path, is_sp, ext_hdr):
outfile = str(tmp_path / 'outspec.fits')

if is_sp:
sp1 = self.sp
Expand All @@ -426,6 +420,3 @@ def test_write(self, is_sp, ext_hdr):
assert 'expr' in hdr
if ext_hdr is not None:
assert 'foo' in hdr

def teardown_class(self):
shutil.rmtree(self.outdir)
12 changes: 7 additions & 5 deletions synphot/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def test_merge_same(self):
np.testing.assert_array_equal(wave, self.wave)


def test_download_bad_root(tmpdir):
def test_download_bad_root(tmp_path):
"""Test data download helper when input dir is invalid."""
ptr = tmpdir.join('bad_cdbs')
ptr.write('content')
ptr = tmp_path / 'bad_cdbs'
ptr.mkdir()
f = ptr / "content"
f.write_text("something")
cdbs_root = str(ptr)

with pytest.raises(OSError):
Expand All @@ -129,12 +131,12 @@ def test_download_bad_root(tmpdir):
utils.download_data('', verbose=False)


def test_download_data(tmpdir):
def test_download_data(tmp_path):
"""Test data download helper in dry run mode."""
from synphot.config import conf

# Use case where user downloads all data into new dir.
cdbs_root = os.path.join(tmpdir.strpath, 'cdbs')
cdbs_root = str(tmp_path / 'cdbs')
file_list_1 = utils.download_data(cdbs_root, verbose=False, dry_run=True)
filename = file_list_1[0]
assert len(file_list_1) == 21
Expand Down

0 comments on commit 7e2fbb0

Please sign in to comment.