Skip to content

Commit

Permalink
hotfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Oct 10, 2022
1 parent e236d11 commit cc644cf
Show file tree
Hide file tree
Showing 9 changed files with 726 additions and 521 deletions.
264 changes: 189 additions & 75 deletions python/astra/contrib/aspcap/base.py

Large diffs are not rendered by default.

257 changes: 120 additions & 137 deletions python/astra/contrib/aspcap/continuum.py
@@ -1,3 +1,4 @@
from re import A
import numpy as np
from scipy.ndimage.filters import median_filter

Expand All @@ -7,164 +8,143 @@
from astra.contrib.ferre import bitmask
from astra.contrib.ferre.utils import read_ferre_headers
from astra.database.astradb import Task
from astra.tools.continuum.base import NormalizationBase
from astra.tools.continuum.base import Continuum
from astra.tools.spectrum import Spectrum1D
from astropy.io import fits

from typing import Optional, List, Tuple, Union
from astra.tools.spectrum import SpectralAxis

class MedianNormalizationWithErrorInflation(NormalizationBase):
class MedianFilter(Continuum):

"""
Continuum-normalize an input spectrum by the median flux value,
and inflate the errors due to skylines and bad pixels.
"""

parameter_names = ()
"""Use a median filter to represent the stellar continuum."""

def __init__(
self,
spectrum,
axis=1,
ivar_multiplier_for_sig_skyline=1e-4,
ivar_min=0,
ivar_max=40_000,
bad_pixel_flux=1e-4,
bad_pixel_ivar=1e-20,
upstream_task_id: int,
median_filter_width: Optional[int] = 151,
bad_minimum_flux: Optional[float] = 0.01,
non_finite_err_value: Optional[float] = 1e10,
valid_continuum_correction_range: Optional[Tuple[float]] = (0.1, 10.0),
mode: Optional[str] = "constant",
spectral_axis: Optional[SpectralAxis] = None,
regions: Optional[List[Tuple[float, float]]] = None,
mask: Optional[Union[str, np.array]] = None,
fill_value: Optional[Union[int, float]] = np.nan,
**kwargs
) -> None:
super().__init__(spectrum)
self.axis = axis
self.ivar_multiplier_for_sig_skyline = ivar_multiplier_for_sig_skyline
self.ivar_min = ivar_min
self.ivar_max = ivar_max
self.bad_pixel_flux = bad_pixel_flux
self.bad_pixel_ivar = bad_pixel_ivar
(
"""
:param median_filter_width: [optional]
The width (int) for the median filter (default: 151).
def __call__(self):
:param bad_minimum_flux: [optional]
The value at which to set pixels as bad and median filter over them. This should be a float,
or `None` to set no low-flux filtering (default: 0.01).
pixel_bit_mask = bitmask.PixelBitMask()
:param non_finite_err_value: [optional]
The error value to set for pixels with non-finite fluxes (default: 1e10).
# Normalize.
continuum = np.nanmedian(self.spectrum.flux.value, axis=self.axis).reshape(
(-1, 1)
:param valid_continuum_correction_range: [optional]
A (min, max) tuple of the bounds that the final correction can have. Values outside this range will be set
as 1.
"""
+ Continuum.__init__.__doc__
)
self.spectrum._data /= continuum
self.spectrum._uncertainty.array *= continuum**2

# Increase the error around significant skylines.
skyline_mask = (
self.spectrum.meta["bitmask"] & pixel_bit_mask.get_value("SIG_SKYLINE")
) > 0
self.spectrum._uncertainty.array[
skyline_mask
] *= self.ivar_multiplier_for_sig_skyline

# Set bad pixels to have no useful data.
bad = (
~np.isfinite(self.spectrum.flux.value)
| ~np.isfinite(self.spectrum.uncertainty.array)
| (self.spectrum.flux.value < 0)
| (self.spectrum.uncertainty.array < 0)
| ((self.spectrum.meta["bitmask"] & pixel_bit_mask.get_level_value(1)) > 0)
super(MedianFilter, self).__init__(
spectral_axis=spectral_axis,
regions=regions,
mask=mask,
fill_value=fill_value,
**kwargs,
)

self.spectrum._data[bad] = self.bad_pixel_flux
self.spectrum._uncertainty.array[bad] = self.bad_pixel_ivar

# Ensure a minimum error.
# TODO: This seems like a pretty bad idea!
self.spectrum._uncertainty.array = np.clip(
self.spectrum._uncertainty.array, self.ivar_min, self.ivar_max
) # sigma = 5e-3

return self.spectrum


class MedianFilterNormalizationWithErrorInflation(
MedianNormalizationWithErrorInflation
):

parameter_names = ()

def __init__(
self,
spectrum,
median_filter_from_task,
segment_indices=None,
median_filter_width=151,
bad_minimum_flux=0.01,
non_finite_err_value=1e10,
valid_continuum_correction_range=(0.1, 10.0),
**kwargs,
) -> None:
super().__init__(spectrum, **kwargs)
self.median_filter_from_task = median_filter_from_task
self.segment_indices = segment_indices
self.upstream_task_id = upstream_task_id
self.median_filter_width = median_filter_width
self.bad_minimum_flux = bad_minimum_flux
self.non_finite_err_value = non_finite_err_value
self.valid_continuum_correction_range = valid_continuum_correction_range
self.mode = mode
return None

def __call__(self):

# Do standard median normalization.
spectrum = super().__call__()
def _initialize(self, spectrum, task=None):
try:
self._initialized_args
except AttributeError:
if self.regions is None:
# Get the regions from the model wavelength segments.
if task is None:
task = Task.get(self.upstream_task_id)

regions = []
for header in read_ferre_headers(expand_path(task.parameters["header_path"]))[1:]:
crval, cdelt = header["WAVE"]
npix = header["NPIX"]
regions.append((10**crval, 10**(crval + cdelt * npix)))
self.regions = regions

self._initialized_args = super(MedianFilter, self)._initialize(spectrum)
finally:
return self._initialized_args


def fit(self, spectrum: Spectrum1D, hdu=3):

task = Task.get(self.upstream_task_id)
region_slices, region_masks = self._initialize(spectrum, task)

#flux/continuum and model_flux
# before:
# (ferre_flux / continuum) and (model_flux)
# now:
# ferre_flux and (model_flux / continuum)
# ratio = continuum * (ferre_flux / model_flux)

# This is an astraStar-FERRE product, but let's just use fits.open
with fits.open(task.output_data_products[0].path) as image:
# How do we decide on the HDU for this product just from the spectrum?
continuum = image[hdu].data["CONTINUUM"]
flux = image[hdu].data["FERRE_FLUX"]
rectified_model_flux = image[hdu].data["MODEL_FLUX"] / continuum

N, P = flux.shape
self._continuum = np.nan * np.ones((N, P))
for i in range(N):
for region_mask in region_masks:
flux_region, model_flux_region = (flux[i, region_mask].copy(), rectified_model_flux[i, region_mask].copy())

# TODO: It's a little counter-intuitive how this is documented, so we should fix that.
# Or allow for a MAGIC number 5 instead.
median = median_filter(flux[i, region_mask], [self.median_filter_width], mode=self.mode)

bad = np.where(
(flux_region < self.bad_minimum_flux) | (flux_region > (np.nanmedian(flux_region) + 3 * np.nanstd(flux_region)))
)[0]
flux_region[bad] = median[bad]

ratio_region = flux_region / model_flux_region
self._continuum[i, region_mask] = median_filter(
ratio_region,
[self.median_filter_width],
mode=self.mode,
cval=1.0
)

if not isinstance(self.median_filter_from_task, Task):
median_filter_from_task = Task.get_by_id(int(self.median_filter_from_task))
else:
median_filter_from_task = self.median_filter_from_task
scalars = np.nanmedian(spectrum.flux.value / self._continuum, axis=1)
self._continuum *= scalars
return None


def __call__(
self,
spectrum: Spectrum1D,
theta: Optional[Union[List, np.array, Tuple]] = None,
**kwargs
) -> np.ndarray:
if theta is not None:
log.warning(f"Continuum coefficients ignored here")
return self._continuum

# Need number of pixels from header
n_pixels = np.array(
[
header["NPIX"]
for header in read_ferre_headers(
expand_path(median_filter_from_task.parameters["header_path"])
)
][1:]
)
_ = n_pixels.cumsum()
segment_indices = np.sort(np.hstack([[0], _, _]))[:-1].reshape((-1, 2))

continuum = []
for output_data_product in median_filter_from_task.output_data_products:
with open(output_data_product.path, "rb") as fp:
output = pickle.load(fp)

kwds = dict(
wavelength=output["data"]["wavelength"],
normalised_observed_flux=output["data"]["flux"]
/ output["data"]["continuum"],
normalised_observed_flux_err=output["data"]["flux_sigma"]
/ output["data"]["continuum"],
normalised_model_flux=output["data"]["model_flux"],
segment_indices=segment_indices,
median_filter_width=self.median_filter_width,
bad_minimum_flux=self.bad_minimum_flux,
non_finite_err_value=self.non_finite_err_value,
valid_continuum_correction_range=self.valid_continuum_correction_range,
)
continuum.append(median_filtered_correction(**kwds))

continuum = np.array(continuum)

# Construct mask to match FERRE model grid.
N, P = spectrum.flux.shape
mask = np.zeros(P, dtype=bool)
for si, ei in segment_indices:
# TODO: Building wavelength mask off just the last wavelength array.
# We are assuming all have the same wavelength array.
s_index, e_index = spectrum.wavelength.value.searchsorted(
output["data"]["wavelength"][si:ei][[0, -1]]
)
mask[s_index : e_index + 1] = True

continuum_unmasked = np.nan * np.ones((N, P))
continuum_unmasked[:, mask] = np.array(continuum)

spectrum._data /= continuum_unmasked
spectrum._uncertainty.array *= continuum_unmasked * continuum_unmasked

return spectrum


def median_filtered_correction(
Expand Down Expand Up @@ -302,3 +282,6 @@ def median_filtered_correction(
continuum[bad] = 1

return continuum


MedianFilterNormalizationWithErrorInflation = MedianFilter

0 comments on commit cc644cf

Please sign in to comment.