Skip to content

Commit

Permalink
JP-3499: Handle too much masked data for NIRISS SOSS extractions in e…
Browse files Browse the repository at this point in the history
…xtract_1d (#8265)

Co-authored-by: James Davies <jdavies@mpia.de>
Co-authored-by: Maria <penaguerrero@users.noreply.github.com>
Co-authored-by: David Law <dlaw@stsci.edu>
Co-authored-by: Howard Bushouse <bushouse@stsci.edu>
  • Loading branch information
5 people committed Feb 15, 2024
1 parent 8f723b7 commit f9ee282
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ extract_1d
- Added a hook to bypass the ``extract_1d`` step for NIRISS SOSS data in
the FULL subarray with warning. [#8225]

- Added a trap in the NIRISS SOSS ATOCA algorithm for cases where nearly all
pixels in the 2nd-order spectrum are flagged and would cause the step
to fail. [#8265]

extract_2d
----------

Expand Down
18 changes: 15 additions & 3 deletions jwst/extract_1d/soss_extract/atoca.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
log.setLevel(logging.DEBUG)


class MaskOverlapError(Exception):

def __init__(self, message):
self.message = message
super().__init__(self.message)


class _BaseOverlap:
"""Base class for the ATOCA algorithm (Darveau-Bernier 2021, in prep).
Used to perform an overlapping extraction of the form:
Expand Down Expand Up @@ -181,6 +188,12 @@ def __init__(self, wave_map, trace_profile, throughput, kernels,
# First estimate of a global mask and masks for each orders
self.mask, self.mask_ord = self._get_masks(global_mask)

# Ensure there are adequate good pixels left in each order
good_pixels_in_order = np.sum(np.sum(~self.mask_ord, axis=-1), axis=-1)
min_good_pixels = 25 # hard-code to qualitatively reasonable value
if np.any(good_pixels_in_order < min_good_pixels):
raise MaskOverlapError('At least one order has no valid pixels (mask_trace_profile and mask_wave do not overlap)')

# Correct i_bounds if it was not specified
self.i_bounds = self._get_i_bnds(wave_bounds)

Expand Down Expand Up @@ -408,9 +421,9 @@ def _get_masks(self, global_mask):
"""

# Get needed attributes
args = ('threshold', 'n_orders', 'throughput', 'mask_trace_profile', 'wave_map', 'trace_profile')
args = ('threshold', 'n_orders', 'mask_trace_profile', 'trace_profile')
needed_attr = self.get_attributes(*args)
threshold, n_orders, throughput, mask_trace_profile, wave_map, trace_profile = needed_attr
threshold, n_orders, mask_trace_profile, trace_profile = needed_attr

# Convert list to array (easier for coding)
mask_trace_profile = np.array(mask_trace_profile)
Expand Down Expand Up @@ -1647,7 +1660,6 @@ def get_w(self, i_order):

# Use the convolved grid (depends on the order)
wave_grid = wave_grid[i_bnds[0]:i_bnds[1]]

# Compute the wavelength coverage of the grid
d_grid = np.diff(wave_grid)

Expand Down
47 changes: 42 additions & 5 deletions jwst/extract_1d/soss_extract/soss_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .soss_syscor import make_background_mask, soss_background
from .soss_solver import solve_transform, transform_wavemap, transform_profile, transform_coords
from .atoca import ExtractionEngine
from .atoca import ExtractionEngine, MaskOverlapError
from .atoca_utils import (ThroughputSOSS, WebbKernel, grid_from_map, mask_bad_dispersion_direction,
make_combined_adaptive_grid, get_wave_p_or_m, oversample_grid)
from .soss_boxextract import get_box_weights, box_extract, estim_error_nearest_data
Expand Down Expand Up @@ -463,6 +463,35 @@ def _build_tracemodel_order(engine, ref_file_args, f_k, i_order, mask, ref_files
return tracemodel_ord, spec_ord


def _build_null_spec_table(wave_grid):
"""
Build a SpecModel of entirely bad values
Parameters
----------
wave_grid : np.array
Input wavelengths
Returns
-------
spec : SpecModel
Null SpecModel. Flux values are NaN, DQ flags are 1,
but note that DQ gets overwritten at end of run_extract1d
"""
wave_grid_cut = wave_grid[wave_grid > 0.58] # same cutoff applied for valid data
spec = datamodels.SpecModel()
spec.spectral_order = 2
spec.meta.soss_extract1d.type = 'OBSERVATION'
spec.meta.soss_extract1d.factor = np.nan
spec.spec_table = np.zeros((wave_grid_cut.size,), dtype=datamodels.SpecModel().spec_table.dtype)
spec.spec_table['WAVELENGTH'] = wave_grid_cut
spec.spec_table['FLUX'] = np.empty(wave_grid_cut.size) * np.nan
spec.spec_table['DQ'] = np.ones(wave_grid_cut.size)
spec.validate()

return spec


def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, subarray, transform=None,
tikfac=None, threshold=1e-4, n_os=2, wave_grid=None,
estimate=None, rtol=1e-3, max_grid_size=1000000):
Expand Down Expand Up @@ -689,10 +718,16 @@ def model_image(scidata_bkg, scierr, scimask, refmask, ref_files, box_weights, s
tikfac_log_range = np.log10(tikfac) + np.array([-2, 8])

# Model the remaining part of order 2 with atoca
model, spec_ord = model_single_order(scidata_bkg, scierr, ref_file_order,
mask_fit, global_mask, order,
pixel_wave_grid, valid_cols, save_tiktests,
tikfac_log_range=tikfac_log_range)
try:
model, spec_ord = model_single_order(scidata_bkg, scierr, ref_file_order,
mask_fit, global_mask, order,
pixel_wave_grid, valid_cols, save_tiktests,
tikfac_log_range=tikfac_log_range)

except MaskOverlapError:
log.error('Not enough unmasked pixels to model the remaining part of order 2. Model and spectrum will be NaN in that spectral region.')
spec_ord = [_build_null_spec_table(pixel_wave_grid)]
model = np.nan * np.ones_like(scidata_bkg)

# Keep only pixels from which order 2 contribution
# is not already modeled.
Expand Down Expand Up @@ -1032,6 +1067,8 @@ def run_extract1d(input_model, spectrace_ref_name, wavemap_ref_name,
transform = soss_kwargs.pop('transform')
if transform is None:
transform = [None, None, None]
else:
transform = [float(val) for val in transform]
# Save names for logging
param_name = np.array(['theta', 'x-offset', 'y-offset'])

Expand Down
31 changes: 31 additions & 0 deletions jwst/regtest/test_niriss_soss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,34 @@ def test_niriss_soss_extras(rtdata_module, run_atoca_extras, fitsdiff_default_kw

diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs)
assert diff.identical, diff.report()


@pytest.mark.bigdata
@pytest.fixture(scope='module')
def run_extract1d_null_order2(jail, rtdata_module):
"""
Test coverage for fix to error thrown when all of the pixels
in order 2 are flagged as bad. Ensure graceful handling of the
MaskOverlapError exception raise.
Pin tikfac and transform for faster runtime
"""
rtdata = rtdata_module
rtdata.get_data("niriss/soss/jw01201008001_04101_00001-seg003_nis_int72.fits")
args = ["extract_1d", rtdata.input,
"--soss_tikfac=4.290665733550672e-17",
"--soss_transform=0.0794900761418923, -1.3197790951056494, -0.796875809148081",
]
Step.from_cmdline(args)


@pytest.mark.bigdata
def test_extract1d_null_order2(rtdata_module, run_extract1d_null_order2, fitsdiff_default_kwargs):
rtdata = rtdata_module

output = "jw01201008001_04101_00001-seg003_nis_int72_extract1dstep.fits"
rtdata.output = output

rtdata.get_truth(f"truth/test_niriss_soss_stages/{output}")

diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs)
assert diff.identical, diff.report()

0 comments on commit f9ee282

Please sign in to comment.