From cc644cf4e0916dd2a15d6060408da2189cf32b47 Mon Sep 17 00:00:00 2001 From: Andy Casey Date: Mon, 10 Oct 2022 05:09:20 -0600 Subject: [PATCH] hotfixes --- python/astra/contrib/aspcap/base.py | 264 +++++++--- python/astra/contrib/aspcap/continuum.py | 257 +++++----- python/astra/contrib/ferre/base.py | 579 +++++++++++----------- python/astra/contrib/ferre/bitmask.py | 49 +- python/astra/contrib/slam/base.py | 45 +- python/astra/database/astradb.py | 25 +- python/astra/operators/slurm.py | 10 +- python/astra/sdss/datamodels/pipeline.py | 8 +- python/astra/tools/continuum/chebyshev.py | 10 +- 9 files changed, 726 insertions(+), 521 deletions(-) diff --git a/python/astra/contrib/aspcap/base.py b/python/astra/contrib/aspcap/base.py index 7ce67e2..f054a68 100644 --- a/python/astra/contrib/aspcap/base.py +++ b/python/astra/contrib/aspcap/base.py @@ -1,27 +1,33 @@ + import os -from re import findall -import numpy as np import json +import numpy as np from astra.contrib.aspcap import continuum, utils -from astra.tools.continuum.base import NormalizationBase +from astra.tools.continuum.base import Continuum +from astra.tools.continuum.scalar import Scalar from typing import Union, List, Tuple, Optional, Callable -from tqdm import tqdm from astra import log, __version__ -from astra.base import ExecutableTask -from astra.utils import deserialize, expand_path, serialize_executable +from astra.base import TaskInstance, TupleParameter +from astra.utils import bundler, deserialize, expand_path, serialize_executable from astra.database.astradb import ( + database, FerreOutput, + AspcapOutput, + Output, + TaskOutput, Source, Task, + Bundle, + TaskBundle, DataProduct, TaskInputDataProducts, ) -from astra.database.apogee_drpdb import Star from astra.contrib.ferre.base import Ferre from astra.contrib.ferre.utils import read_ferre_headers, sanitise from astra.contrib.ferre.bitmask import ParamBitMask + FERRE_TASK_NAME = serialize_executable(Ferre) FERRE_DEFAULTS = Ferre.get_defaults() @@ -29,7 +35,7 @@ def initial_guess_doppler( data_product: DataProduct, source: Optional[Source] = None, - star: Optional[Star] = None, + star = None, ) -> dict: """ Return an initial guess for FERRE from Doppler given a data product. @@ -44,6 +50,7 @@ def initial_guess_doppler( The associated Star in the APOGEE DRP database for this data product. """ if star is None: + from astra.database.apogee_drpdb import Star if source is None: (source,) = data_product.sources @@ -55,9 +62,11 @@ def initial_guess_doppler( .order_by(Star.created.desc()) .first() ) + if star is None: + return None return dict( - telescope=data_product.kwargs["telescope"], + telescope=star.telescope, mean_fiber=int(star.meanfib), teff=np.round(star.rv_teff, 0), logg=np.round(star.rv_logg, 3), @@ -69,29 +78,98 @@ def initial_guess_doppler( o_mg_si_s_ca_ti=0, ) +def initial_guess_apogeenet(data_product: DataProduct, star) -> dict: + """ + Return an initial guess for FERRE from APOGEENet given a data product. + + :param data_product: + The data prodcut to be analyzed with FERRE. + """ + + from astra.database.astradb import ApogeeNetOutput, TaskInputDataProducts, Task + + q = ( + ApogeeNetOutput + .select() + .join(Task, on=(ApogeeNetOutput.task_id == Task.id)) + .join(TaskInputDataProducts) + .where(TaskInputDataProducts.data_product_id == data_product.id) + .where(Task.name == "astra.contrib.apogeenet.StellarParameters") + .order_by(ApogeeNetOutput.snr.desc()) + ) + + output = q.first() + if output is None: + return None + + teff, logg, metals = (output.teff, output.logg, output.fe_h) + + if star is None: + print(f"WARNING: TODO: Fix this andy") + return None + + return dict( + telescope=star.telescope, + mean_fiber=int(star.meanfib), + teff=np.round(teff, 0), + logg=np.round(logg, 3), + metals=np.round(metals, 3), + log10vdop=utils.approximate_log10_microturbulence(logg), + lgvsini=1.0, + c=0, + n=0, + o_mg_si_s_ca_ti=0, + ) + def initial_guesses(data_product: DataProduct) -> List[dict]: - """Return initial guesses for FERRE given a data product.""" - return [initial_guess_doppler(data_product)] + """ + Return initial guesses for FERRE given a data product. + + :param data_product: + The data product containing 1D spectra for a source. + """ + # TODO: get defaults from Star (telescope, mean_fiber, etc) in a not-so-clumsy way + + from astra.database.apogee_drpdb import Star + (source,) = data_product.sources + + # Be sure to get the record of the Star with the latest and greatest stack, + # and latest and greatest set of Doppler values. + star = ( + Star.select() + .where(Star.catalogid == source.catalogid) + .order_by(Star.created.desc()) + .first() + ) + + try: + int(star.meanfib) + except: + return [] + + # TODO: Add estimates from other pipelines? Gaia? + return [ + initial_guess_doppler(data_product, star=star), + initial_guess_apogeenet(data_product, star=star) + ] def create_initial_stellar_parameter_tasks( input_data_products, - header_paths: Union[List[str], Tuple[str], str], + header_paths: Optional[Union[List[str], Tuple[str], str]] = "$MWM_ASTRA/component_data/aspcap/synspec_dr17_marcs_header_paths.list", weight_path: Optional[str] = "$MWM_ASTRA/component_data/aspcap/global_mask_v02.txt", - normalization_method: Optional[ - Union[NormalizationBase, str] - ] = continuum.MedianNormalizationWithErrorInflation, - slice_args: Optional[List[Tuple[int]]] = [(0, 1)], + continuum_method: Optional[Union[Continuum, str]] = Scalar, + continuum_kwargs: Optional[dict] = dict(method="median"), + data_slice: Optional[List[Tuple[int]]] = [(0, 1)], initial_guess_callable: Optional[Callable] = None, - as_primary_keys: bool = False, **kwargs, -) -> List[Task]: +) -> List[Union[Task, int]]: """ - Create tasks that will use FERRE to estimate the stellar parameters given the stacked spectrum in an ApStar data product. + Create tasks that will use FERRE to estimate the stellar parameters given a data product. :param input_data_products: - The input (ApStar) data products, or primary keys for those data products. + The input data products, or primary keys for those data products. :param header_paths: A list of FERRE header path files, or a path to a file that has one FERRE header path per line. @@ -99,14 +177,14 @@ def create_initial_stellar_parameter_tasks( :param weight_path: [optional] The weights path to supply to FERRE. By default this is set to the global mask used by SDSS. - :param normalization_method: [optional] + :param continuum_method: [optional] The method to use for continuum normalization before FERRE is executed. By default this is set to - :param slice_args: [optional] - Slice the input spectra and only analyze those rows that meet the slice. Because this is the initial - round of stellar parameter determination, by default we only take the highest S/N spectrum (i.e., the - first spectrum in each ApStar data product). - + :param data_slice: [optional] + Slice the input spectra and only analyze those rows that meet the slice. This is only relevant for ApStar + input data products, where the first spectrum represents the stacked spectrum. The parmaeter is ignored + for all other input data products. + :param initial_guess_callable: [optional] A callable function that takes in a data product and returns a list of dictionaries of initial guesses. Each dictionary should contain at least the following keys: @@ -122,9 +200,6 @@ def create_initial_stellar_parameter_tasks( - o_mg_si_s_ca_ti If the callable cannot supply an initial guess for a data product, it should return None instead of a dict. - - :param as_primary_keys: [optional] - Return a list of primary keys instead of tasks. """ log.debug(f"Data products {type(input_data_products)}: {input_data_products}") @@ -141,8 +216,8 @@ def create_initial_stellar_parameter_tasks( with open(os.path.expandvars(os.path.expanduser(header_paths)), "r") as fp: header_paths = [line.strip() for line in fp] - if normalization_method is not None: - normalization_method = serialize_executable(normalization_method) + if continuum_method is not None: + continuum_method = serialize_executable(continuum_method) grid_info = utils.parse_grid_information(header_paths) @@ -153,7 +228,7 @@ def create_initial_stellar_parameter_tasks( round = lambda _, d=3: np.round(_, d).astype(float) # For each (data product, initial guess) permutation we need to create tasks based on suitable grids. - tasks = [] + task_data_products = [] for data_product in input_data_products: for initial_guess in initial_guess_callable(data_product): if initial_guess is None: @@ -171,8 +246,9 @@ def create_initial_stellar_parameter_tasks( kwds = dict( header_path=header_path, weight_path=weight_path, - normalization_method=normalization_method, - slice_args=slice_args, + continuum_method=continuum_method, + continuum_kwargs=continuum_kwargs, + data_slice=data_slice, frozen_parameters=frozen_parameters, initial_parameters=dict( teff=round(initial_guess["teff"], 0), @@ -193,34 +269,69 @@ def create_initial_stellar_parameter_tasks( ) # Create a task. - task = Task.create( + task = Task( name=FERRE_TASK_NAME, version=__version__, parameters=parameters ) - TaskInputDataProducts.create(task=task, data_product=data_product) - tasks.append(task) + task_data_products.append((task, data_product)) - if as_primary_keys: - return [task.id for task in tasks] - return tasks + with database.atomic(): + Task.bulk_create([t for t, dp in task_data_products]) + TaskInputDataProducts.insert_many([ + { "task_id": t.id, "data_product_id": dp.id } for t, dp in task_data_products + ]).execute() + return [t for t, dp in task_data_products] -def create_stellar_parameter_tasks_from_best_initial_tasks( - initial_tasks, +def create_initial_stellar_parameter_task_bundles( + input_data_products, + header_paths: Optional[Union[List[str], Tuple[str], str]] = "$MWM_ASTRA/component_data/aspcap/synspec_dr17_marcs_header_paths.list", weight_path: Optional[str] = "$MWM_ASTRA/component_data/aspcap/global_mask_v02.txt", - normalization_method: Optional[ - Union[NormalizationBase, str] - ] = continuum.MedianFilterNormalizationWithErrorInflation, - normalization_kwds: Optional[dict] = None, - as_primary_keys: bool = False, + continuum_method: Optional[Union[Continuum, str]] = Scalar, + continuum_kwargs: Optional[dict] = dict(method="median"), + data_slice: Optional[List[Tuple[int]]] = [(0, 1)], + initial_guess_callable: Optional[Callable] = None, **kwargs, -) -> List[Task]: +) -> List[Union[Bundle, int]]: + + tasks = create_initial_stellar_parameter_tasks( + input_data_products=input_data_products, + header_paths=header_paths, + weight_path=weight_path, + continuum_method=continuum_method, + continuum_kwargs=continuum_kwargs, + data_slice=data_slice, + initial_guess_callable=initial_guess_callable, + **kwargs + ) + log.info(f"Created {len(tasks)} tasks") + + bundles = bundler(tasks) + log.info(f"Created {len(bundles)} bundles") + return [bundle.id for bundle in bundles] + + +def create_stellar_parameter_task_bundles( + initial_task_bundles, + weight_path: Optional[str] = "$MWM_ASTRA/component_data/aspcap/global_mask_v02.txt", + continuum_method: Optional[Union[Continuum, str]] = continuum.MedianFilter, + continuum_kwargs: Optional[dict] = None, + **kwargs, +): """ Create FERRE tasks to estimate stellar parameters, given the best result from the initial round of stellar parameters. """ - initial_tasks = deserialize(initial_tasks, Task) + #initial_task_bundles = deserialize(initial_task_bundles, Task) + if isinstance(initial_task_bundles, str): + initial_task_bundles = json.loads(initial_task_bundles) + q = ( + Task + .select() + .join(TaskBundle) + .where(TaskBundle.bundle_id.in_(initial_task_bundles)) + ) bitmask = ParamBitMask() bad_grid_edge = bitmask.get_value("GRIDEDGE_WARN") | bitmask.get_value( "GRIDEDGE_BAD" @@ -228,7 +339,7 @@ def create_stellar_parameter_tasks_from_best_initial_tasks( # Get all results per data product. results = {} - for task in initial_tasks: + for task in q: # TODO: Here we are assuming one data product per task, but it doesn't have to be this way. # It just makes it tricky if there are many data products + results per task, as we would # have to infer which result for which data product. @@ -299,26 +410,26 @@ def create_stellar_parameter_tasks_from_best_initial_tasks( f"\t{i:.0f}: \chi^2 = {log_chisq_fit:.3f} for task {task} and output {output}" ) - if normalization_method is not None: - normalization_method = serialize_executable(normalization_method) + if continuum_method is not None: + continuum_method = serialize_executable(continuum_method) - tasks = [] + task_data_products = [] for data_product_id, (result, *_) in results.items(): log_chisq_fit, task, output = result # For the normalization we will do a median filter correction using the previous result. - if normalization_method is not None: - _normalization_kwds = (normalization_kwds or {}).copy() - _normalization_kwds.update(median_filter_from_task=task.id) + if continuum_method is not None: + _continuum_kwargs = (continuum_kwargs or {}).copy() + _continuum_kwargs.update(upstream_task_id=task.id) else: - _normalization_kwds = FERRE_DEFAULTS["normalization_kwds"] + _continuum_kwargs = FERRE_DEFAULTS["continuum_kwargs"] parameters = FERRE_DEFAULTS.copy() parameters.update( header_path=task.parameters["header_path"], weight_path=weight_path, - normalization_method=normalization_method, - normalization_kwds=_normalization_kwds, + continuum_method=continuum_method, + continuum_kwargs=_continuum_kwargs, initial_parameters=dict( teff=output.teff, logg=output.logg, @@ -332,15 +443,26 @@ def create_stellar_parameter_tasks_from_best_initial_tasks( ) parameters.update({k: v for k, v in kwargs.items() if k in FERRE_DEFAULTS}) - task = Task.create( + task = Task( name=FERRE_TASK_NAME, version=__version__, parameters=parameters ) - TaskInputDataProducts.create(task=task, data_product_id=data_product_id) - tasks.append(task) + task_data_products.append((task, data_product_id)) + + with database.atomic(): + Task.bulk_create([t for t, _ in task_data_products]) + TaskInputDataProducts.insert_many([ + { "task_id": t.id, "data_product_id": dp_id } for t, dp_id in task_data_products + ]).execute() + + tasks = [t for t, _ in task_data_products] + + # Create bundles + log.info(f"Created {len(tasks)} tasks") + + bundles = bundler(tasks) + log.info(f"Created {len(bundles)} bundles") + return [bundle.id for bundle in bundles] - if as_primary_keys: - return [task.id for task in tasks] - return tasks def get_element(weight_path): @@ -429,6 +551,7 @@ def create_abundance_tasks( ) continue + print(f"TODO: insert tasks in bulk instead") abundance_task = Task.create( name=FERRE_TASK_NAME, version=__version__, parameters=parameters ) @@ -445,18 +568,9 @@ def create_abundance_tasks( return tasks -from astra.database.astradb import ( - database, - Task, - TaskInputDataProducts, - AspcapOutput, - Output, - TaskOutput, -) -from astra.base import ExecutableTask, TupleParameter -class Aspcap(ExecutableTask): +class Aspcap(TaskInstance): stellar_parameter_task_ids = TupleParameter("stellar_parameter_task_ids") abundance_task_ids = TupleParameter("abundance_task_ids") diff --git a/python/astra/contrib/aspcap/continuum.py b/python/astra/contrib/aspcap/continuum.py index 10bbb19..4b9dabb 100644 --- a/python/astra/contrib/aspcap/continuum.py +++ b/python/astra/contrib/aspcap/continuum.py @@ -1,3 +1,4 @@ +from re import A import numpy as np from scipy.ndimage.filters import median_filter @@ -7,164 +8,143 @@ from astra.contrib.ferre import bitmask from astra.contrib.ferre.utils import read_ferre_headers from astra.database.astradb import Task -from astra.tools.continuum.base import NormalizationBase +from astra.tools.continuum.base import Continuum +from astra.tools.spectrum import Spectrum1D +from astropy.io import fits +from typing import Optional, List, Tuple, Union +from astra.tools.spectrum import SpectralAxis -class MedianNormalizationWithErrorInflation(NormalizationBase): +class MedianFilter(Continuum): - """ - Continuum-normalize an input spectrum by the median flux value, - and inflate the errors due to skylines and bad pixels. - """ - - parameter_names = () + """Use a median filter to represent the stellar continuum.""" def __init__( self, - spectrum, - axis=1, - ivar_multiplier_for_sig_skyline=1e-4, - ivar_min=0, - ivar_max=40_000, - bad_pixel_flux=1e-4, - bad_pixel_ivar=1e-20, + upstream_task_id: int, + median_filter_width: Optional[int] = 151, + bad_minimum_flux: Optional[float] = 0.01, + non_finite_err_value: Optional[float] = 1e10, + valid_continuum_correction_range: Optional[Tuple[float]] = (0.1, 10.0), + mode: Optional[str] = "constant", + spectral_axis: Optional[SpectralAxis] = None, + regions: Optional[List[Tuple[float, float]]] = None, + mask: Optional[Union[str, np.array]] = None, + fill_value: Optional[Union[int, float]] = np.nan, + **kwargs ) -> None: - super().__init__(spectrum) - self.axis = axis - self.ivar_multiplier_for_sig_skyline = ivar_multiplier_for_sig_skyline - self.ivar_min = ivar_min - self.ivar_max = ivar_max - self.bad_pixel_flux = bad_pixel_flux - self.bad_pixel_ivar = bad_pixel_ivar + ( + """ + :param median_filter_width: [optional] + The width (int) for the median filter (default: 151). - def __call__(self): + :param bad_minimum_flux: [optional] + The value at which to set pixels as bad and median filter over them. This should be a float, + or `None` to set no low-flux filtering (default: 0.01). - pixel_bit_mask = bitmask.PixelBitMask() + :param non_finite_err_value: [optional] + The error value to set for pixels with non-finite fluxes (default: 1e10). - # Normalize. - continuum = np.nanmedian(self.spectrum.flux.value, axis=self.axis).reshape( - (-1, 1) + :param valid_continuum_correction_range: [optional] + A (min, max) tuple of the bounds that the final correction can have. Values outside this range will be set + as 1. + """ + + Continuum.__init__.__doc__ ) - self.spectrum._data /= continuum - self.spectrum._uncertainty.array *= continuum**2 - - # Increase the error around significant skylines. - skyline_mask = ( - self.spectrum.meta["bitmask"] & pixel_bit_mask.get_value("SIG_SKYLINE") - ) > 0 - self.spectrum._uncertainty.array[ - skyline_mask - ] *= self.ivar_multiplier_for_sig_skyline - - # Set bad pixels to have no useful data. - bad = ( - ~np.isfinite(self.spectrum.flux.value) - | ~np.isfinite(self.spectrum.uncertainty.array) - | (self.spectrum.flux.value < 0) - | (self.spectrum.uncertainty.array < 0) - | ((self.spectrum.meta["bitmask"] & pixel_bit_mask.get_level_value(1)) > 0) + super(MedianFilter, self).__init__( + spectral_axis=spectral_axis, + regions=regions, + mask=mask, + fill_value=fill_value, + **kwargs, ) - - self.spectrum._data[bad] = self.bad_pixel_flux - self.spectrum._uncertainty.array[bad] = self.bad_pixel_ivar - - # Ensure a minimum error. - # TODO: This seems like a pretty bad idea! - self.spectrum._uncertainty.array = np.clip( - self.spectrum._uncertainty.array, self.ivar_min, self.ivar_max - ) # sigma = 5e-3 - - return self.spectrum - - -class MedianFilterNormalizationWithErrorInflation( - MedianNormalizationWithErrorInflation -): - - parameter_names = () - - def __init__( - self, - spectrum, - median_filter_from_task, - segment_indices=None, - median_filter_width=151, - bad_minimum_flux=0.01, - non_finite_err_value=1e10, - valid_continuum_correction_range=(0.1, 10.0), - **kwargs, - ) -> None: - super().__init__(spectrum, **kwargs) - self.median_filter_from_task = median_filter_from_task - self.segment_indices = segment_indices + self.upstream_task_id = upstream_task_id self.median_filter_width = median_filter_width self.bad_minimum_flux = bad_minimum_flux self.non_finite_err_value = non_finite_err_value self.valid_continuum_correction_range = valid_continuum_correction_range + self.mode = mode return None - def __call__(self): - # Do standard median normalization. - spectrum = super().__call__() + def _initialize(self, spectrum, task=None): + try: + self._initialized_args + except AttributeError: + if self.regions is None: + # Get the regions from the model wavelength segments. + if task is None: + task = Task.get(self.upstream_task_id) + + regions = [] + for header in read_ferre_headers(expand_path(task.parameters["header_path"]))[1:]: + crval, cdelt = header["WAVE"] + npix = header["NPIX"] + regions.append((10**crval, 10**(crval + cdelt * npix))) + self.regions = regions + + self._initialized_args = super(MedianFilter, self)._initialize(spectrum) + finally: + return self._initialized_args + + + def fit(self, spectrum: Spectrum1D, hdu=3): + + task = Task.get(self.upstream_task_id) + region_slices, region_masks = self._initialize(spectrum, task) + + #flux/continuum and model_flux + # before: + # (ferre_flux / continuum) and (model_flux) + # now: + # ferre_flux and (model_flux / continuum) + # ratio = continuum * (ferre_flux / model_flux) + + # This is an astraStar-FERRE product, but let's just use fits.open + with fits.open(task.output_data_products[0].path) as image: + # How do we decide on the HDU for this product just from the spectrum? + continuum = image[hdu].data["CONTINUUM"] + flux = image[hdu].data["FERRE_FLUX"] + rectified_model_flux = image[hdu].data["MODEL_FLUX"] / continuum + + N, P = flux.shape + self._continuum = np.nan * np.ones((N, P)) + for i in range(N): + for region_mask in region_masks: + flux_region, model_flux_region = (flux[i, region_mask].copy(), rectified_model_flux[i, region_mask].copy()) + + # TODO: It's a little counter-intuitive how this is documented, so we should fix that. + # Or allow for a MAGIC number 5 instead. + median = median_filter(flux[i, region_mask], [self.median_filter_width], mode=self.mode) + + bad = np.where( + (flux_region < self.bad_minimum_flux) | (flux_region > (np.nanmedian(flux_region) + 3 * np.nanstd(flux_region))) + )[0] + flux_region[bad] = median[bad] + + ratio_region = flux_region / model_flux_region + self._continuum[i, region_mask] = median_filter( + ratio_region, + [self.median_filter_width], + mode=self.mode, + cval=1.0 + ) - if not isinstance(self.median_filter_from_task, Task): - median_filter_from_task = Task.get_by_id(int(self.median_filter_from_task)) - else: - median_filter_from_task = self.median_filter_from_task + scalars = np.nanmedian(spectrum.flux.value / self._continuum, axis=1) + self._continuum *= scalars + return None + + + def __call__( + self, + spectrum: Spectrum1D, + theta: Optional[Union[List, np.array, Tuple]] = None, + **kwargs + ) -> np.ndarray: + if theta is not None: + log.warning(f"Continuum coefficients ignored here") + return self._continuum - # Need number of pixels from header - n_pixels = np.array( - [ - header["NPIX"] - for header in read_ferre_headers( - expand_path(median_filter_from_task.parameters["header_path"]) - ) - ][1:] - ) - _ = n_pixels.cumsum() - segment_indices = np.sort(np.hstack([[0], _, _]))[:-1].reshape((-1, 2)) - - continuum = [] - for output_data_product in median_filter_from_task.output_data_products: - with open(output_data_product.path, "rb") as fp: - output = pickle.load(fp) - - kwds = dict( - wavelength=output["data"]["wavelength"], - normalised_observed_flux=output["data"]["flux"] - / output["data"]["continuum"], - normalised_observed_flux_err=output["data"]["flux_sigma"] - / output["data"]["continuum"], - normalised_model_flux=output["data"]["model_flux"], - segment_indices=segment_indices, - median_filter_width=self.median_filter_width, - bad_minimum_flux=self.bad_minimum_flux, - non_finite_err_value=self.non_finite_err_value, - valid_continuum_correction_range=self.valid_continuum_correction_range, - ) - continuum.append(median_filtered_correction(**kwds)) - - continuum = np.array(continuum) - - # Construct mask to match FERRE model grid. - N, P = spectrum.flux.shape - mask = np.zeros(P, dtype=bool) - for si, ei in segment_indices: - # TODO: Building wavelength mask off just the last wavelength array. - # We are assuming all have the same wavelength array. - s_index, e_index = spectrum.wavelength.value.searchsorted( - output["data"]["wavelength"][si:ei][[0, -1]] - ) - mask[s_index : e_index + 1] = True - - continuum_unmasked = np.nan * np.ones((N, P)) - continuum_unmasked[:, mask] = np.array(continuum) - - spectrum._data /= continuum_unmasked - spectrum._uncertainty.array *= continuum_unmasked * continuum_unmasked - - return spectrum def median_filtered_correction( @@ -302,3 +282,6 @@ def median_filtered_correction( continuum[bad] = 1 return continuum + + +MedianFilterNormalizationWithErrorInflation = MedianFilter diff --git a/python/astra/contrib/ferre/base.py b/python/astra/contrib/ferre/base.py index e943908..1d6dcc9 100644 --- a/python/astra/contrib/ferre/base.py +++ b/python/astra/contrib/ferre/base.py @@ -5,12 +5,14 @@ import sys import pickle from tempfile import mkdtemp - +from astropy.nddata import StdDevUncertainty +from collections import OrderedDict from astra import log, __version__ from astra.base import TaskInstance, Parameter, TupleParameter, DictParameter -from astra.tools.spectrum import Spectrum1D -#from astra.contrib.ferre import bitmask, utils -from astra.utils import flatten, executable, expand_path, nested_list +from astra.tools.spectrum import Spectrum1D, SpectrumList +from astra.tools.spectrum.utils import spectrum_overlaps +from astra.contrib.ferre import bitmask, utils +from astra.utils import dict_to_list, list_to_dict, flatten, executable, expand_path, nested_list from astra.database.astradb import ( database, DataProduct, @@ -19,58 +21,52 @@ TaskOutput, FerreOutput, ) -#from astra.operators.sdss import get_apvisit_metadata - +from astra.sdss.datamodels.pipeline import create_pipeline_product +from astra.sdss.datamodels.base import get_extname +from astra.contrib.ferre.bitmask import (PixelBitMask, ParamBitMask) class Ferre(TaskInstance): - header_path = Parameter("header_path", bundled=True) - initial_parameters = DictParameter("initial_parameters", default=None) - frozen_parameters = DictParameter("frozen_parameters", default=None, bundled=True) - interpolation_order = Parameter("interpolation_order", default=3, bundled=True) - weight_path = Parameter("weight_path", default=None, bundled=True) - lsf_shape_path = Parameter("lsf_shape_path", default=None, bundled=True) - lsf_shape_flag = Parameter("lsf_shape_flag", default=0, bundled=True) - error_algorithm_flag = Parameter("error_algorithm_flag", default=1, bundled=True) - wavelength_interpolation_flag = Parameter( - "wavelength_interpolation_flag", default=0, bundled=True - ) - optimization_algorithm_flag = Parameter( - "optimization_algorithm_flag", default=3, bundled=True - ) - continuum_flag = Parameter("continuum_flag", default=1, bundled=True) - continuum_order = Parameter("continuum_order", default=4, bundled=True) - continuum_segment = Parameter("continuum_segment", default=None, bundled=True) - continuum_reject = Parameter("continuum_reject", default=0.3, bundled=True) - continuum_observations_flag = Parameter( - "continuum_observations_flag", default=1, bundled=True - ) - full_covariance = Parameter("full_covariance", default=False, bundled=True) - pca_project = Parameter("pca_project", default=False, bundled=True) - pca_chi = Parameter("pca_chi", default=False, bundled=True) - f_access = Parameter("f_access", default=None, bundled=True) - f_format = Parameter("f_format", default=1, bundled=True) - ferre_kwds = DictParameter("ferre_kwds", default=None, bundled=True) - parent_dir = Parameter("parent_dir", default=None, bundled=True) - n_threads = Parameter("n_threads", default=1, bundled=True) - - # For normalization to be made before the FERRE run. - normalization_method = Parameter( - "normalization_method", default=None, bundled=False - ) - normalization_kwds = DictParameter( - "normalization_kwds", default=None, bundled=False - ) - - # For deciding what rows to use from each data product. - slice_args = TupleParameter("slice_args", default=None, bundled=False) + header_path = Parameter(bundled=True) + initial_parameters = DictParameter(default=None) + frozen_parameters = DictParameter(default=None, bundled=True) + interpolation_order = Parameter(default=3, bundled=True) + weight_path = Parameter(default=None, bundled=True) + lsf_shape_path = Parameter(default=None, bundled=True) + lsf_shape_flag = Parameter(default=0, bundled=True) + error_algorithm_flag = Parameter(default=1, bundled=True) + wavelength_interpolation_flag = Parameter(default=0, bundled=True) + optimization_algorithm_flag = Parameter(default=3, bundled=True) + continuum_flag = Parameter(default=1, bundled=True) + continuum_order = Parameter(default=4, bundled=True) + continuum_segment = Parameter(default=None, bundled=True) + continuum_reject = Parameter(default=0.3, bundled=True) + continuum_observations_flag = Parameter(default=1, bundled=True) + full_covariance = Parameter(default=False, bundled=True) + pca_project = Parameter(default=False, bundled=True) + pca_chi = Parameter(default=False, bundled=True) + f_access = Parameter(default=None, bundled=True) + f_format = Parameter(default=1, bundled=True) + ferre_kwds = DictParameter(default=None, bundled=True) + parent_dir = Parameter(default=None, bundled=True) + n_threads = Parameter(default=1, bundled=True) + + # For rectification to be made before the FERRE run. + continuum_method = Parameter(default=None, bundled=True) + continuum_kwargs = DictParameter(default=None, bundled=True) + + data_slice = TupleParameter(default=[0, 1]) # only relevant for ApStar data products + + bad_pixel_flux_value = Parameter(default=1e-4) + bad_pixel_sigma_value = Parameter(default=1e10) + skyline_sigma_multiplier = Parameter(default=100) + min_sigma_value = Parameter(default=0.05) # FERRE will sometimes hang forever if there is a spike in the data (e.g., a skyline) that # is not represented by the uncertainty array (e.g., it looks 'real'). - # An example of this on Utah is under ~/ferre-death-examples/spike/ # To self-preserve FERRE, we do a little adjustment to the uncertainty array. - spike_threshold_to_inflate_uncertainty = Parameter(default=5, bundled=True) + spike_threshold_to_inflate_uncertainty = Parameter(default=5) # Maximum timeout in seconds for FERRE timeout = Parameter(default=12 * 60 * 60, bundled=True) @@ -104,24 +100,30 @@ def estimate_relative_cost_factors(cls, parameters): # If we are not slicing the data products then the scaling should go as the size of the # data products - if parameters.get("slice_args", None) is None: + if parameters.get("data_slice", None) is None: factor_data_product_size = scale else: # Estimate the number slicing each time. - N = np.ptp(np.array(parameters["slice_args"]).flatten()) + N = np.ptp(np.array(parameters["data_slice"]).flatten()) factor_data_product = N * scale return np.array([factor_task, factor_data_product, factor_data_product_size]) @classmethod - def to_name(cls, i, j, k, data_product, snr, **kwargs): - obj = data_product.kwargs.get("obj", "NOOBJ") - return f"{i:.0f}_{j:.0f}_{k:.0f}_{snr:.1f}_{obj}" + def to_name(cls, i, j, k, l, data_product, snr, **kwargs): + keys = ("cat_id", "obj") + for key in keys: + if key in data_product.kwargs: + obj = data_product.kwargs[key] + break + else: + obj = "NOOBJ" + return f"{i:.0f}_{j:.0f}_{k:.0f}_{l:.0f}_{snr:.1f}_{obj}" @classmethod def from_name(cls, name): - i, j, k, snr, *obj = name.split("_") - return dict(i=int(i), j=int(j), k=int(k), snr=float(snr), obj="_".join(obj)) + i, j, k, l, snr, *obj = name.split("_") + return dict(i=int(i), j=int(j), k=int(k), l=int(l), snr=float(snr), obj="_".join(obj)) def pre_execute(self): @@ -148,7 +150,7 @@ def pre_execute(self): else: dir = mkdtemp(dir=parent_dir) else: - dir = os.path.join(parent_dir, f"bundles/{bundle.id % 100}/{bundle.id}") + dir = os.path.join(parent_dir, f"bundles/{bundle.id % 100:0>2.0f}/{bundle.id}") else: dir = mkdtemp() @@ -189,100 +191,130 @@ def pre_execute(self): with open(os.path.join(dir, "input.nml"), "w") as fp: fp.write(utils.format_ferre_control_keywords(control_kwds)) + if self.continuum_method is not None: + f_continuum = executable(self.continuum_method)(**self.continuum_kwargs) + else: + f_continuum = None + + pixel_mask = PixelBitMask() + + # Construct mask to match FERRE model grid. + model_wavelengths = tuple(map(utils.wavelength_array, segment_headers)) + # Read in the input data products. - wl, flux, sigma = ([], [], []) - names, initial_parameters_as_dicts = ([], []) - indices = [] - spectrum_metas = [] + indices, flux, sigma, names, initial_parameters_as_dicts = ([], [], [], [], []) for i, (task, data_products, parameters) in enumerate(self.iterable()): for j, data_product in enumerate(flatten(data_products)): - spectrum = Spectrum1D.read(data_product.path) - - # Get relevant spectrum metadata - spectrum_meta = get_apvisit_metadata(data_product) - - # Apply any slicing, if requested. - if parameters["slice_args"] is not None: - # TODO: Refactor this and put somewhere common. - slices = tuple([slice(*args) for args in parameters["slice_args"]]) - spectrum_meta = spectrum_meta[ - slices[0] - ] # TODO: allow for more than 1 slice? - spectrum._data = spectrum._data[slices] - spectrum._uncertainty.array = spectrum._uncertainty.array[slices] - for key in ("bitmask", "snr"): - try: - spectrum.meta[key] = np.array(spectrum.meta[key])[slices] - except: - log.exception( - f"Unable to slice '{key}' metadata with {parameters['slice_args']} on {task} {data_product}" - ) + for k, spectrum in enumerate(SpectrumList.read(data_product.path, data_slice=parameters["data_slice"])): + if not spectrum_overlaps(spectrum, np.hstack(model_wavelengths)): + continue + + N, P = spectrum.flux.shape + wl_ = spectrum.wavelength.value + flux_ = spectrum.flux.value + sigma_ = spectrum.uncertainty.represent_as(StdDevUncertainty).array + + # Perform any continuum rectification pre-processing. + if f_continuum is not None: + f_continuum.fit(spectrum) + continuum = f_continuum(spectrum) + flux_ /= continuum + sigma_ /= continuum + else: + continuum = None + + # Inflate errors around skylines, etc. + skyline_mask = ( + spectrum.meta["BITMASK"] & pixel_mask.get_value("SIG_SKYLINE") + ) > 0 + sigma_[skyline_mask] *= parameters["skyline_sigma_multiplier"] + + # Set bad pixels to have no useful data. + if parameters["bad_pixel_flux_value"] is not None or parameters["bad_pixel_sigma_value"] is not None: + bad = ( + ~np.isfinite(flux_) + | ~np.isfinite(sigma_) + | (flux_ < 0) + | (sigma_ < 0) + | ((spectrum.meta["BITMASK"] & pixel_mask.get_level_value(1)) > 0) + ) - # Apply any general normalization method. - if parameters["normalization_method"] is not None: - _class = executable(parameters["normalization_method"]) - rectifier = _class( - spectrum, **(parameters["normalization_kwds"] or dict()) - ) + flux_[bad] = parameters["bad_pixel_flux_value"] + sigma_[bad] = parameters["bad_pixel_sigma_value"] - # Normalization methods for FERRE cannot be applied within the log-likelihood - # function, because we'd have to have it executed *within* FERRE. - if len(rectifier.parameter_names) > 0: - raise TypeError( - f"Normalization method {parameters['normalization_method']} on {self} cannot be applied within the log-likelihood function for FERRE." - ) - spectrum = rectifier() + # Clip the error array. This is a pretty bad idea but I am doing what was done before! + if parameters["min_sigma_value"] is not None: + sigma_ = np.clip(sigma_, parameters["min_sigma_value"], np.inf) - N, P = spectrum.flux.shape - initial_parameters = parameters["initial_parameters"] - # Allow initital parameters to be a dict (applied to all spectra) or a list of dicts (one per spectra) - log.debug( - f"There are {N} spectra in {task} {data_product} and initial params is {len(initial_parameters)} long" - ) - log.debug(f"And {set(map(type, initial_parameters))}") + # Retrict to the pixels within the model wavelength grid. + mask = _get_ferre_mask(wl_, model_wavelengths) - if len(initial_parameters) == N and all( - isinstance(_, dict) for _ in initial_parameters - ): - log.debug( - f"Allowing different initial parameters for each {N} spectra on task {task}" - ) - initial_parameters_as_dicts.extend(initial_parameters) - else: - if N > 1: - log.debug( - f"Using same initial parameters {initial_parameters} for all {N} spectra on task {task}" - ) - initial_parameters_as_dicts.extend([initial_parameters] * N) + flux_ = flux_[:, mask] + sigma_ = sigma_[:, mask] - if N != len(spectrum_meta): - log.warning( - f"Number of spectra does not match expected from visit metadata: {N} != {len(spectrum_meta)}" - ) + # Sometimes FERRE will run forever. + # TODO: rename to spike_threshold_for_bad_pixel + if parameters["spike_threshold_to_inflate_uncertainty"] > 0: - for k in range(N): - indices.append((i, j, k)) - names.append( - self.to_name( - i=i, - j=j, - k=k, - data_product=data_product, - snr=spectrum.meta["snr"][k], + flux_median = np.median(flux_, axis=1).reshape((-1, 1)) + flux_stddev = np.std(flux_, axis=1).reshape((-1, 1)) + sigma_median = np.median(sigma_, axis=1).reshape((-1, 1)) + + delta = (flux_ - flux_median) / flux_stddev + is_spike = (delta > parameters["spike_threshold_to_inflate_uncertainty"]) * ( + sigma_ < (parameters["spike_threshold_to_inflate_uncertainty"] * sigma_median) ) + if np.any(is_spike): + fraction = np.sum(is_spike) / is_spike.size + log.warning( + f"Inflating uncertainties for {np.sum(is_spike)} pixels ({100 * fraction:.2f}%) that were identified as spikes." + ) + for pi in range(is_spike.shape[0]): + n = np.sum(is_spike[pi]) + if n > 0: + log.debug(f" {n} pixels on spectrum index {pi}") + sigma_[is_spike] = parameters["bad_pixel_sigma_value"] + + # Parse initial parameters. Expected types: + # - dictionary of single values -> apply single value to all N spectra + # - dictionary of lists of length N -> different value per spectrum + # - dictionary of lists of single value -> apply single value to all N spectra + # TODO: Move this logic elsewhere so it's testable. + initial_parameters = parameters["initial_parameters"] + # Allow initital parameters to be a dict (applied to all spectra) or a list of dicts (one per spectra) + log.debug( + f"There are {N} spectra in {task} {data_product} and initial params is {len(initial_parameters)} long" ) - wl.append(spectrum.wavelength) - flux.append(spectrum.flux.value[k]) - sigma.append(spectrum.uncertainty.array[k] ** -0.5) + log.debug(f"And {set(map(type, initial_parameters))}") - spectrum_metas.extend(spectrum_meta) - - indices, wl, flux, sigma = ( - np.array(indices), - np.array(wl), - np.array(flux), - np.array(sigma), - ) + if len(initial_parameters) == N and all( + isinstance(_, dict) for _ in initial_parameters + ): + log.debug( + f"Allowing different initial parameters for each {N} spectra on task {task}" + ) + initial_parameters_as_dicts.extend(initial_parameters) + else: + if N > 1: + log.debug( + f"Using same initial parameters {initial_parameters} for all {N} spectra on task {task}" + ) + initial_parameters_as_dicts.extend([initial_parameters] * N) + + for l in range(N): + indices.append((i, j, k, l)) + names.append( + self.to_name( + i=i, + j=j, + k=k, + l=l, + data_product=data_product, + snr=spectrum.meta["SNR"].flatten()[l], + ) + ) + flux.append(flux_) + sigma.append(sigma_) # Convert list of dicts of initial parameters to array. initial_parameters = utils.validate_initial_and_frozen_parameters( @@ -297,54 +329,23 @@ def pre_execute(self): for name, point in zip(names, initial_parameters): fp.write(utils.format_ferre_input_parameters(*point, name=name)) - # Construct mask to match FERRE model grid. - model_wavelengths = tuple(map(utils.wavelength_array, segment_headers)) - mask = np.zeros(wl.shape[1], dtype=bool) - for model_wavelength in model_wavelengths: - # TODO: Building wavelength mask off just the first wavelength array. - # We are assuming all have the same wavelength array. - s_index, e_index = wl[0].searchsorted(model_wavelength[[0, -1]]) - mask[s_index : e_index + 1] = True - - # Sometimes FERRE will run forever - self.spike_threshold_to_inflate_uncertainty = 5 - if self.spike_threshold_to_inflate_uncertainty > 0: - flux_median = np.median(flux[:, mask], axis=1).reshape((-1, 1)) - flux_stddev = np.std(flux[:, mask], axis=1).reshape((-1, 1)) - sigma_median = np.median(sigma[:, mask], axis=1).reshape((-1, 1)) - - delta = (flux - flux_median) / flux_stddev - is_spike = (delta > self.spike_threshold_to_inflate_uncertainty) * ( - sigma < (self.spike_threshold_to_inflate_uncertainty * sigma_median) - ) - - if np.any(is_spike): - fraction = np.sum(is_spike[:, mask]) / is_spike[:, mask].size - log.warning( - f"Inflating uncertainties for {np.sum(is_spike)} pixels ({100 * fraction:.2f}%) that were identified as spikes." - ) - for i in range(is_spike.shape[0]): - n = np.sum(is_spike[i, mask]) - if n > 0: - log.debug(f" {n} pixels on spectrum index {i}") - sigma[is_spike] = 1e10 + indices = np.array(indices) + N, _ = indices.shape + flux = np.array(flux).reshape((N, -1)) + sigma = np.array(sigma).reshape((N, -1)) # Write data arrays. savetxt_kwds = dict(fmt="%.4e", footer="\n") np.savetxt( - os.path.join(dir, control_kwds["ffile"]), flux[:, mask], **savetxt_kwds + os.path.join(dir, control_kwds["ffile"]), flux, **savetxt_kwds ) np.savetxt( - os.path.join(dir, control_kwds["erfile"]), sigma[:, mask], **savetxt_kwds + os.path.join(dir, control_kwds["erfile"]), sigma, **savetxt_kwds ) - - # Write metadata file to pick up later. - with open(os.path.join(dir, "spectrum_meta.pkl"), "wb") as fp: - pickle.dump(spectrum_metas, fp) - context = dict(dir=dir) return context + def execute(self): """Execute FERRE""" @@ -415,6 +416,7 @@ def execute(self): with open(os.path.join(dir, "stderr"), "w") as fp: fp.write(stderr) + ''' # We actually have timings per-spectrum but we aggregate this to per-task. # We might want to store the per-data-product and per-spectrum timing elsewhere. try: @@ -440,6 +442,23 @@ def execute(self): log.debug(f"Timing information from FERRE stdout:") for key, value in timings.items(): log.debug(f"\t{key}: {value}") + ''' + + + def post_execute(self): + """ + Post-execute hook after FERRE is complete. + + Read in the output files, create rows in the database, and produce output data products. + """ + dir = self.context["pre_execute"]["dir"] + with open(os.path.join(dir, "stdout"), "r") as fp: + stdout = fp.read() + with open(os.path.join(dir, "stderr"), "r") as fp: + stderr = fp.read() + + n_done, n_error, control_kwds = utils.parse_ferre_output(dir, stdout, stderr) + timings = utils.get_processing_times(stdout) # Parse the outputs from the FERRE run. path = os.path.join(dir, control_kwds["OPFILE"]) @@ -479,7 +498,7 @@ def execute(self): path = os.path.join(dir, control_kwds["SFFILE"]) normalized_flux = np.atleast_2d(np.loadtxt(path)) except: - log.exception(f"Failed to load normalized flux from {path}") + log.exception(f"Failed to load normalized observed flux from {path}") raise else: continuum = flux / normalized_flux @@ -492,8 +511,6 @@ def execute(self): ) parameter_names = utils.sanitise(headers["LABEL"]) - wavelength = np.hstack(tuple(map(utils.wavelength_array, segment_headers))) - # Flag things. param_bitmask = bitmask.ParamBitMask() param_bitmask_flags = np.zeros(params.shape, dtype=np.int64) @@ -523,126 +540,114 @@ def execute(self): ) log.warning(f"FERRE returned all erroneous values for an entry: {idx} {v}") - with open(os.path.join(dir, "spectrum_meta.pkl"), "rb") as fp: - spectrum_metas = pickle.load(fp) - - results_dict = {} - ijks = [] - for z, (name, param, param_err, bitmask_flag, spectrum_meta) in enumerate( - zip(names, params, param_errs, param_bitmask_flags, spectrum_metas) + model_wavelengths = tuple(map(utils.wavelength_array, segment_headers)) + label_results = {} + spectral_results = {} + for z, (name, param, param_err, bitmask_flag) in enumerate( + zip(names, params, param_errs, param_bitmask_flags) ): parsed = self.from_name(name) - result = dict( - log_chisq_fit=meta["log_chisq_fit"][z], - log_snr_sq=meta["log_snr_sq"][z], - frac_phot_data_points=meta["frac_phot_data_points"][z], - snr=parsed["snr"], - ) - result["meta"] = spectrum_meta - - try: - result.update( - ferre_time_elapsed=timings["time_per_spectrum"][z], - ferre_time_load=timings["time_load"], - ferre_n_obj=timings["n_obj"], - ferre_n_threads=timings["n_threads"], - ) - except: - log.exception( - f"Exception while trying to include FERRE timing information in the database for {self}" - ) - - result.update(dict(zip(parameter_names, param))) - result.update(dict(zip([f"u_{pn}" for pn in parameter_names], param_err))) + result = OrderedDict(zip(reversed(parameter_names), reversed(param))) + result.update(dict(zip([f"e_{pn}" for pn in reversed(parameter_names)], reversed(param_err)))) result.update( - dict(zip([f"bitmask_{pn}" for pn in parameter_names], bitmask_flag)) + dict(zip([f"bitmask_{pn}" for pn in reversed(parameter_names)], reversed(bitmask_flag))) ) - - # Add spectra. - result["data"] = dict( - wavelength=wavelength, - flux=flux[z], - flux_sigma=flux_sigma[z], - model_flux=model_flux[z], - continuum=continuum[z], - normalized_flux=normalized_flux[z], + result.update( + dict( + log_chisq_fit=meta["log_chisq_fit"][z], + log_snr_sq=meta["log_snr_sq"][z], + frac_phot_data_points=meta["frac_phot_data_points"][z], + snr=parsed["snr"], + ) ) i, j, k = (int(parsed[_]) for _ in "ijk") - ijks.append((i, j, k)) - results_dict.setdefault((i, j), []) - results_dict[(i, j)].append(result) - # List-ify. - results = nested_list(ijks) - for (i, j), value in results_dict.items(): - for k, result in enumerate(value): - results[i][j][k] = result - return results - - def post_execute(self): - """ - Post-execute hook after FERRE is complete. - - Read in the output files, create rows in the database, and produce output data products. - """ - - results = self.context["execute"] + label_results.setdefault((i, j, k), []) + label_results[(i, j, k)].append(result) + spectral_results.setdefault((i, j, k), []) + + # TODO: These need to be resampled to the observed pixels! + spectral_results[(i, j, k)].append(dict( + model_flux=model_flux[z], + continuum=continuum[z], + # FERRE_flux is what we actually gave to FERRE. Store it here just in case? + ferre_flux=flux[z], + e_ferre_flux=flux_sigma[z], + )) + if self.continuum_method is not None: + f_continuum = executable(self.continuum_method)(**self.continuum_kwargs) + else: + f_continuum = None + # Create outputs in the database. - with database.atomic() as txn: - for (task, data_products, _), task_results in zip(self.iterable(), results): - for (data_product, data_product_results) in zip( - flatten(data_products), task_results - ): - - for result in data_product_results: - output = Output.create() - - # Create a data product. - # TODO: This is a temporary hack until we have a data model in place. - path = expand_path( - f"$MWM_ASTRA/{__version__}/ferre/tasks/{task.id % 100}/{task.id}/output_{output.id}.pkl" - ) - os.makedirs(os.path.dirname(path), exist_ok=True) - - with open(path, "wb") as fp: - pickle.dump(result, fp) - - output_data_product = DataProduct.create( - release=data_product.release, - filetype="full", - kwargs=dict(full=path), - ) - TaskOutputDataProducts.create( - task=task, data_product=output_data_product - ) - - # Spectra don't belong in the database. - result.pop("data") - - TaskOutput.create(task=task, output=output) - FerreOutput.create(task=task, output=output, **result) - - log.info( - f"Created output {output} for task {task} and data product {data_product}" - ) - log.info( - f"New output data product: {output_data_product} at {path}" - ) + for i, (task, (data_product, ), parameters) in enumerate(self.iterable()): + hdu_results = {} + task_results = [] + header_groups = {} + for k, spectrum in enumerate(SpectrumList.read(data_product.path, data_slice=parameters["data_slice"])): + if not spectrum_overlaps(spectrum, np.hstack(model_wavelengths)): + continue + + index = (i, 0, k) + + extname = get_extname(spectrum, data_product) + + # TODO: Put in the initial parameters? + hdu_result = list_to_dict(label_results[index]) + + spectral_results_ = list_to_dict(spectral_results[index]) + mask = _get_ferre_mask(spectrum.wavelength.value, model_wavelengths) + spectral_results_ = _de_mask_values(spectral_results_, mask) + + # TODO: Store this in a meta file instead of doing it in pre_ and post_?? + if f_continuum is not None: + f_continuum.fit(spectrum) + pre_continuum = f_continuum(spectrum) + else: + pre_continuum = 1 + + spectral_results_["continuum"] *= pre_continuum + spectral_results_["model_flux"] *= spectral_results_["continuum"] + + hdu_result.update(spectral_results_) + hdu_results[extname] = hdu_result + header_groups[extname] = [ + ("TEFF", "STELLAR LABELS"), + ("BITMASK_TEFF", "BITMASK FLAGS"), + ("LOG_CHISQ_FIT", "SUMMARY STATISTICS"), + ("MODEL_FLUX", "MODEL SPECTRA") + ] + task_results.extend(label_results[index]) + + create_pipeline_product(task, data_product, hdu_results, header_groups=header_groups) + + # Add to database. + task.create_or_update_outputs(FerreOutput, task_results) return None -""" -def create_task_output(task, model, **kwargs): - output = Output.create() - task_output = TaskOutput.create(task=task, output=output) - result = model.create( - task=task, - output=output, - **kwargs - ) - return (output, task_output, result) -""" +def _get_ferre_mask(observed_wavelength, model_wavelengths): + P = observed_wavelength.size + mask = np.zeros(P, dtype=bool) + for model_wavelength in model_wavelengths: + s_index, e_index = observed_wavelength.searchsorted(model_wavelength[[0, -1]]) + if (e_index - s_index) != model_wavelength.size: + log.warn(f"Model wavelength grid does not precisely match data product ({e_index - s_index} vs {model_wavelength.size} on {model_wavelength[[0, -1]]})") + e_index = s_index + model_wavelength.size + mask[s_index:e_index] = True + return mask + + +def _de_mask_values(spectral_dict, mask, fill_value=np.nan): + P = mask.size + updated = {} + for k, v in spectral_dict.items(): + N, O = np.atleast_2d(v).shape + updated_v = fill_value * np.ones((N, P)) + updated_v[:, mask] = np.array(v) + updated[k] = updated_v + return updated \ No newline at end of file diff --git a/python/astra/contrib/ferre/bitmask.py b/python/astra/contrib/ferre/bitmask.py index 51e2a28..c43c277 100644 --- a/python/astra/contrib/ferre/bitmask.py +++ b/python/astra/contrib/ferre/bitmask.py @@ -1,8 +1,53 @@ #from astra.utils.bitmask import BitFlagNameMap from collections import OrderedDict -class BitFlagNameMap: - pass +import numpy as np + +class BitFlagNameMap(object): + + def get_names(self, value): + names = [] + for k, entry in self.__class__.__dict__.items(): + if k.startswith("_") or k != k.upper(): continue + + if isinstance(entry, int): + v = entry + else: + v, comment = entry + + if value & (2**v): + names.append(k) + + return tuple(names) + + + + def get_value(self, *names): + value = np.int64(0) + + for name in names: + try: + entry = getattr(self, name) + except KeyError: + raise ValueError(f"no bit flag found named '{name}'") + + if isinstance(entry, int): + entry = (entry, "") + + v, comment = entry + value |= np.int64(2**v) + + return value + + + def get_level_value(self, level): + try: + names = self.levels[level] + except KeyError: + raise ValueError(f"No level name '{level}' found (available: {' '.join(list(self.levels.keys()))})") + + return self.get_value(*names) + class ParamBitMask(BitFlagNameMap): diff --git a/python/astra/contrib/slam/base.py b/python/astra/contrib/slam/base.py index a7e984c..b9aaf99 100644 --- a/python/astra/contrib/slam/base.py +++ b/python/astra/contrib/slam/base.py @@ -191,17 +191,40 @@ def execute(self): # Create AstraStar product. model_continuum = flux_resamp / flux_norm - resampled_continuum = np.empty((N, P)) - resampled_model_flux = np.empty((N, P)) - for i in range(N): - assert np.all(np.isfinite(prediction)), "Prediction values not all finite?" - assert np.all(np.isfinite(model_continuum[i])), "Model continuum values not all finite?" - f = interp1d(wave_interp, prediction[i], kind="cubic", bounds_error=False, fill_value=np.nan) - c = interp1d(wave_interp, model_continuum[i], kind="cubic", bounds_error=False, fill_value=np.nan) - - # Re-sample the predicted spectra back to the observed frame. - resampled_model_flux[i] = f(wave) - resampled_continuum[i] = c(wave) + resampled_continuum = np.nan * np.ones((N, P)) + resampled_model_flux = np.nan * np.ones((N, P)) + if not np.all(np.isfinite(prediction)): + log.warning(f"Prediction values not all finite!") + if not np.all(np.isfinite(model_continuum[i])): + log.warning(f"Not all model continuum values finite!") + + try: + + for i in range(N): + #assert np.all(np.isfinite(prediction)), "Prediction values not all finite?" + #assert np.all(np.isfinite(model_continuum[i])), "Model continuum values not all finite?" + finite_prediction = np.isfinite(prediction[i]) + finite_model_continuum = np.isfinite(model_continuum[i]) + f = interp1d( + wave_interp[finite_prediction], + prediction[i][finite_prediction], + kind="cubic", + bounds_error=False, + fill_value=np.nan + ) + c = interp1d( + wave_interp[finite_model_continuum], + model_continuum[i][finite_model_continuum], + kind="cubic", + bounds_error=False, + fill_value=np.nan + ) + + # Re-sample the predicted spectra back to the observed frame. + resampled_model_flux[i] = f(wave) + resampled_continuum[i] = c(wave) + except: + log.exception(f"Exception in sampling spectra. Maybe there are all negative fluxes?") database_results.extend(dict_to_list(results)) results.update( diff --git a/python/astra/database/astradb.py b/python/astra/database/astradb.py index 5149787..334e1c9 100644 --- a/python/astra/database/astradb.py +++ b/python/astra/database/astradb.py @@ -102,6 +102,17 @@ def data_products(self): .where(Source.catalogid == self.catalogid) ) + ''' + # Sould link source and data product/task to the OUTPUT table... then it'd be easier. + @property + def outputs(self): + outputs = [] + o = TaskOutput.get(TaskOutput.task == self) + for expr, column in o.output.dependencies(): + if column.model not in (TaskOutput, AstraOutputBaseModel): + outputs.extend(column.model.select().where(column.model.task == self)) + return sorted(outputs, key=lambda x: x.output_id) + ''' class DataProductKeywordsField(JSONField): def adapt(self, kwargs): @@ -222,7 +233,19 @@ def path(self): return expand_path("$MWM_ASTRA/{v_astra}/{run2d}-{apred}/spectra/star/{catalogid_groups}/mwmStar-{v_astra}-{cat_id}.fits".format(catalogid_groups=catalogid_groups, **self.kwargs)) else: return expand_path("$MWM_ASTRA/{v_astra}/{run2d}-{apred}/spectra/visit/{catalogid_groups}/mwmVisit-{v_astra}-{cat_id}.fits".format(catalogid_groups=catalogid_groups, **self.kwargs)) - + + elif self.filetype.startswith("astraStar"): + pipeline = self.filetype[len("astraStar"):] + catalogid = self.kwargs['cat_id'] + k = 100 + #catalogid_groups = f"{(catalogid // k) % k:.0f}/{catalogid % k:.0f}" + catalogid_groups = f"{(catalogid // k) % k:0>2.0f}/{catalogid % k:0>2.0f}" + + log.warn("hard-coding in path") + from astra.utils import expand_path + return expand_path("$MWM_ASTRA/{v_astra}/{run2d}-{apred}/results/star/{catalogid_groups}/astraStar-{pipeline}-{v_astra}-{cat_id}-{task_id}.fits".format(pipeline=pipeline, catalogid_groups=catalogid_groups, **self.kwargs)) + + try: p = _sdss_path_instances[self.release] except KeyError: diff --git a/python/astra/operators/slurm.py b/python/astra/operators/slurm.py index 94343ce..60c4584 100644 --- a/python/astra/operators/slurm.py +++ b/python/astra/operators/slurm.py @@ -176,9 +176,16 @@ def execute(self, context): .tuples() ) """ + def estimate_relative_cost(bundle_id): + return ( + TaskBundle.select() + .join(Task) + .where(TaskBundle.bundle == bundle_id) + .count() + ) bundle_costs = np.array( [ - [bundle_id, estimate_relative_cost(Bundle.get(bundle_id))] + [bundle_id, estimate_relative_cost(bundle_id)] for bundle_id in primary_keys ] ) @@ -198,6 +205,7 @@ def execute(self, context): f"Total bundle cost: {np.sum(bundle_costs.T[1])} split across {Q_free} groups, with costs {group_costs} (max diff: {np.ptp(group_costs)})" ) log.debug(f"Number per item: {list(map(len, group_bundle_ids))}") + else: # Run all bundles in parallel. diff --git a/python/astra/sdss/datamodels/pipeline.py b/python/astra/sdss/datamodels/pipeline.py index eaeb439..4015c47 100644 --- a/python/astra/sdss/datamodels/pipeline.py +++ b/python/astra/sdss/datamodels/pipeline.py @@ -5,7 +5,7 @@ import numpy as np from typing import Dict, List, Union, Optional, Tuple from sdss_access import SDSSPath -from astra import log, __version__ as astra_version +from astra import log, __version__ as v_astra from astra.database.astradb import Task, DataProduct, Source, TaskOutputDataProducts from astropy.io import fits from functools import partial @@ -187,7 +187,7 @@ def create_pipeline_product( image = fits.HDUList(hdus) kwds = dict( - astra_version=astra_version, + v_astra=v_astra, apred=data_product.kwargs.get("apred", ""), run2d=data_product.kwargs.get("run2d", ""), # Get catalog identifier from primary HDU. Don't rely on it being in the data product kwargs. @@ -206,11 +206,11 @@ def create_pipeline_product( apred, run2d, cat_id = (kwds["apred"], kwds["run2d"], kwds["cat_id"]) if filetype.startswith("astraVisit"): path = expand_path( - f"$MWM_ASTRA/{astra_version}/{run2d}-{apred}/results/visit/{catalogid_groups}/astraVisit-{product_name}-{astra_version}-{cat_id}-{task_id}.fits" + f"$MWM_ASTRA/{v_astra}/{run2d}-{apred}/results/visit/{catalogid_groups}/astraVisit-{product_name}-{v_astra}-{cat_id}-{task_id}.fits" ) elif filetype.startswith("astraStar"): path = expand_path( - f"$MWM_ASTRA/{astra_version}/{run2d}-{apred}/results/star/{catalogid_groups}/astraStar-{product_name}-{astra_version}-{cat_id}-{task_id}.fits" + f"$MWM_ASTRA/{v_astra}/{run2d}-{apred}/results/star/{catalogid_groups}/astraStar-{product_name}-{v_astra}-{cat_id}-{task_id}.fits" ) os.makedirs(os.path.dirname(path), exist_ok=True) diff --git a/python/astra/tools/continuum/chebyshev.py b/python/astra/tools/continuum/chebyshev.py index 52e18b9..f2938c2 100644 --- a/python/astra/tools/continuum/chebyshev.py +++ b/python/astra/tools/continuum/chebyshev.py @@ -52,11 +52,15 @@ def fit(self, spectrum: Spectrum1D) -> Chebyshev: for i in range(N): for j, ((lower, upper), indices) in enumerate(zip(*_initialized_args)): x = np.linspace(-1, 1, upper - lower) + # Restrict to finite values. + y = flux[i, indices] + w = 1.0 / e_flux[i, indices] + finite = np.isfinite(y * w) f = np.polynomial.Chebyshev.fit( - x[indices - lower], - flux[i, indices], + x[indices - lower][finite], + y[finite], self.deg, - w=1.0 / e_flux[i, indices], + w=w[finite], ) self.theta[i, j] = f.convert().coef return self