diff --git a/CHANGES.rst b/CHANGES.rst index 706a0fb0897..a53c0b1e315 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -103,6 +103,8 @@ outlier_detection to detect outliers in TSO and coronagraphic data, with user-defined rolling window width via the ``n_ints`` parameter. [#8473] +- Added tests for changes made in #8464. [#8481] + photom ------ @@ -151,11 +153,16 @@ tweakreg - Improved how a image group name is determined. [#8426] +- Refactor step to work towards performance improvements. [#8424] + - Changed default settings for ``abs_separation`` parameter for the ``tweakreg`` step to have a value compatible with the ``abs_tolerance`` parameter. [#8445] - Improve error handling in the absolute alignment. [#8450, #8477] +- Change code default to use IRAF StarFinder instead of + DAO StarFinder [#8487] + wfss_contam ----------- diff --git a/docs/jwst/tweakreg/README.rst b/docs/jwst/tweakreg/README.rst index 30e39ba0b51..5e707dbbb62 100644 --- a/docs/jwst/tweakreg/README.rst +++ b/docs/jwst/tweakreg/README.rst @@ -218,7 +218,7 @@ The ``tweakreg`` step has the following optional arguments: in pixels. (Default=400) * ``starfinder``: A `str` indicating the source detection algorithm to use. - Allowed values: `'iraf'`, `'dao'`, `'segmentation'`. (Default= `'dao'`) + Allowed values: `'iraf'`, `'dao'`, `'segmentation'`. (Default= `'iraf'`) * ``snr_threshold``: A `float` value indicating SNR threshold above the background. Required for all star finders. (Default=10.0) diff --git a/jwst/assign_wcs/util.py b/jwst/assign_wcs/util.py index 50c8ba31d3e..f9f90918b4d 100644 --- a/jwst/assign_wcs/util.py +++ b/jwst/assign_wcs/util.py @@ -199,7 +199,7 @@ def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None, pscale_ratio=None, pscale=None, rotation=None, - shape=None, crpix=None, crval=None): + shape=None, crpix=None, crval=None, wcslist=None): """ Create a WCS from a list of input data models. @@ -259,7 +259,8 @@ def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=Non """ bb = bounding_box - wcslist = [im.meta.wcs for im in dmodels] + if wcslist is None: + wcslist = [im.meta.wcs for im in dmodels] if not isiterable(wcslist): raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.") diff --git a/jwst/coron/tests/test_coron.py b/jwst/coron/tests/test_coron.py index 0d82220341f..377486d97fc 100644 --- a/jwst/coron/tests/test_coron.py +++ b/jwst/coron/tests/test_coron.py @@ -141,8 +141,8 @@ def test_align_array(): ] ) - npt.assert_allclose(aligned, truth_aligned, atol=1e-6) - npt.assert_allclose(shifts, truth_shifts, atol=1e-6) + npt.assert_allclose(aligned, truth_aligned, atol=1e-5) + npt.assert_allclose(shifts, truth_shifts, atol=1e-5) def test_align_models(): diff --git a/jwst/datamodels/container.py b/jwst/datamodels/container.py index 61521562249..45440bcb170 100644 --- a/jwst/datamodels/container.py +++ b/jwst/datamodels/container.py @@ -160,6 +160,7 @@ def __init__(self, init=None, asn_exptypes=None, asn_n_members=None, self.asn_table = {} self.asn_table_name = None self.asn_pool_name = None + self.asn_file_path = None self._memmap = kwargs.get("memmap", False) self._return_open = kwargs.get('return_open', True) @@ -196,7 +197,8 @@ def __init__(self, init=None, asn_exptypes=None, asn_n_members=None, self.from_asn(init) elif isinstance(init, str): init_from_asn = self.read_asn(init) - self.from_asn(init_from_asn, asn_file_path=init) + self.asn_file_path = init + self.from_asn(init_from_asn) else: raise TypeError('Input {0!r} is not a list of JwstDataModels or ' 'an ASN file'.format(init)) @@ -275,7 +277,7 @@ def read_asn(filepath): raise IOError("Cannot read ASN file.") from e return asn_data - def from_asn(self, asn_data, asn_file_path=None): + def from_asn(self, asn_data): """ Load fits files from a JWST association file. @@ -283,9 +285,6 @@ def from_asn(self, asn_data, asn_file_path=None): ---------- asn_data : ~jwst.associations.Association An association dictionary - - asn_file_path: str - Filepath of the association, if known. """ # match the asn_exptypes to the exptype in the association and retain # only those file that match, as a list, if asn_exptypes is set to none @@ -303,8 +302,8 @@ def from_asn(self, asn_data, asn_file_path=None): infiles = [member for member in asn_data['products'][0]['members']] - if asn_file_path: - asn_dir = op.dirname(asn_file_path) + if self.asn_file_path: + asn_dir = op.dirname(self.asn_file_path) else: asn_dir = '' @@ -348,8 +347,8 @@ def from_asn(self, asn_data, asn_file_path=None): self.meta.asn_table._instance, asn_data ) - if asn_file_path is not None: - self.asn_table_name = op.basename(asn_file_path) + if self.asn_file_path is not None: + self.asn_table_name = op.basename(self.asn_file_path) self.asn_pool_name = asn_data['asn_pool'] for model in self: try: diff --git a/jwst/outlier_detection/outlier_detection.py b/jwst/outlier_detection/outlier_detection.py index 576a35d6491..a4e2d2c3f7a 100644 --- a/jwst/outlier_detection/outlier_detection.py +++ b/jwst/outlier_detection/outlier_detection.py @@ -387,7 +387,7 @@ def detect_outliers(self, blot_models): def _remove_file(fn): if isinstance(fn, str) and os.path.isfile(fn): os.remove(fn) - log.debug(f" {fn}") + log.info(f"Removing file {fn}") def flag_cr(sci_image, blot_image, snr="5.0 4.0", scale="1.2 0.7", backg=0, diff --git a/jwst/outlier_detection/outlier_detection_spec.py b/jwst/outlier_detection/outlier_detection_spec.py index 03f32ac1eaf..d4f4b799c5c 100644 --- a/jwst/outlier_detection/outlier_detection_spec.py +++ b/jwst/outlier_detection/outlier_detection_spec.py @@ -110,6 +110,7 @@ def do_detection(self): if not pars['in_memory']: for fn in drizzled_models._models: _remove_file(fn) + log.info(f"Removing file {fn}") if pars['resample_data'] is True: # Blot the median image back to recreate each input image specified @@ -130,4 +131,5 @@ def do_detection(self): if not pars['save_intermediate_results']: for fn in blot_models._models: _remove_file(fn) + log.info(f"Removing file {fn}") del median_model, blot_models diff --git a/jwst/outlier_detection/outlier_detection_step.py b/jwst/outlier_detection/outlier_detection_step.py index a86d8f64112..980e72b79f5 100644 --- a/jwst/outlier_detection/outlier_detection_step.py +++ b/jwst/outlier_detection/outlier_detection_step.py @@ -174,8 +174,6 @@ def process(self, input_data): state = 'COMPLETE' if self.input_container: - if not self.save_intermediate_results: - self.log.debug("The following files will be deleted since save_intermediate_results=False:") for model in self.input_models: model.meta.cal_step.outlier_detection = state else: diff --git a/jwst/outlier_detection/tests/test_outlier_detection.py b/jwst/outlier_detection/tests/test_outlier_detection.py index d77c6cf4778..50655240936 100644 --- a/jwst/outlier_detection/tests/test_outlier_detection.py +++ b/jwst/outlier_detection/tests/test_outlier_detection.py @@ -1,6 +1,8 @@ import pytest import numpy as np from scipy.ndimage import gaussian_filter +from glob import glob +import os from stdatamodels.jwst import datamodels @@ -15,7 +17,6 @@ ) from jwst.assign_wcs.pointing import create_fitswcs - OUTLIER_DO_NOT_USE = np.bitwise_or( datamodels.dqflags.pixel["DO_NOT_USE"], datamodels.dqflags.pixel["OUTLIER"] ) @@ -184,6 +185,15 @@ def test_outlier_step(we_three_sci, tmp_cwd): # Drop a CR on the science array container[0].data[12, 12] += 1 + # Verify that intermediary files are removed + OutlierDetectionStep.call(container) + i2d_files = glob(os.path.join(tmp_cwd, '*i2d.fits')) + median_files = glob(os.path.join(tmp_cwd, '*median.fits')) + blot_files = glob(os.path.join(tmp_cwd, '*blot.fits')) + assert len(i2d_files) == 0 + assert len(median_files) == 0 + assert len(blot_files) == 0 + result = OutlierDetectionStep.call( container, save_results=True, save_intermediate_results=True ) @@ -199,6 +209,14 @@ def test_outlier_step(we_three_sci, tmp_cwd): # Verify CR is flagged assert result[0].dq[12, 12] == OUTLIER_DO_NOT_USE + # Verify that intermediary files are saved at the specified location + i2d_files = glob(os.path.join(tmp_cwd, '*i2d.fits')) + median_files = glob(os.path.join(tmp_cwd, '*median.fits')) + blot_files = glob(os.path.join(tmp_cwd, '*blot.fits')) + assert len(i2d_files) != 0 + assert len(median_files) != 0 + assert len(blot_files) != 0 + def test_outlier_step_on_disk(we_three_sci, tmp_cwd): """Test whole step with an outlier including saving intermediate and results files""" diff --git a/jwst/regtest/test_miri_lrs_slit_spec3.py b/jwst/regtest/test_miri_lrs_slit_spec3.py index 513cc5740b0..d8766db97e1 100644 --- a/jwst/regtest/test_miri_lrs_slit_spec3.py +++ b/jwst/regtest/test_miri_lrs_slit_spec3.py @@ -79,7 +79,7 @@ def test_miri_lrs_slit_spec3(run_pipeline, rtdata_module, fitsdiff_default_kwarg diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) assert diff.identical, diff.report() - if "s2d" in output: + if output == "s2d": # Compare the calculated wavelengths tolerance = 1e-03 dmt = datamodels.open(rtdata.truth) diff --git a/jwst/regtest/test_niriss_image.py b/jwst/regtest/test_niriss_image.py index eb2a0ad70af..d117b3fd104 100644 --- a/jwst/regtest/test_niriss_image.py +++ b/jwst/regtest/test_niriss_image.py @@ -54,17 +54,26 @@ def test_niriss_tweakreg_no_sources(rtdata, fitsdiff_default_kwargs): rtdata.input = "niriss/imaging/jw01537-o003_20240406t164421_image3_00004_asn.json" rtdata.get_asn("niriss/imaging/jw01537-o003_20240406t164421_image3_00004_asn.json") - args = ["jwst.tweakreg.TweakRegStep", rtdata.input, "--abs_refcat='GAIADR3'"] + args = [ + "jwst.tweakreg.TweakRegStep", + rtdata.input, + "--abs_refcat='GAIADR3'", + "--save_results=True", + ] result = Step.from_cmdline(args) # Check that the step is skipped assert result.skip # Check the status of the step is set correctly in the files. - result = TweakRegStep.call(rtdata.input) + mc = datamodels.ModelContainer(rtdata.input) - for fi in result._models: - with datamodels.open(fi) as model: - assert model.meta.cal_step.tweakreg == 'SKIPPED' + for model in mc: + assert model.meta.cal_step.tweakreg != 'SKIPPED' + + result = TweakRegStep.call(mc) + + for model in result: + assert model.meta.cal_step.tweakreg == 'SKIPPED' result.close() diff --git a/jwst/regtest/test_nirspec_fs_spec3.py b/jwst/regtest/test_nirspec_fs_spec3.py index ac15e6a1fd1..869d6c39499 100644 --- a/jwst/regtest/test_nirspec_fs_spec3.py +++ b/jwst/regtest/test_nirspec_fs_spec3.py @@ -46,7 +46,7 @@ def test_nirspec_fs_spec3(run_pipeline, rtdata_module, fitsdiff_default_kwargs, diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) assert diff.identical, diff.report() - if "s2d" in output: + if output == "s2d": # Compare the calculated wavelengths tolerance = 1e-03 dmt = datamodels.open(rtdata.truth) diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 1884e39864f..70641182f55 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -72,7 +72,7 @@ def __init__(self, input_models, output=None, single=False, blendheaders=True, self.input_models = input_models self.output_dir = None self.output_filename = output - if output is not None and '.fits' not in output: + if output is not None and '.fits' not in str(output): self.output_dir = output self.output_filename = None diff --git a/jwst/resample/resample_spec.py b/jwst/resample/resample_spec.py index 162ec403b47..fcd49a93029 100644 --- a/jwst/resample/resample_spec.py +++ b/jwst/resample/resample_spec.py @@ -64,7 +64,7 @@ def __init__(self, input_models, output=None, single=False, blendheaders=False, self.output_filename = output self.output_dir = None - if output is not None and '.fits' not in output: + if output is not None and '.fits' not in str(output): self.output_dir = output self.output_filename = None self.pscale_ratio = pscale_ratio diff --git a/jwst/tweakreg/tests/test_tweakreg.py b/jwst/tweakreg/tests/test_tweakreg.py index cb5baea3f91..c6f6e2b5968 100644 --- a/jwst/tweakreg/tests/test_tweakreg.py +++ b/jwst/tweakreg/tests/test_tweakreg.py @@ -1,14 +1,23 @@ from copy import deepcopy +import json import os import asdf from astropy.modeling.models import Shift from astropy.table import Table +import numpy as np import pytest from jwst.tweakreg import tweakreg_step from jwst.tweakreg import tweakreg_catalog +from jwst.tweakreg.utils import _wcsinfo_from_wcs_transform from stdatamodels.jwst.datamodels import ImageModel +from jwst.datamodels import ModelContainer + + +BKG_LEVEL = 0.001 +N_EXAMPLE_SOURCES = 21 +N_CUSTOM_SOURCES = 15 @pytest.fixture @@ -21,8 +30,55 @@ def dummy_source_catalog(): return catalog +@pytest.mark.parametrize("inplace", [True, False]) +def test_rename_catalog_columns(dummy_source_catalog, inplace): + """ + Test that a catalog with 'xcentroid' and 'ycentroid' columns + passed to _renamed_catalog_columns successfully renames those columns + to 'x' and 'y' (and does so "inplace" modifying the input catalog) + """ + renamed_catalog = tweakreg_step._rename_catalog_columns(dummy_source_catalog) + + # if testing inplace, check the input catalog + if inplace: + catalog = dummy_source_catalog + else: + catalog = renamed_catalog + + assert 'xcentroid' not in catalog.colnames + assert 'ycentroid' not in catalog.colnames + assert 'x' in catalog.colnames + assert 'y' in catalog.colnames + + +@pytest.mark.parametrize("missing", ["x", "y", "xcentroid", "ycentroid"]) +def test_rename_catalog_columns_invalid(dummy_source_catalog, missing): + """ + Test that passing a catalog that is missing either "x" or "y" + (or "xcentroid" and "ycentroid" which is renamed to "x" or "y") + results in an exception indicating that a required column is missing + """ + # if the column we want to remove is not in the table, first run + # rename to rename columns this should add the column we want to remove + if missing not in dummy_source_catalog.colnames: + tweakreg_step._rename_catalog_columns(dummy_source_catalog) + dummy_source_catalog.remove_column(missing) + with pytest.raises(ValueError, match="catalogs must contain"): + tweakreg_step._rename_catalog_columns(dummy_source_catalog) + + @pytest.mark.parametrize("offset, is_good", [(1 / 3600, True), (11 / 3600, False)]) def test_is_wcs_correction_small(offset, is_good): + """ + Test that the _is_wcs_correction_small method returns True for a small + wcs correction and False for a "large" wcs correction. The values in this + test are selected based on the current step default parameters: + - use2dhist + - searchrad + - tolerance + Changes to the defaults for these parameters will likely require updating the + values uses for parametrizing this test. + """ path = os.path.join(os.path.dirname(__file__), "mosaic_long_i2d_gwcs.asdf") with asdf.open(path) as af: wcs = af.tree["wcs"] @@ -35,7 +91,18 @@ def test_is_wcs_correction_small(offset, is_good): step = tweakreg_step.TweakRegStep() - assert step._is_wcs_correction_small(wcs, twcs) == is_good + class FakeCorrector: + def __init__(self, wcs, original_skycoord): + self.wcs = wcs + self._original_skycoord = original_skycoord + + @property + def meta(self): + return {'original_skycoord': self._original_skycoord} + + correctors = [FakeCorrector(twcs, tweakreg_step._wcs_to_skycoord(wcs))] + + assert step._is_wcs_correction_small(correctors) == is_good def test_expected_failure_bad_starfinder(): @@ -51,11 +118,206 @@ def test_write_catalog(dummy_source_catalog, tmp_cwd): ''' OUTDIR = 'outdir' - model = ImageModel() step = tweakreg_step.TweakRegStep() os.mkdir(OUTDIR) step.output_dir = OUTDIR expected_outfile = os.path.join(OUTDIR, 'catalog.ecsv') - step._write_catalog(model, dummy_source_catalog, 'catalog.ecsv') + step._write_catalog(dummy_source_catalog, 'catalog.ecsv') + + assert os.path.exists(expected_outfile) + + +@pytest.fixture() +def example_wcs(): + path = os.path.join( + os.path.dirname(__file__), + "data", + "nrcb1-wcs.asdf") + with asdf.open(path, lazy_load=False) as af: + return af.tree["wcs"] + + +@pytest.fixture() +def example_input(example_wcs): + m0 = ImageModel((512, 512)) + + # add a wcs and wcsinfo + m0.meta.wcs = example_wcs + m0.meta.wcsinfo = _wcsinfo_from_wcs_transform(example_wcs) + + # and a few 'sources' + m0.data[:] = BKG_LEVEL + n_sources = N_EXAMPLE_SOURCES # a few more than default minobj + rng = np.random.default_rng(26) + xs = rng.choice(50, n_sources, replace=False) * 8 + 10 + ys = rng.choice(50, n_sources, replace=False) * 8 + 10 + for y, x in zip(ys, xs): + m0.data[y-1:y+2, x-1:x+2] = [ + [0.1, 0.6, 0.1], + [0.6, 0.8, 0.6], + [0.1, 0.6, 0.1], + ] + + m1 = m0.copy() + # give each a unique filename + m0.meta.filename = 'some_file_0.fits' + m1.meta.filename = 'some_file_1.fits' + c = ModelContainer([m0, m1]) + return c + + +@pytest.mark.parametrize("with_shift", [True, False]) +def test_tweakreg_step(example_input, with_shift): + """ + A simplified unit test for basic operation of the TweakRegStep + when run with or without a small shift in the input image sources + """ + if with_shift: + # shift 9 pixels so that the sources in one of the 2 images + # appear at different locations (resulting in a correct wcs update) + example_input[1].data[:-9] = example_input[1].data[9:] + example_input[1].data[-9:] = BKG_LEVEL + + # assign images to different groups (so they are aligned to each other) + example_input[0].meta.group_id = 'a' + example_input[1].meta.group_id = 'b' + + # make the step with default arguments + step = tweakreg_step.TweakRegStep() + + # run the step on the example input modified above + result = step(example_input) + + # check that step completed + for model in result: + assert model.meta.cal_step.tweakreg == 'COMPLETE' + + # and that the wcses differ by a small amount due to the shift above + # by projecting one point through each wcs and comparing the difference + abs_delta = abs(result[1].meta.wcs(0, 0)[0] - result[0].meta.wcs(0, 0)[0]) + if with_shift: + assert abs_delta > 1E-5 + else: + assert abs_delta < 1E-12 + + +@pytest.fixture() +def custom_catalog_path(tmp_path): + fn = tmp_path / "custom_catalog.ecsv" + + # it's important that the sources here don't match + # those added in example_input but conform to the input + # shape, wcs, etc used in example_input + rng = np.random.default_rng(42) + n_sources = N_CUSTOM_SOURCES + xs = rng.choice(50, n_sources, replace=False) * 8 + 10 + ys = rng.choice(50, n_sources, replace=False) * 8 + 10 + catalog = Table(np.vstack((xs, ys)).T, names=['x', 'y'], dtype=[float, float]) + catalog.write(fn) + return fn + + +@pytest.mark.parametrize( + "catfile", + ["no_catfile", "valid_catfile", "invalid_catfile", "empty_catfile_row"], +) +@pytest.mark.parametrize( + "asn", + ["no_cat_in_asn", "cat_in_asn", "empty_asn_entry"], +) +@pytest.mark.parametrize( + "meta", + ["no_meta", "cat_in_meta", "empty_meta"], +) +@pytest.mark.parametrize("custom", [True, False]) +@pytest.mark.slow +def test_custom_catalog(custom_catalog_path, example_input, catfile, asn, meta, custom, monkeypatch): + """ + Test that TweakRegStep uses a custom catalog provided by the user + when the correct set of options are provided. The combinations here can be confusing + and this test attempts to test all likely combinations of: + - a catalog in a `catfile` + - a catalog in the asn + - a catalog in the metadata + combined with step options: + - `use_custom_catalogs` (True/False) + - a "valid" file passed as `catfile` + """ + example_input[0].meta.group_id = 'a' + example_input[1].meta.group_id = 'b' + + # this worked because if use_custom_catalogs was true but + # catfile was blank tweakreg still uses custom catalogs + # which in this case is defined in model.meta.tweakreg_catalog + if meta == "cat_in_meta": + example_input[0].meta.tweakreg_catalog = str(custom_catalog_path) + elif meta == "empty_meta": + example_input[0].meta.tweakreg_catalog = "" + + # write out the ModelContainer and association (so the association table will be loaded) + example_input.save(dir_path=str(custom_catalog_path.parent)) + asn_data = { + 'asn_id': 'foo', + 'asn_pool': 'bar', + 'products': [ + { + 'members': [{'expname': m.meta.filename, 'exptype': 'science'} for m in example_input], + }, + ], + } + + if asn == "empty_asn_entry": + asn_data['products'][0]['members'][0]['tweakreg_catalog'] = '' + elif asn == "cat_in_asn": + asn_data['products'][0]['members'][0]['tweakreg_catalog'] = str(custom_catalog_path.name) + + asn_path = custom_catalog_path.parent / 'example_input.json' + with open(asn_path, 'w') as f: + json.dump(asn_data, f) + + # write out a catfile + if catfile != "no_catfile": + catfile_path = custom_catalog_path.parent / 'catfile.txt' + with open(catfile_path, 'w') as f: + if catfile == "valid_catfile": + f.write(f"{example_input[0].meta.filename} {custom_catalog_path.name}") + elif catfile == "empty_catfile_row": + f.write(f"{example_input[0].meta.filename}") + elif catfile == "invalid_catfile": + pass + + # figure out how many sources to expect for the model in group 'a' + n_custom_sources = N_EXAMPLE_SOURCES + if custom: + if catfile == "valid_catfile": + # for a 'valid' catfile, expect the custom number + n_custom_sources = N_CUSTOM_SOURCES + elif catfile == "no_catfile": + # since catfile is not defined, now look at asn_ + if asn == "cat_in_asn": + # for a 'valid' asn entry, expect the custom number + n_custom_sources = N_CUSTOM_SOURCES + elif asn == "no_cat_in_asn" and meta == "cat_in_meta": + n_custom_sources = N_CUSTOM_SOURCES + + kwargs = {'use_custom_catalogs': custom} + if catfile != "no_catfile": + kwargs["catfile"] = str(catfile_path) + step = tweakreg_step.TweakRegStep(**kwargs) + + # patch _construct_wcs_corrector to check the correct catalog was loaded + def patched_construct_wcs_corrector(model, catalog, _seen=[]): + # we don't need to continue + if model.meta.group_id == 'a': + assert len(catalog) == n_custom_sources + elif model.meta.group_id == 'b': + assert len(catalog) == N_EXAMPLE_SOURCES + _seen.append(model) + if len(_seen) == 2: + raise ValueError("done testing") + return None + + monkeypatch.setattr(tweakreg_step, "_construct_wcs_corrector", patched_construct_wcs_corrector) - assert os.path.exists(expected_outfile) \ No newline at end of file + with pytest.raises(ValueError, match="done testing"): + step(str(asn_path)) diff --git a/jwst/tweakreg/tweakreg_catalog.py b/jwst/tweakreg/tweakreg_catalog.py index fa6835c8b85..bdf9158c143 100644 --- a/jwst/tweakreg/tweakreg_catalog.py +++ b/jwst/tweakreg/tweakreg_catalog.py @@ -148,7 +148,7 @@ def _DaoStarFinderWrapper(data, threshold, mask=None, **kwargs): return sources -def make_tweakreg_catalog(model, snr_threshold, bkg_boxsize=400, starfinder='dao', starfinder_kwargs={}): +def make_tweakreg_catalog(model, snr_threshold, bkg_boxsize=400, starfinder='iraf', starfinder_kwargs={}): """ Create a catalog of point-line sources to be used for image alignment in tweakreg. diff --git a/jwst/tweakreg/tweakreg_step.py b/jwst/tweakreg/tweakreg_step.py index 519a200ab28..287128bcc9d 100644 --- a/jwst/tweakreg/tweakreg_step.py +++ b/jwst/tweakreg/tweakreg_step.py @@ -6,22 +6,21 @@ """ from os import path -from astropy.table import Table from astropy import units as u from astropy.coordinates import SkyCoord +from astropy.table import Table +from astropy.time import Time from tweakwcs.imalign import align_wcs from tweakwcs.correctors import JWSTWCSCorrector from tweakwcs.matchutils import XYXYMatch -from stdatamodels.jwst.datamodels.util import is_association - from jwst.datamodels import ModelContainer # LOCAL from ..stpipe import Step -from ..assign_wcs.util import update_fits_wcsinfo, update_s_region_imaging -from . import astrometric_utils as amutils -from . tweakreg_catalog import make_tweakreg_catalog +from ..assign_wcs.util import update_fits_wcsinfo, update_s_region_imaging, wcs_from_footprints +from .astrometric_utils import create_astrometric_catalog +from .tweakreg_catalog import make_tweakreg_catalog def _oxford_or_str_join(str_list): @@ -56,8 +55,12 @@ class TweakRegStep(Step): use_custom_catalogs = boolean(default=False) # Use custom user-provided catalogs? catalog_format = string(default='ecsv') # Catalog output file format catfile = string(default='') # Name of the file with a list of custom user-provided catalogs - starfinder = option('dao', 'iraf', 'segmentation', default='dao') # Star finder to use. + starfinder = option('dao', 'iraf', 'segmentation', default='iraf') # Star finder to use. + + # general starfinder options snr_threshold = float(default=10.0) # SNR threshold above the bkg for star finder + bkg_boxsize = integer(default=400) # The background mesh box size in pixels. + # kwargs for DAOStarFinder and IRAFStarFinder, only used if starfinder is 'dao' or 'iraf' kernel_fwhm = float(default=2.5) # Gaussian kernel FWHM in pixels minsep_fwhm = float(default=0.0) # Minimum separation between detected objects in FWHM @@ -68,6 +71,7 @@ class TweakRegStep(Step): roundhi = float(default=1.0) # The upper bound on roundness for object detection. brightest = integer(default=200) # Keep top ``brightest`` objects peakmax = float(default=None) # Filter out objects with pixel values >= ``peakmax`` + # kwargs for SourceCatalog and SourceFinder, only used if starfinder is 'segmentation' npixels = integer(default=10) # Minimum number of connected pixels connectivity = option(4, 8, default=8) # The connectivity defining the neighborhood of a pixel @@ -77,90 +81,81 @@ class TweakRegStep(Step): localbkg_width = integer(default=0) # Width of rectangular annulus used to compute local background around each source apermask_method = option('correct', 'mask', 'none', default='correct') # How to handle neighboring sources kron_params = float_list(min=2, max=3, default=None) # Parameters defining Kron aperture - # continue args for rest of step - bkg_boxsize = integer(default=400) # The background mesh box size in pixels. + + # align wcs options enforce_user_order = boolean(default=False) # Align images in user specified order? expand_refcat = boolean(default=False) # Expand reference catalog with new sources? minobj = integer(default=15) # Minimum number of objects acceptable for matching + fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') # Fitting geometry + nclip = integer(min=0, default=3) # Number of clipping iterations in fit + sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units + + # xyxymatch options searchrad = float(default=2.0) # The search radius in arcsec for a match use2dhist = boolean(default=True) # Use 2d histogram to find initial offset? separation = float(default=1.0) # Minimum object separation for xyxymatch in arcsec tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec xoffset = float(default=0.0), # Initial guess for X offset in arcsec yoffset = float(default=0.0) # Initial guess for Y offset in arcsec - fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') # Fitting geometry - nclip = integer(min=0, default=3) # Number of clipping iterations in fit - sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units + + # Absolute catalog options abs_refcat = string(default='') # Catalog file name or one of: {_SINGLE_GROUP_REFCAT_STR}, or None, or '' save_abs_catalog = boolean(default=False) # Write out used absolute astrometric reference catalog as a separate product + + # Absolute catalog align wcs options abs_minobj = integer(default=15) # Minimum number of objects acceptable for matching when performing absolute astrometry + abs_fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') + abs_nclip = integer(min=0, default=3) # Number of clipping iterations in fit when performing absolute astrometry + abs_sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units when performing absolute astrometry + + # absolute catalog xyxymatch options abs_searchrad = float(default=6.0) # The search radius in arcsec for a match when performing absolute astrometry # We encourage setting this parameter to True. Otherwise, xoffset and yoffset will be set to zero. abs_use2dhist = boolean(default=True) # Use 2D histogram to find initial offset when performing absolute astrometry? abs_separation = float(default=1) # Minimum object separation in arcsec when performing absolute astrometry abs_tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec when performing absolute astrometry - # Fitting geometry when performing absolute astrometry - abs_fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') - abs_nclip = integer(min=0, default=3) # Number of clipping iterations in fit when performing absolute astrometry - abs_sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units when performing absolute astrometry + + # stpipe general options output_use_model = boolean(default=True) # When saving use `DataModel.meta.filename` """ reference_file_types = [] def process(self, input): - use_custom_catalogs = self.use_custom_catalogs + images = ModelContainer(input) - if use_custom_catalogs: - catdict = _parse_catfile(self.catfile) - # if user requested the use of custom catalogs and provided a - # valid 'catfile' file name that has no custom catalogs, - # turn off the use of custom catalogs: - if catdict is not None and not catdict: - self.log.warning( - "'use_custom_catalogs' is set to True but 'catfile' " - "contains no user catalogs. Turning on built-in catalog " - "creation." - ) - use_custom_catalogs = False - - try: - if use_custom_catalogs and catdict: - images = ModelContainer() - if isinstance(input, str): - asn_dir = path.dirname(input) - asn_data = images.read_asn(input) - for member in asn_data['products'][0]['members']: - filename = member['expname'] - member['expname'] = path.join(asn_dir, filename) - if filename in catdict: - member['tweakreg_catalog'] = catdict[filename] - elif 'tweakreg_catalog' in member: - del member['tweakreg_catalog'] - - images.from_asn(asn_data) - - elif is_association(input): - images.from_asn(input) + if len(images) == 0: + raise ValueError("Input must contain at least one image model.") - else: - images = ModelContainer(input) - for im in images: - filename = im.meta.filename - if filename in catdict: - self.log.info( - f"setting meta.tweakreg_catalog of '{filename}' to {repr(catdict[filename])}" - ) - im.meta.tweakreg_catalog = catdict[filename] + # determine number of groups (used below) + n_groups = len(images.group_names) - else: - images = ModelContainer(input) + use_custom_catalogs = self.use_custom_catalogs - except TypeError as e: - e.args = ("Input to tweakreg must be a list of DataModels, an " - "association, or an already open ModelContainer " - "containing one or more DataModels.", ) + e.args[1:] - raise e + if self.use_custom_catalogs: + # first check catfile + if self.catfile.strip(): + catdict = _parse_catfile(self.catfile) + # if user requested the use of custom catalogs and provided a + # valid 'catfile' file name that has no custom catalogs, + # turn off the use of custom catalogs: + if not catdict: + self.log.warning( + "'use_custom_catalogs' is set to True but 'catfile' " + "contains no user catalogs. Turning on built-in catalog " + "creation." + ) + use_custom_catalogs = False + # else, load from association + elif hasattr(images.meta, "asn_table") and getattr(images, "asn_file_path", None) is not None: + catdict = {} + asn_dir = path.dirname(images.asn_file_path) + for member in images.meta.asn_table.products[0].members: + if hasattr(member, "tweakreg_catalog"): + if member.tweakreg_catalog is None or not member.tweakreg_catalog.strip(): + catdict[member.expname] = None + else: + catdict[member.expname] = path.join(asn_dir, member.tweakreg_catalog) if self.abs_refcat is not None and self.abs_refcat.strip(): align_to_abs_refcat = True @@ -170,72 +165,61 @@ def process(self, input): else: align_to_abs_refcat = False - if len(images) == 0: - raise ValueError("Input must contain at least one image model.") - - rel_outcomes = set() - - # Build the catalogs for input images - for image_model in images: - if use_custom_catalogs and image_model.meta.tweakreg_catalog: + # since we're not aligning to a reference catalog, check if we + # are saving catalogs, if not, and we have 1 group, skip + if not self.save_catalogs and n_groups == 1: + # we need at least two exposures to perform image alignment + self.log.warning("At least two exposures are required for image " + "alignment.") + self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") + self.skip = True + for model in images: + model.meta.cal_step.tweakreg = "SKIPPED" + return input + + # === start processing images === + + # pre-allocate collectors (same length and order as images) + correctors = [None] * len(images) + + # Build the catalog and corrector for each input images + for (model_index, image_model) in enumerate(images): + # now that the model is open, check it's metadata for a custom catalog + # only if it's not listed in the catdict + if use_custom_catalogs and image_model.meta.filename not in catdict: + if (image_model.meta.tweakreg_catalog is not None and image_model.meta.tweakreg_catalog.strip()): + catdict[image_model.meta.filename] = image_model.meta.tweakreg_catalog + if use_custom_catalogs and catdict.get(image_model.meta.filename, None) is not None: + # FIXME this modifies the input_model + image_model.meta.tweakreg_catalog = catdict[image_model.meta.filename] # use user-supplied catalog: self.log.info("Using user-provided input catalog " f"'{image_model.meta.tweakreg_catalog}'") - catalog = Table.read(image_model.meta.tweakreg_catalog) - new_cat = False - + catalog = Table.read( + image_model.meta.tweakreg_catalog, + ) + save_catalog = False else: # source finding - starfinder_kwargs = { - 'fwhm': self.kernel_fwhm, - 'sigma_radius': self.sigma_radius, - 'minsep_fwhm': self.minsep_fwhm, - 'sharplo': self.sharplo, - 'sharphi': self.sharphi, - 'roundlo': self.roundlo, - 'roundhi': self.roundhi, - 'peakmax': self.peakmax, - 'brightest': self.brightest, - 'npixels': self.npixels, - 'connectivity': int(self.connectivity), # option returns a string, so cast to int - 'nlevels': self.nlevels, - 'contrast': self.contrast, - 'mode': self.multithresh_mode, - 'error': image_model.err, - 'localbkg_width': self.localbkg_width, - 'apermask_method': self.apermask_method, - 'kron_params': self.kron_params, - } - - catalog = make_tweakreg_catalog( - image_model, self.snr_threshold, - starfinder=self.starfinder, - bkg_boxsize=self.bkg_boxsize, - starfinder_kwargs=starfinder_kwargs, - ) - new_cat = True + catalog = self._find_sources(image_model) - for axis in ['x', 'y']: - if axis not in catalog.colnames: - long_axis = axis + 'centroid' - if long_axis in catalog.colnames: - catalog.rename_column(long_axis, axis) - else: - raise ValueError( - "'tweakreg' source catalogs must contain either " - "columns 'x' and 'y' or 'xcentroid' and " - "'ycentroid'." - ) + # only save if catalog was computed from _find_sources and + # the user requested save_catalogs + save_catalog = self.save_catalogs + + # if needed rename xcentroid to x, ycentroid to y + catalog = _rename_catalog_columns(catalog) - # filter out sources outside the WCS bounding box - bb = image_model.meta.wcs.bounding_box - if bb is not None: - ((xmin, xmax), (ymin, ymax)) = bb - x = catalog['x'] - y = catalog['y'] - mask = (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) - catalog = catalog[mask] + # filter all sources outside the wcs bounding box + catalog = _filter_catalog_by_bounding_box( + catalog, + image_model.meta.wcs.bounding_box) + # setting 'name' is important for tweakwcs logging + if catalog.meta.get('name') is None: + catalog.meta['name'] = path.splitext(image_model.meta.filename)[0].strip('_- ') + + # log results of source finding (or user catalog) filename = image_model.meta.filename nsources = len(catalog) if nsources == 0: @@ -244,70 +228,24 @@ def process(self, input): self.log.info('Detected {} sources in {}.' .format(len(catalog), filename)) - if new_cat and self.save_catalogs: - image_model = self._write_catalog(image_model, catalog, filename) - - # Temporarily attach catalog to the image model so that it follows - # the grouping by exposure, to be removed after use below - image_model.catalog = catalog + # save catalog (if requested) + if save_catalog: + # FIXME this modifies the input_model + image_model.meta.tweakreg_catalog = self._write_catalog(catalog, filename) - # group images by their "group id": - grp_img = list(images.models_grouped) + # construct the corrector since the model is open (and already has a group_id) + correctors[model_index] = _construct_wcs_corrector(image_model, catalog) self.log.info('') self.log.info("Number of image groups to be aligned: {:d}." - .format(len(grp_img))) - self.log.info("Image groups:") - - if len(grp_img) == 1 and not align_to_abs_refcat: - self.log.info("* Images in GROUP 1:") - for im in grp_img[0]: - self.log.info(" {}".format(im.meta.filename)) - self.log.info('') - - # we need at least two exposures to perform image alignment - self.log.warning("At least two exposures are required for image " - "alignment.") - self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") - self.skip = True - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - # Remove the attached catalogs - del model.catalog - return input - - elif len(grp_img) == 1 and align_to_abs_refcat: - # create a list of WCS-Catalog-Images Info and/or their Groups: - g = grp_img[0] - if len(g) == 0: - raise AssertionError("Logical error in the pipeline code.") - imcats = list(map(self._imodel2wcsim, g)) - # Remove the attached catalogs - for model in g: - del model.catalog - self.log.info(f"* Images in GROUP '{imcats[0].meta['group_id']}':") - for im in imcats: - self.log.info(f" {im.meta['name']}") - - self.log.info('') - - elif len(grp_img) > 1: - # create a list of WCS-Catalog-Images Info and/or their Groups: - imcats = [] - for g in grp_img: - if len(g) == 0: - raise AssertionError("Logical error in the pipeline code.") - else: - wcsimlist = list(map(self._imodel2wcsim, g)) - # Remove the attached catalogs - for model in g: - del model.catalog - self.log.info(f"* Images in GROUP '{wcsimlist[0].meta['group_id']}':") - for im in wcsimlist: - self.log.info(f" {im.meta['name']}") - imcats.extend(wcsimlist) + .format(n_groups)) - self.log.info('') + # keep track of if 'local' alignment failed, even if this + # fails, absolute alignment might be run (if so configured) + local_align_failed = False + + # if we have >1 group of images, align them to each other + if n_groups > 1: # align images: xyxymatch = XYXYMatch( @@ -321,7 +259,7 @@ def process(self, input): try: align_wcs( - imcats, + correctors, refcat=None, enforce_user_order=self.enforce_user_order, expand_refcat=self.expand_refcat, @@ -346,6 +284,7 @@ def process(self, input): if not align_to_abs_refcat: self.skip = True return images + local_align_failed = True else: raise e @@ -365,28 +304,20 @@ def process(self, input): else: raise e - for imcat in imcats: - model = imcat.meta['image_model'] - rel_outcomes.add(model.meta.cal_step.tweakreg) - if model.meta.cal_step.tweakreg == "SKIPPED": - continue - wcs = model.meta.wcs - twcs = imcat.wcs - if not self._is_wcs_correction_small(wcs, twcs): - # Large corrections are typically a result of source - # mis-matching or poorly-conditioned fit. Skip such models. - self.log.warning(f"WCS has been tweaked by more than {10 * self.tolerance} arcsec") - + if not local_align_failed and not self._is_wcs_correction_small(correctors): + if align_to_abs_refcat: + self.log.warning("Skipping relative alignment (stage 1)...") + else: + self.log.warning("Skipping 'TweakRegStep'...") + self.skip = True for model in images: model.meta.cal_step.tweakreg = "SKIPPED" - if align_to_abs_refcat: - self.log.warning("Skipping relative alignment (stage 1)...") - else: - self.log.warning("Skipping 'TweakRegStep'...") - self.skip = True - return images + return images if align_to_abs_refcat: + # now, align things to the reference catalog + # this can occur after alignment between groups (only if >1 group) + # Get catalog of GAIA sources for the field # # NOTE: If desired, the pipeline can write out the reference @@ -402,11 +333,6 @@ def process(self, input): else: output_name = None - rel_ok = ( - len(rel_outcomes) > 1 or - (rel_outcomes and rel_outcomes.pop() != "SKIPPED") - ) - # initial shift to be used with absolute astrometry self.abs_xoffset = 0 self.abs_yoffset = 0 @@ -415,10 +341,24 @@ def process(self, input): gaia_cat_name = self.abs_refcat.upper() if gaia_cat_name in SINGLE_GROUP_REFCAT: - ref_cat = amutils.create_astrometric_catalog( - images, + ref_model = images[0] + + epoch = Time(ref_model.meta.observation.date).decimalyear + + # combine all aligned wcs to compute a new footprint to + # filter the absolute catalog sources + combined_wcs = wcs_from_footprints( + None, + refmodel=ref_model, + wcslist=[corrector.wcs for corrector in correctors], + ) + + ref_cat = create_astrometric_catalog( + None, gaia_cat_name, - output=output_name + existing_wcs=combined_wcs, + output=output_name, + epoch=epoch, ) elif path.isfile(self.abs_refcat): @@ -457,16 +397,16 @@ def process(self, input): # easy to recognize when alignment to GAIA was being performed # as opposed to the group_id values used for relative alignment # earlier in this step. - for imcat in imcats: - imcat.meta['group_id'] = 987654 - if ('fit_info' in imcat.meta and - 'REFERENCE' in imcat.meta['fit_info']['status']): - del imcat.meta['fit_info'] + for corrector in correctors: + corrector.meta['group_id'] = 987654 + if ('fit_info' in corrector.meta and + 'REFERENCE' in corrector.meta['fit_info']['status']): + del corrector.meta['fit_info'] # Perform fit try: align_wcs( - imcats, + correctors, refcat=ref_cat, enforce_user_order=True, expand_refcat=False, @@ -487,7 +427,7 @@ def process(self, input): "to an absolute reference catalog. Alignment to an " "absolute reference catalog will not be performed." ) - if not rel_ok: + if local_align_failed or n_groups == 1: self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") for model in images: model.meta.cal_step.tweakreg = "SKIPPED" @@ -508,7 +448,7 @@ def process(self, input): "Alignment to an absolute reference catalog will " "not be performed." ) - if not rel_ok: + if local_align_failed or n_groups == 1: self.log.warning("Skipping 'TweakRegStep'...") self.skip = True for model in images: @@ -517,13 +457,14 @@ def process(self, input): else: raise e - for imcat in imcats: - image_model = imcat.meta['image_model'] + # one final pass through all the models to update them based + # on the results of this step + for (image_model, corrector) in zip(images, correctors): image_model.meta.cal_step.tweakreg = 'COMPLETE' # retrieve fit status and update wcs if fit is successful: - if ('fit_info' in imcat.meta and - 'SUCCESS' in imcat.meta['fit_info']['status']): + if ('fit_info' in corrector.meta and + 'SUCCESS' in corrector.meta['fit_info']['status']): # Update/create the WCS .name attribute with information # on this astrometric fit as the only record that it was @@ -536,9 +477,9 @@ def process(self, input): # translated to the FITS WCSNAME keyword # IF that is what gets recorded in the archive # for end-user searches. - imcat.wcs.name = "FIT-LVL3-{}".format(self.abs_refcat) + corrector.wcs.name = "FIT-LVL3-{}".format(self.abs_refcat) - image_model.meta.wcs = imcat.wcs + image_model.meta.wcs = corrector.wcs update_s_region_imaging(image_model) # Also update FITS representation in input exposures for @@ -557,15 +498,13 @@ def process(self, input): return images - def _write_catalog(self, image_model, catalog, filename): + def _write_catalog(self, catalog, filename): ''' Determine output filename for catalog based on outfile for step and output dir, then write catalog to file. Parameters ---------- - image_model : jwst.datamodels.ImageModel - Image model containing the source catalog. catalog : astropy.table.Table Table containing the source catalog. filename : str @@ -573,8 +512,8 @@ def _write_catalog(self, image_model, catalog, filename): Returns ------- - image_model : jwst.datamodels.ImageModel - Image model with updated catalog information. + catalog_filename : str + Filename where the catalog was saved ''' catalog_filename = str(filename).replace( @@ -600,67 +539,75 @@ def _write_catalog(self, image_model, catalog, filename): ) self.log.info('Wrote source catalog: {}' .format(catalog_filename)) - image_model.meta.tweakreg_catalog = catalog_filename - - return image_model + return catalog_filename + + def _find_sources(self, image_model): + # source finding + starfinder_kwargs = { + 'fwhm': self.kernel_fwhm, + 'sigma_radius': self.sigma_radius, + 'minsep_fwhm': self.minsep_fwhm, + 'sharplo': self.sharplo, + 'sharphi': self.sharphi, + 'roundlo': self.roundlo, + 'roundhi': self.roundhi, + 'peakmax': self.peakmax, + 'brightest': self.brightest, + 'npixels': self.npixels, + 'connectivity': int(self.connectivity), # option returns a string, so cast to int + 'nlevels': self.nlevels, + 'contrast': self.contrast, + 'mode': self.multithresh_mode, + 'error': image_model.err, + 'localbkg_width': self.localbkg_width, + 'apermask_method': self.apermask_method, + 'kron_params': self.kron_params, + } + + return make_tweakreg_catalog( + image_model, self.snr_threshold, + starfinder=self.starfinder, + bkg_boxsize=self.bkg_boxsize, + starfinder_kwargs=starfinder_kwargs, + ) - def _is_wcs_correction_small(self, wcs, twcs): - """Check that the newly tweaked wcs hasn't gone off the rails""" + def _is_wcs_correction_small(self, correctors): + # check for a small wcs correction, it should be small if self.use2dhist: max_corr = 2 * (self.searchrad + self.tolerance) * u.arcsec else: max_corr = 2 * (max(abs(self.xoffset), abs(self.yoffset)) + self.tolerance) * u.arcsec - - ra, dec = wcs.footprint(axis_type="spatial").T - tra, tdec = twcs.footprint(axis_type="spatial").T - skycoord = SkyCoord(ra=ra, dec=dec, unit="deg") - tskycoord = SkyCoord(ra=tra, dec=tdec, unit="deg") - - separation = skycoord.separation(tskycoord) - - return (separation < max_corr).all() - - def _imodel2wcsim(self, image_model): - # make sure that we have a catalog: - if hasattr(image_model, 'catalog'): - catalog = image_model.catalog - else: - catalog = image_model.meta.tweakreg_catalog - - model_name = path.splitext(image_model.meta.filename)[0].strip('_- ') - - if isinstance(catalog, Table): - if not catalog.meta.get('name', None): - catalog.meta['name'] = model_name - - else: - try: - cat_name = str(catalog) - catalog = Table.read(catalog, format='ascii.ecsv') - catalog.meta['name'] = cat_name - except IOError: - self.log.error("Cannot read catalog {}".format(catalog)) - - # create WCSImageCatalog object: - refang = image_model.meta.wcsinfo.instance - im = JWSTWCSCorrector( - wcs=image_model.meta.wcs, - wcsinfo={'roll_ref': refang['roll_ref'], - 'v2_ref': refang['v2_ref'], - 'v3_ref': refang['v3_ref']}, - meta={ - 'image_model': image_model, - 'catalog': catalog, - 'name': model_name, - 'group_id': image_model.meta.group_id, - } - ) - - return im + for corrector in correctors: + aligned_skycoord = _wcs_to_skycoord(corrector.wcs) + original_skycoord = corrector.meta['original_skycoord'] + separation = original_skycoord.separation(aligned_skycoord) + if not (separation < max_corr).all(): + # Large corrections are typically a result of source + # mis-matching or poorly-conditioned fit. Skip such models. + self.log.warning(f"WCS has been tweaked by more than {10 * self.tolerance} arcsec") + return False + return True def _parse_catfile(catfile): + """ + Parse a text file containing at 2 whitespace-delimited columns + column 1: str, datamodel filename + column 2: str, catalog filename + into a dictionary with datamodel filename keys and catalog filename + values. The catalog filenames will become paths relative + to the current working directory. So for a catalog filename + "mycat.ecsv" if the catfile is in a subdirectory "my_data" + the catalog filename will be "my_data/mycat.ecsv". + + Returns: + - None of catfile is None (or an empty string) + - empty dict if catfile is empty + + Raises: + VaueError if catfile contains >2 columns + """ if catfile is None or not catfile.strip(): return None @@ -679,8 +626,61 @@ def _parse_catfile(catfile): if len(catalog) == 1: catdict[data_model] = path.join(catfile_dir, catalog[0]) elif len(catalog) == 0: + # set this to None so it's custom catalog is skipped catdict[data_model] = None else: raise ValueError("'catfile' can contain at most two columns.") return catdict + + +def _rename_catalog_columns(catalog): + for axis in ['x', 'y']: + if axis not in catalog.colnames: + long_axis = axis + 'centroid' + if long_axis in catalog.colnames: + catalog.rename_column(long_axis, axis) + else: + raise ValueError( + "'tweakreg' source catalogs must contain either " + "columns 'x' and 'y' or 'xcentroid' and " + "'ycentroid'." + ) + return catalog + + +def _filter_catalog_by_bounding_box(catalog, bounding_box): + if bounding_box is None: + return catalog + + # filter out sources outside the WCS bounding box + ((xmin, xmax), (ymin, ymax)) = bounding_box + x = catalog['x'] + y = catalog['y'] + mask = (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) + return catalog[mask] + + +def _wcs_to_skycoord(wcs): + ra, dec = wcs.footprint(axis_type="spatial").T + return SkyCoord(ra=ra, dec=dec, unit="deg") + + +def _construct_wcs_corrector(image_model, catalog): + # pre-compute skycoord here so we can later use it + # to check for a small wcs correction + wcs = image_model.meta.wcs + refang = image_model.meta.wcsinfo.instance + return JWSTWCSCorrector( + wcs=image_model.meta.wcs, + wcsinfo={'roll_ref': refang['roll_ref'], + 'v2_ref': refang['v2_ref'], + 'v3_ref': refang['v3_ref']}, + # catalog and group_id are required meta + meta={ + 'catalog': catalog, + 'name': catalog.meta.get('name'), + 'group_id': image_model.meta.group_id, + 'original_skycoord': _wcs_to_skycoord(wcs), + } + )