Skip to content

Commit

Permalink
handling of errors
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Oct 3, 2022
1 parent bcd7cbd commit 5bd236f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 56 deletions.
40 changes: 22 additions & 18 deletions bin/astra
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,7 @@ def execute():
"""Execute a task or bundle."""
pass


@execute.command(hidden=True)
@click.argument("task_ids", type=int, nargs=-1, required=True)
def tasks(task_ids):
"""Execute multiple tasks"""
def _execute_tasks(task_ids):
from astra import log
from astra.utils import deserialize
from astra.database.astradb import Task
Expand All @@ -261,23 +257,20 @@ def tasks(task_ids):
return None


@execute.command(hidden=True)
@click.argument("task_ids", type=int, nargs=-1, required=True)
def tasks(task_ids):
"""Execute multiple tasks"""
return _execute_tasks(task_ids)

@execute.command()
@click.argument("task_id", type=int, nargs=1, required=True)
def task(task_id):
"""Execute a single task"""
return tasks([task_id])
return _execute_tasks([task_id])


@execute.command(hidden=True)
@click.option(
"--only-incomplete",
"only_incomplete",
is_flag=True,
help="Only execute the incomplete tasks in the bundle"
)
@click.argument("bundle_ids", type=int, nargs=-1, required=True)
def bundles(bundle_ids, only_incomplete):
"""Execute bundles"""
def _execute_bundles(bundle_ids, only_incomplete):
from astra import log
from astra.utils import deserialize
from astra.database.astradb import Bundle
Expand All @@ -293,16 +286,27 @@ def bundles(bundle_ids, only_incomplete):
return None


@execute.command(hidden=True)
@click.option(
"--only-incomplete",
is_flag=True,
help="Only execute the incomplete tasks in the bundle"
)
@click.argument("bundle_ids", type=int, nargs=-1, required=True)
def bundles(bundle_ids, only_incomplete):
"""Execute bundles"""
_execute_bundles(bundle_ids, only_incomplete)


@execute.command()
@click.option(
"--only-incomplete",
"only_incomplete",
is_flag=True,
help="Only execute the incomplete tasks in the bundle"
)
@click.argument("bundle_id", type=int, nargs=-1, required=True)
def bundle(bundle_id, only_incomplete):
return bundles([bundle_id], only_incomplete)
return _execute_bundles([bundle_id], only_incomplete)



Expand Down
60 changes: 38 additions & 22 deletions python/astra/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,11 @@ def from_bundle(cls, bundle, only_incomplete=False, strict=True):
log.warn(f"Restricting to tasks in bundle {bundle} that are incomplete")
q = q.where(Task.status != Status.get(description="completed"))

tasks = list(tasks)
tasks = list(q)
bundle_size = len(tasks)
if bundle_size == 0:
raise EmptyBundleException(f"No tasks found for bundle {bundle} with only_incomplete={only_incomplete}")

context = {
"input_data_products": [task.input_data_products for task in tasks],
"tasks": tasks,
Expand Down Expand Up @@ -480,7 +483,7 @@ def parse_parameters(cls, **kwargs):
parsed[name] = (parameter, value, default, length, indexed)
return (bundle_size, parsed)

def iterable(self, stage=None):
def iterable(self, stage=None, debug=True):
"""
Iterate over the tasks in the bundle.
Expand Down Expand Up @@ -509,27 +512,36 @@ def iterable(self, stage=None):
t_init = time()
for i, item in enumerate(self.context["iterable"]):

yield item

if stage is not None:
t_iterable = time() - t_init
try:
self.context["timing"][key][i] += t_iterable
except IndexError: # array is not long enough, we're in append mode
self.context["timing"][key].append(t_iterable)
finally:
t_init = time()

# Update status for this task.
print(f"update timing??")
'''
try:
task, idp, parameters = item
task.status_id = 5
try:
yield item
except:
task, idp, parameters = item
log.exception(f"Exception in executing {i}th task in bundle ({task}):")
if stage is not None:
status = Status.get(description=f"failed-{stage.replace('_', '-')}")
task.status_id = status.id
task.save()
except ValueError:
log.warning(f"Couldn't update task in status {task} {stage}")
'''
raise
continue
else:
if stage is not None:
t_iterable = time() - t_init
try:
self.context["timing"][key][i] += t_iterable
except IndexError: # array is not long enough, we're in append mode
self.context["timing"][key].append(t_iterable)
finally:
t_init = time()

# Update status for this task. if we set
# TODO: Get the correct status ID
try:
task, idp, parameters = item
task.status_id = 5 # TODO
task.completed = datetime.datetime.now()
task.save()
except ValueError:
log.warning(f"Couldn't update task in status {task} {stage}")

# fin

Expand Down Expand Up @@ -750,3 +762,7 @@ def get_or_create_data_products(iterable):
else:
dps.append(dp)
return dps


class EmptyBundleException(Exception):
pass
71 changes: 57 additions & 14 deletions python/astra/sdss/operators/mwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SourceDataProduct,
MWMSourceStatus
)
from peewee import fn
from airflow.exceptions import AirflowSkipException
from astra.utils import flatten
from astra import log, __version__ as astra_version
Expand Down Expand Up @@ -246,20 +247,18 @@ class MWMVisitStarFactory(BaseOperator):
def __init__(
self,
*,
product_release: Optional[str] = None,
apred_release: Optional[str] = None,
release: Optional[str] = None,
#apred_release: Optional[str] = None,
apred: Optional[str] = None,
run2d_release: Optional[str] = None,
#run2d_release: Optional[str] = None,
run2d: Optional[str] = None,
num_bundles: Optional[int] = 1,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.product_release = product_release
self.release = release
self.apred = apred
self.apred_release = apred_release
self.run2d = run2d
self.run2d_release = run2d_release
self.num_bundles = num_bundles

if num_bundles < 1:
Expand All @@ -270,21 +269,57 @@ def execute(self, context):

ti, task = (context["ti"], context["task"])

#if catalogids is None:
# catalogids = map(
# int, tuple(set(flatten(ti.xcom_pull(task_ids=task.upstream_task_ids))))
# )
#else:
catalogids = flatten(Source.select(Source.catalogid).tuples())
# Get the data products identifiers from upstream operators
data_product_ids = map(int, tuple(set(flatten(ti.xcom_pull(task_ids=task.upstream_task_ids)))))
log.info(f"Found {len(data_product_ids)} data product identifiers")

# TODO: we could get run2d/apred from upstream, but should we?
'''
if self.run2d is None:
# Get run2d from the input data products
q = flatten(
DataProduct
.select(fn.DISTINCT(DataProduct.kwargs["run2d"]))
.where(
(DataProduct.filetype == "specFull")
& (DataProduct.id.in_(data_product_ids))
)
.tuples()
)
if len(q) > 1:
raise ValueError(f"No run2d given, and multiple run2d values found in input data products: {q}")
self.run2d, = q
log.info(f"Setting run2d={self.run2d}")
if self.apred is None:
# Get apred from the input data products.
q = flatten(
DataProduct
.select(fn.DISTINCT(DataProduct.kwargs["apred"]))
.where(
(DataProduct.filetype == "apVisit")
& (DataProduct.id.in_(data_product_ids))
)
.tuples()
)
if len(q) > 1:
raise ValueError(f"No apred given, and multiple apred values found in input data products: {q}")
self.apred, = q
log.info(f"Setting apred={self.apred}")
'''
parameters = dict(
release=self.product_release,
release=self.release,
run2d=self.run2d,
apred=self.apred,
)

expression = DataProduct.filetype.in_(("apVisit", "specFull"))
expression = DataProduct.id.in_(data_product_ids)
# DataProduct.filetype.in_(("apVisit", "specFull"))
#&
#)

'''
if self.apred_release is not None and self.apred is not None:
sub = DataProduct.filetype == "apVisit"
if self.apred_release is not None:
Expand All @@ -302,6 +337,14 @@ def execute(self, context):
sub &= DataProduct.kwargs["run2d"] == self.run2d
expression |= sub
'''
catalogids = flatten(
SourceDataProduct
.select(SourceDataProduct.source_id)
.distinct()
.where(SourceDataProduct.data_product_id.in_(data_product_ids))
.tuples()
)

# Create the tasks first in bulk.
# TODO: use insert_many instead
Expand Down
4 changes: 2 additions & 2 deletions python/astra/tools/continuum/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class Scalar(Continuum):
"""Represent the stellar continuum with a scalar."""

available_methods = {
"mean": np.mean,
"median": np.median,
"mean": lambda a: np.nanmean(a, axis=1),
"median": lambda a: np.nanmedian(a, axis=1)
}

def __init__(
Expand Down

0 comments on commit 5bd236f

Please sign in to comment.