Skip to content

Commit

Permalink
add custom_catalog tests
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Apr 18, 2024
1 parent 7c1b004 commit b6f5959
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 24 deletions.
17 changes: 8 additions & 9 deletions jwst/datamodels/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(self, init=None, asn_exptypes=None, asn_n_members=None,
self.asn_table = {}
self.asn_table_name = None
self.asn_pool_name = None
self.asn_file_path = None

self._memmap = kwargs.get("memmap", False)
self._return_open = kwargs.get('return_open', True)
Expand Down Expand Up @@ -196,7 +197,8 @@ def __init__(self, init=None, asn_exptypes=None, asn_n_members=None,
self.from_asn(init)
elif isinstance(init, str):
init_from_asn = self.read_asn(init)
self.from_asn(init_from_asn, asn_file_path=init)
self.asn_file_path = init
self.from_asn(init_from_asn)
else:
raise TypeError('Input {0!r} is not a list of JwstDataModels or '
'an ASN file'.format(init))
Expand Down Expand Up @@ -275,17 +277,14 @@ def read_asn(filepath):
raise IOError("Cannot read ASN file.") from e
return asn_data

def from_asn(self, asn_data, asn_file_path=None):
def from_asn(self, asn_data):
"""
Load fits files from a JWST association file.
Parameters
----------
asn_data : ~jwst.associations.Association
An association dictionary
asn_file_path: str
Filepath of the association, if known.
"""
# match the asn_exptypes to the exptype in the association and retain
# only those file that match, as a list, if asn_exptypes is set to none
Expand All @@ -303,8 +302,8 @@ def from_asn(self, asn_data, asn_file_path=None):
infiles = [member for member
in asn_data['products'][0]['members']]

if asn_file_path:
asn_dir = op.dirname(asn_file_path)
if self.asn_file_path:
asn_dir = op.dirname(self.asn_file_path)
else:
asn_dir = ''

Expand Down Expand Up @@ -348,8 +347,8 @@ def from_asn(self, asn_data, asn_file_path=None):
self.meta.asn_table._instance, asn_data
)

if asn_file_path is not None:
self.asn_table_name = op.basename(asn_file_path)
if self.asn_file_path is not None:
self.asn_table_name = op.basename(self.asn_file_path)
self.asn_pool_name = asn_data['asn_pool']
for model in self:
try:
Expand Down
165 changes: 157 additions & 8 deletions jwst/tweakreg/tests/test_tweakreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from jwst.datamodels import ModelContainer


N_EXAMPLE_SOURCES = 21
N_CUSTOM_SOURCES = 15


@pytest.fixture
def dummy_source_catalog():

Expand Down Expand Up @@ -119,7 +123,7 @@ def example_input(example_wcs):

# and a few 'sources'
m0.data[:] = 0.001
n_sources = 21 # a few more than default minobj
n_sources = N_EXAMPLE_SOURCES # a few more than default minobj
rng = np.random.default_rng(26)
xs = rng.choice(50, n_sources, replace=False) * 8 + 10
ys = rng.choice(50, n_sources, replace=False) * 8 + 10
Expand All @@ -132,13 +136,13 @@ def example_input(example_wcs):

m1 = m0.copy()
# give each a unique filename
m0.meta.filename = 'some_file_0'
m1.meta.filename = 'some_file_1'
m0.meta.filename = 'some_file_0.fits'
m1.meta.filename = 'some_file_1.fits'
c = ModelContainer([m0, m1])
return c


def test_run_tweakreg(example_input):
def test_tweakreg_step(example_input):
# shift 9 pixels
example_input[1].data = np.roll(example_input[1].data, 9, axis=0)

Expand All @@ -161,18 +165,163 @@ def test_abs_refcat():
pass


def test_custom_catalog():
@pytest.fixture()
def custom_catalog_path(tmp_path):
fn = tmp_path / "custom_catalog.ecsv"

# it's important that the sources here don't match
# those added in example_input but conform to the input
# shape, wcs, etc used in example_input
rng = np.random.default_rng(42)
n_sources = N_CUSTOM_SOURCES
xs = rng.choice(50, n_sources, replace=False) * 8 + 10
ys = rng.choice(50, n_sources, replace=False) * 8 + 10
catalog = Table(np.vstack((xs, ys)).T, names=['x', 'y'], dtype=[float, float])
catalog.write(fn)
return fn


@pytest.mark.parametrize(
"catfile",
["skip", "valid", "invalid", "empty"],
ids=["catfile_skip", "catfile_valid", "catfile_invalid", "catfile_empty"],
)
@pytest.mark.parametrize(
"asn",
["skip", "valid", "empty"],
ids=["asn_skip", "asn_valid", "asn_empty"],
)
@pytest.mark.parametrize(
"meta",
["skip", "valid", "empty"],
ids=["meta_skip", "meta_valid", "meta_empty"],
)
@pytest.mark.parametrize("custom", [True, False])
@pytest.mark.slow
def test_custom_catalog(custom_catalog_path, example_input, catfile, asn, meta, custom, monkeypatch):
"""
Options:
if use_custom_catalogs is False, don't use a catalog
if use_custom_catalogs is True...
if catfile is defined...
if catfile loads -> use_custom_catalogs (ignore asn table)
if catfile loads -> use_custom_catalogs (ignore asn table, but not model.tweakreg_catalog?)
if catfile fails to load -> warn and disable custom catalogs
if catfile is not defined...
if input doesn't have an asn table -> disable custom catalogs
if input doesn't have an asn table...
if model has tweakreg_catalog, use it
if input has an asn table...
if member has a tweakreg_catalog use it
if not, use meta.tweakreg_catalog
if member doesn't have a tweakreg_catalog don't use a custom catalog
If the entry in either catfile or asn is "", don't use a custom catalog.
Inputs:
use_custom_catalogs: True/False
catfile: missing, non-valid file (load attempted), valid file
asn_file/table: with tweakreg_catalog, without tweakreg_catalog
model.meta.tweakreg_catalog (per model)
Could run step with save_catalog to generate a catalog... or just make a fake one
"""
pass
example_input[0].meta.group_id = 'a'
example_input[1].meta.group_id = 'b'

# this worked because if use_custom_catalogs was true but
# catfile was blank tweakreg still uses custom catalogs
# which in this case is defined in model.meta.tweakreg_catalog
if meta == "valid":
example_input[0].meta.tweakreg_catalog = str(custom_catalog_path)
elif meta == "empty":
example_input[0].meta.tweakreg_catalog = ""

# write out the ModelContainer and association (so the association table will be loaded)
example_input.save(dir_path=str(custom_catalog_path.parent))
asn_data = {
'asn_id': 'foo',
'asn_pool': 'bar',
'products': [
{
'members': [{'expname': m.meta.filename, 'exptype': 'science'} for m in example_input],
},
],
}

if asn == "empty":
asn_data['products'][0]['members'][0]['tweakreg_catalog'] = ''
elif asn == "valid":
asn_data['products'][0]['members'][0]['tweakreg_catalog'] = str(custom_catalog_path.name)

import json
asn_path = custom_catalog_path.parent / 'example_input.json'
with open(asn_path, 'w') as f:
json.dump(asn_data, f)

# write out a catfile
if catfile != "skip":
catfile_path = custom_catalog_path.parent / 'catfile.txt'
with open(catfile_path, 'w') as f:
if catfile == "valid":
f.write(f"{example_input[0].meta.filename} {custom_catalog_path.name}")
elif catfile == "empty":
f.write(f"{example_input[0].meta.filename}")
elif catfile == "invalid":
pass


# figure out how many sources to expect for the model in group 'a'
n_sources = N_EXAMPLE_SOURCES
custom_number = N_CUSTOM_SOURCES
if not custom:
# if use_custom_catalog is False, the custom catalog shouldn't be used
n_custom_sources = n_sources
else:
if catfile == "valid":
# for a 'valid' catfile, expect the custom number
n_custom_sources = custom_number
elif catfile == "invalid":
# for an 'invalid' catfile, use_custom_catalog should become disabled
n_custom_sources = n_sources
elif catfile == "empty":
# for a catfile with an 'empty' entry, no custom catalog should be used
n_custom_sources = n_sources
else: # catfile == "skip"
assert catfile == "skip" # sanity check
# since catfile is not defined, now look at asn_
if asn == "valid":
# for a 'valid' asn entry, expect the custom number
n_custom_sources = custom_number
elif asn == "empty":
# for a 'empty' asn entry, no custom catalog should be used
n_custom_sources = n_sources
else: # asn == "skip"
assert asn == "skip" # sanity check
if meta == "valid":
n_custom_sources = custom_number
elif meta == "empty":
n_custom_sources = n_sources
else: # meta == "skip"
assert meta == "skip"
n_custom_sources = n_sources

kwargs = {'use_custom_catalogs': custom}
if catfile != "skip":
kwargs["catfile"] = str(catfile_path)
step = tweakreg_step.TweakRegStep(**kwargs)

# patch _construct_wcs_corrector to check the correct catalog was loaded
def patched_construct_wcs_corrector(model, catalog, _seen=[]):
# we don't need to continue
if model.meta.group_id == 'a':
assert len(catalog) == n_custom_sources
elif model.meta.group_id == 'b':
assert len(catalog) == n_sources
_seen.append(model)
if len(_seen) == 2:
raise ValueError("done testing")
return None

monkeypatch.setattr(tweakreg_step, "_construct_wcs_corrector", patched_construct_wcs_corrector)

with pytest.raises(ValueError, match="done testing"):
step(str(asn_path))
21 changes: 14 additions & 7 deletions jwst/tweakreg/tweakreg_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,15 @@ def process(self, input):
)
use_custom_catalogs = False
# else, load from association
elif hasattr(images.meta, "asn_table"):
elif hasattr(images.meta, "asn_table") and getattr(images, "asn_file_path", None) is not None:
catdict = {}
asn_dir = path.dirname(images.asn_file_path)
for member in images.meta.asn_table.products[0].members:
if "tweakreg_catalog" in member:
catdict[member.expname] = member.tweakreg_catalog
else:
# no custom catalogs were found, so don't check
use_custom_catalogs = False
if hasattr(member, "tweakreg_catalog"):
if member.tweakreg_catalog is None or not member.tweakreg_catalog.strip():
catdict[member.expname] = None
else:
catdict[member.expname] = path.join(asn_dir, member.tweakreg_catalog)

if self.abs_refcat is not None and self.abs_refcat.strip():
align_to_abs_refcat = True
Expand Down Expand Up @@ -183,7 +184,12 @@ def process(self, input):

# Build the catalog and corrector for each input images
for (model_index, image_model) in enumerate(images):
if use_custom_catalogs and image_model.meta.filename in catdict:
# now that the model is open, check it's metadata for a custom catalog
# only if it's not listed in the catdict
if use_custom_catalogs and image_model.meta.filename not in catdict:
if (image_model.meta.tweakreg_catalog is not None and image_model.meta.tweakreg_catalog.strip()):
catdict[image_model.meta.filename] = image_model.meta.tweakreg_catalog
if use_custom_catalogs and catdict.get(image_model.meta.filename, None) is not None:
# FIXME this modifies the input_model
image_model.meta.tweakreg_catalog = catdict[image_model.meta.filename]
# use user-supplied catalog:
Expand Down Expand Up @@ -580,6 +586,7 @@ def _parse_catfile(catfile):
if len(catalog) == 1:
catdict[data_model] = path.join(catfile_dir, catalog[0])
elif len(catalog) == 0:
# set this to None so it's custom catalog is skipped
catdict[data_model] = None
else:
raise ValueError("'catfile' can contain at most two columns.")
Expand Down

0 comments on commit b6f5959

Please sign in to comment.