diff --git a/bin/astra b/bin/astra index 08fb875..abe65ce 100755 --- a/bin/astra +++ b/bin/astra @@ -241,11 +241,7 @@ def execute(): """Execute a task or bundle.""" pass - -@execute.command(hidden=True) -@click.argument("task_ids", type=int, nargs=-1, required=True) -def tasks(task_ids): - """Execute multiple tasks""" +def _execute_tasks(task_ids): from astra import log from astra.utils import deserialize from astra.database.astradb import Task @@ -261,23 +257,20 @@ def tasks(task_ids): return None +@execute.command(hidden=True) +@click.argument("task_ids", type=int, nargs=-1, required=True) +def tasks(task_ids): + """Execute multiple tasks""" + return _execute_tasks(task_ids) + @execute.command() @click.argument("task_id", type=int, nargs=1, required=True) def task(task_id): """Execute a single task""" - return tasks([task_id]) + return _execute_tasks([task_id]) -@execute.command(hidden=True) -@click.option( - "--only-incomplete", - "only_incomplete", - is_flag=True, - help="Only execute the incomplete tasks in the bundle" -) -@click.argument("bundle_ids", type=int, nargs=-1, required=True) -def bundles(bundle_ids, only_incomplete): - """Execute bundles""" +def _execute_bundles(bundle_ids, only_incomplete): from astra import log from astra.utils import deserialize from astra.database.astradb import Bundle @@ -293,16 +286,27 @@ def bundles(bundle_ids, only_incomplete): return None +@execute.command(hidden=True) +@click.option( + "--only-incomplete", + is_flag=True, + help="Only execute the incomplete tasks in the bundle" +) +@click.argument("bundle_ids", type=int, nargs=-1, required=True) +def bundles(bundle_ids, only_incomplete): + """Execute bundles""" + _execute_bundles(bundle_ids, only_incomplete) + + @execute.command() @click.option( "--only-incomplete", - "only_incomplete", is_flag=True, help="Only execute the incomplete tasks in the bundle" ) @click.argument("bundle_id", type=int, nargs=-1, required=True) def bundle(bundle_id, only_incomplete): - return bundles([bundle_id], only_incomplete) + return _execute_bundles([bundle_id], only_incomplete) diff --git a/python/astra/base.py b/python/astra/base.py index 2c12da1..ae349fd 100644 --- a/python/astra/base.py +++ b/python/astra/base.py @@ -345,8 +345,11 @@ def from_bundle(cls, bundle, only_incomplete=False, strict=True): log.warn(f"Restricting to tasks in bundle {bundle} that are incomplete") q = q.where(Task.status != Status.get(description="completed")) - tasks = list(tasks) + tasks = list(q) bundle_size = len(tasks) + if bundle_size == 0: + raise EmptyBundleException(f"No tasks found for bundle {bundle} with only_incomplete={only_incomplete}") + context = { "input_data_products": [task.input_data_products for task in tasks], "tasks": tasks, @@ -480,7 +483,7 @@ def parse_parameters(cls, **kwargs): parsed[name] = (parameter, value, default, length, indexed) return (bundle_size, parsed) - def iterable(self, stage=None): + def iterable(self, stage=None, debug=True): """ Iterate over the tasks in the bundle. @@ -509,27 +512,36 @@ def iterable(self, stage=None): t_init = time() for i, item in enumerate(self.context["iterable"]): - yield item - - if stage is not None: - t_iterable = time() - t_init - try: - self.context["timing"][key][i] += t_iterable - except IndexError: # array is not long enough, we're in append mode - self.context["timing"][key].append(t_iterable) - finally: - t_init = time() - - # Update status for this task. - print(f"update timing??") - ''' - try: - task, idp, parameters = item - task.status_id = 5 + try: + yield item + except: + task, idp, parameters = item + log.exception(f"Exception in executing {i}th task in bundle ({task}):") + if stage is not None: + status = Status.get(description=f"failed-{stage.replace('_', '-')}") + task.status_id = status.id task.save() - except ValueError: - log.warning(f"Couldn't update task in status {task} {stage}") - ''' + raise + continue + else: + if stage is not None: + t_iterable = time() - t_init + try: + self.context["timing"][key][i] += t_iterable + except IndexError: # array is not long enough, we're in append mode + self.context["timing"][key].append(t_iterable) + finally: + t_init = time() + + # Update status for this task. if we set + # TODO: Get the correct status ID + try: + task, idp, parameters = item + task.status_id = 5 # TODO + task.completed = datetime.datetime.now() + task.save() + except ValueError: + log.warning(f"Couldn't update task in status {task} {stage}") # fin @@ -750,3 +762,7 @@ def get_or_create_data_products(iterable): else: dps.append(dp) return dps + + +class EmptyBundleException(Exception): + pass diff --git a/python/astra/sdss/operators/mwm.py b/python/astra/sdss/operators/mwm.py index b438ca2..8f2ee7c 100644 --- a/python/astra/sdss/operators/mwm.py +++ b/python/astra/sdss/operators/mwm.py @@ -11,6 +11,7 @@ SourceDataProduct, MWMSourceStatus ) +from peewee import fn from airflow.exceptions import AirflowSkipException from astra.utils import flatten from astra import log, __version__ as astra_version @@ -246,20 +247,18 @@ class MWMVisitStarFactory(BaseOperator): def __init__( self, *, - product_release: Optional[str] = None, - apred_release: Optional[str] = None, + release: Optional[str] = None, + #apred_release: Optional[str] = None, apred: Optional[str] = None, - run2d_release: Optional[str] = None, + #run2d_release: Optional[str] = None, run2d: Optional[str] = None, num_bundles: Optional[int] = 1, **kwargs, ) -> None: super().__init__(**kwargs) - self.product_release = product_release + self.release = release self.apred = apred - self.apred_release = apred_release self.run2d = run2d - self.run2d_release = run2d_release self.num_bundles = num_bundles if num_bundles < 1: @@ -270,21 +269,57 @@ def execute(self, context): ti, task = (context["ti"], context["task"]) - #if catalogids is None: - # catalogids = map( - # int, tuple(set(flatten(ti.xcom_pull(task_ids=task.upstream_task_ids)))) - # ) - #else: - catalogids = flatten(Source.select(Source.catalogid).tuples()) + # Get the data products identifiers from upstream operators + data_product_ids = map(int, tuple(set(flatten(ti.xcom_pull(task_ids=task.upstream_task_ids))))) + log.info(f"Found {len(data_product_ids)} data product identifiers") + # TODO: we could get run2d/apred from upstream, but should we? + ''' + if self.run2d is None: + # Get run2d from the input data products + q = flatten( + DataProduct + .select(fn.DISTINCT(DataProduct.kwargs["run2d"])) + .where( + (DataProduct.filetype == "specFull") + & (DataProduct.id.in_(data_product_ids)) + ) + .tuples() + ) + if len(q) > 1: + raise ValueError(f"No run2d given, and multiple run2d values found in input data products: {q}") + + self.run2d, = q + log.info(f"Setting run2d={self.run2d}") + + if self.apred is None: + # Get apred from the input data products. + q = flatten( + DataProduct + .select(fn.DISTINCT(DataProduct.kwargs["apred"])) + .where( + (DataProduct.filetype == "apVisit") + & (DataProduct.id.in_(data_product_ids)) + ) + .tuples() + ) + if len(q) > 1: + raise ValueError(f"No apred given, and multiple apred values found in input data products: {q}") + self.apred, = q + log.info(f"Setting apred={self.apred}") + ''' parameters = dict( - release=self.product_release, + release=self.release, run2d=self.run2d, apred=self.apred, ) - expression = DataProduct.filetype.in_(("apVisit", "specFull")) + expression = DataProduct.id.in_(data_product_ids) + # DataProduct.filetype.in_(("apVisit", "specFull")) + #& + #) + ''' if self.apred_release is not None and self.apred is not None: sub = DataProduct.filetype == "apVisit" if self.apred_release is not None: @@ -302,6 +337,14 @@ def execute(self, context): sub &= DataProduct.kwargs["run2d"] == self.run2d expression |= sub + ''' + catalogids = flatten( + SourceDataProduct + .select(SourceDataProduct.source_id) + .distinct() + .where(SourceDataProduct.data_product_id.in_(data_product_ids)) + .tuples() + ) # Create the tasks first in bulk. # TODO: use insert_many instead diff --git a/python/astra/tools/continuum/scalar.py b/python/astra/tools/continuum/scalar.py index 65ac348..d6fd414 100644 --- a/python/astra/tools/continuum/scalar.py +++ b/python/astra/tools/continuum/scalar.py @@ -13,8 +13,8 @@ class Scalar(Continuum): """Represent the stellar continuum with a scalar.""" available_methods = { - "mean": np.mean, - "median": np.median, + "mean": lambda a: np.nanmean(a, axis=1), + "median": lambda a: np.nanmedian(a, axis=1) } def __init__(