-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enables running of multiple tasks in batches #1538
Changes from 1 commit
dd1b095
3802b29
cc73a92
0c90de0
188c6e8
9f73c33
0e87f50
fbba3d2
ca275a6
8007bf9
694fbbd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,8 +39,8 @@ | |
from luigi import notifications | ||
from luigi import parameter | ||
from luigi import task_history as history | ||
from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN | ||
from luigi.task import Config | ||
from luigi.task_status import BATCH_RUNNING, DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN | ||
from luigi.task import Config, task_id_str | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -71,12 +71,25 @@ class Scheduler(object): | |
STATUS_TO_UPSTREAM_MAP = { | ||
FAILED: UPSTREAM_FAILED, | ||
RUNNING: UPSTREAM_RUNNING, | ||
BATCH_RUNNING: UPSTREAM_RUNNING, | ||
PENDING: UPSTREAM_MISSING_INPUT, | ||
DISABLED: UPSTREAM_DISABLED, | ||
} | ||
|
||
TASK_FAMILY_RE = re.compile(r'([^(_]+)[(_]') | ||
|
||
TASKS = 'tasks' | ||
ACTIVE_WORKERS = 'active_workers' | ||
BATCH_TASKS = 'batch_tasks' | ||
RUNNING_BATCHES = 'running_batches' | ||
|
||
AGGREGATE_FUNCTIONS = { | ||
'csv': ','.join, | ||
'max': max, | ||
'min': min, | ||
'range': lambda args: '%s-%s' % (min(args), max(args)), | ||
} | ||
|
||
|
||
class scheduler(Config): | ||
# TODO(erikbern): the config_path is needed for backwards compatilibity. We | ||
|
@@ -162,7 +175,7 @@ class Task(object): | |
|
||
def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, | ||
params=None, disable_failures=None, disable_window=None, disable_hard_timeout=None, | ||
tracking_url=None, status_message=None): | ||
tracking_url=None, status_message=None, is_batch=False): | ||
self.id = task_id | ||
self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) | ||
self.workers = set() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active | ||
|
@@ -190,6 +203,8 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', | |
self.status_message = status_message | ||
self.scheduler_disable_time = None | ||
self.runnable = False | ||
self.is_batch = is_batch | ||
self.batchable = False | ||
|
||
def __repr__(self): | ||
return "Task(%r)" % vars(self) | ||
|
@@ -214,6 +229,31 @@ def pretty_id(self): | |
return '{}({})'.format(self.family, param_str) | ||
|
||
|
||
class TaskBatcher(object): | ||
def __init__(self, aggregate_args, max_batch_size): | ||
self.aggregates = aggregate_args | ||
self.max_batch_size = max_batch_size | ||
|
||
def task_id(self, tasks): | ||
if self.max_batch_size is not None and len(tasks) > self.max_batch_size: | ||
return None, None | ||
params = {} | ||
for arg in tasks[0].params: | ||
raw_vals = [task.params[arg] for task in tasks] | ||
agg_function = self.aggregates.get(arg) | ||
if agg_function is not None: | ||
arg_val = AGGREGATE_FUNCTIONS[agg_function](raw_vals) | ||
elif any(v != raw_vals[0] for v in raw_vals): | ||
return None, None | ||
else: | ||
arg_val = raw_vals[0] | ||
params[arg] = arg_val | ||
|
||
family = tasks[0].family | ||
assert all(task.family == family for task in tasks) | ||
return task_id_str(family, params), params | ||
|
||
|
||
class Worker(object): | ||
""" | ||
Structure for tracking worker activity and keeping their references. | ||
|
@@ -287,14 +327,28 @@ class SimpleTaskState(object): | |
def __init__(self, state_path): | ||
self._state_path = state_path | ||
self._tasks = {} # map from id to a Task object | ||
self._batch_tasks = {} # map from family to TaskBatch object | ||
self._running_batches = {} # map from id to list of batched task ids | ||
self._status_tasks = collections.defaultdict(dict) | ||
self._active_workers = {} # map from id to a Worker object | ||
|
||
def get_state(self): | ||
return self._tasks, self._active_workers | ||
return { | ||
TASKS: self._tasks, | ||
ACTIVE_WORKERS: self._active_workers, | ||
BATCH_TASKS: self._batch_tasks, | ||
RUNNING_BATCHES: self._running_batches, | ||
} | ||
|
||
def set_state(self, state): | ||
self._tasks, self._active_workers = state | ||
if isinstance(state, dict): | ||
self._tasks = state.get(TASKS, {}) | ||
self._active_workers = state.get(ACTIVE_WORKERS, {}) | ||
self._batch_tasks = state.get(BATCH_TASKS, {}) | ||
self._running_batches = state.get(RUNNING_BATCHES, {}) | ||
else: | ||
self._tasks, self._active_workers = state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a note that this is only for backward compatibility and this codepath will be removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._batch_tasks = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def dump(self): | ||
try: | ||
|
@@ -338,6 +392,37 @@ def get_pending_tasks(self): | |
return itertools.chain.from_iterable(six.itervalues(self._status_tasks[status]) | ||
for status in [PENDING, RUNNING]) | ||
|
||
def get_batch(self, worker_id, tasks): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this name confusing? Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if len(tasks) == 1: | ||
return tasks[0] | ||
families = set(task.family for task in tasks) | ||
if len(families) != 1: | ||
return None | ||
family = families.pop() | ||
batch_task = self.get_batcher(worker_id, family) | ||
if batch_task is None: | ||
return None | ||
task_id, params = batch_task.task_id(tasks) | ||
if task_id is None: | ||
return None | ||
priority = max(task.priority for task in tasks) | ||
resource_keys = functools.reduce(set.union, (task.resources.keys() for task in tasks), set()) | ||
resources = {key: max(task.resources.get(key, 0) for task in tasks) for key in resource_keys} | ||
deps = functools.reduce(set.union, (task.deps for task in tasks), set()) | ||
batch_task_obj = Task( | ||
task_id=task_id, | ||
status=PENDING, | ||
deps=deps, | ||
priority=priority, | ||
resources=resources, | ||
family=family, | ||
params=params, | ||
is_batch=True, | ||
) | ||
batch_task_obj.stakeholders.add(worker_id) | ||
batch_task_obj.workers.add(worker_id) | ||
return batch_task_obj | ||
|
||
def num_pending_tasks(self): | ||
""" | ||
Return how many tasks are PENDING + RUNNING. O(1). | ||
|
@@ -352,6 +437,14 @@ def get_task(self, task_id, default=None, setdefault=None): | |
else: | ||
return self._tasks.get(task_id, default) | ||
|
||
def get_batcher(self, worker, family): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you be explicit and write There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return self._batch_tasks.get((worker, family)) | ||
|
||
def set_batcher(self, worker, family, batcher_aggregate_args, max_batch_size): | ||
batcher = TaskBatcher(batcher_aggregate_args, max_batch_size) | ||
self._batch_tasks[(worker, family)] = batcher | ||
return batcher | ||
|
||
def has_task(self, task_id): | ||
return task_id in self._tasks | ||
|
||
|
@@ -362,11 +455,11 @@ def re_enable(self, task, config=None): | |
self.set_status(task, FAILED, config) | ||
task.failures.clear() | ||
|
||
def set_status(self, task, new_status, config=None): | ||
def set_status(self, task, new_status, config=None, batch=None): | ||
if new_status == FAILED: | ||
assert config is not None | ||
|
||
if new_status == DISABLED and task.status == RUNNING: | ||
if new_status == DISABLED and task.status in (RUNNING, BATCH_RUNNING): | ||
return | ||
|
||
if task.status == DISABLED: | ||
|
@@ -377,7 +470,7 @@ def set_status(self, task, new_status, config=None): | |
elif task.scheduler_disable_time is not None and new_status != DISABLED: | ||
return | ||
|
||
if new_status == FAILED and task.status != DISABLED: | ||
if new_status == FAILED and task.status != DISABLED and not task.is_batch: | ||
task.add_failure() | ||
if task.has_excessive_failures(): | ||
task.scheduler_disable_time = time.time() | ||
|
@@ -394,6 +487,17 @@ def set_status(self, task, new_status, config=None): | |
elif new_status == DISABLED: | ||
task.scheduler_disable_time = None | ||
|
||
if new_status == RUNNING and batch: | ||
self._running_batches[task.id] = set(batch) | ||
elif task.id in self._running_batches and new_status != RUNNING: | ||
subtask_status = FAILED if new_status == DISABLED else new_status | ||
for subtask_id in self._running_batches.pop(task.id): | ||
subtask = self.get_task(subtask_id) | ||
if subtask_status == FAILED: | ||
subtask.retry = time.time() + config.retry_delay | ||
self.set_status(subtask, subtask_status, config) | ||
subtask.expl = task.expl | ||
|
||
if new_status != task.status: | ||
self._status_tasks[task.status].pop(task.id) | ||
self._status_tasks[new_status][task.id] = task | ||
|
@@ -402,7 +506,7 @@ def set_status(self, task, new_status, config=None): | |
|
||
def fail_dead_worker_task(self, task, config, assistants): | ||
# If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic | ||
if task.status == RUNNING and task.worker_running and task.worker_running not in task.stakeholders | assistants: | ||
if task.status in (RUNNING, BATCH_RUNNING) and task.worker_running and task.worker_running not in task.stakeholders | assistants: | ||
logger.info("Task %r is marked as running by disconnected worker %r -> marking as " | ||
"FAILED with retry delay of %rs", task.id, task.worker_running, | ||
config.retry_delay) | ||
|
@@ -428,7 +532,10 @@ def update_status(self, task, config): | |
self.set_status(task, PENDING, config) | ||
|
||
def may_prune(self, task): | ||
return task.remove and time.time() > task.remove | ||
return ( | ||
(task.remove and time.time() > task.remove) or | ||
(task.is_batch and task.status != RUNNING) | ||
) | ||
|
||
def inactivate_tasks(self, delete_tasks): | ||
# The terminology is a bit confusing: we used to "delete" tasks when they became inactive, | ||
|
@@ -483,6 +590,13 @@ def get_necessary_tasks(self): | |
necessary_tasks.add(task.id) | ||
return necessary_tasks | ||
|
||
def update_tracking_url(self, task, tracking_url): | ||
task.tracking_url = tracking_url | ||
for batch_task_id in self._running_batches.get(task.id, []): | ||
batch_task = self.get_task(batch_task_id) | ||
if batch_task: | ||
batch_task.tracking_url = tracking_url | ||
|
||
|
||
class CentralPlannerScheduler(Scheduler): | ||
""" | ||
|
@@ -575,10 +689,14 @@ def _update_priority(self, task, prio, worker): | |
if t is not None and prio > t.priority: | ||
self._update_priority(t, prio, worker) | ||
|
||
def add_task_batcher(self, worker, family, batcher_aggregate_args, max_batch_size=None): | ||
self._state.set_batcher(worker, family, batcher_aggregate_args, max_batch_size) | ||
|
||
def add_task(self, task_id=None, status=PENDING, runnable=True, | ||
deps=None, new_deps=None, expl=None, resources=None, | ||
priority=0, family='', module=None, params=None, | ||
assistant=False, tracking_url=None, **kwargs): | ||
assistant=False, tracking_url=None, batchable=None, | ||
**kwargs): | ||
""" | ||
* add task identified by task_id if it doesn't exist | ||
* if deps is not None, update dependency list | ||
|
@@ -611,15 +729,18 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, | |
task.params = _get_default(params, {}) | ||
|
||
if tracking_url is not None or task.status != RUNNING: | ||
task.tracking_url = tracking_url | ||
self._state.update_tracking_url(task, tracking_url) | ||
|
||
if task.remove is not None: | ||
task.remove = None # unmark task for removal so it isn't removed after being added | ||
|
||
if expl is not None: | ||
task.expl = expl | ||
|
||
if not (task.status == RUNNING and status == PENDING) or new_deps: | ||
if batchable is not None: | ||
task.batchable = batchable | ||
|
||
if not (task.status in (RUNNING, BATCH_RUNNING) and status == PENDING) or new_deps: | ||
# don't allow re-scheduling of task while it is running, it must either fail or succeed first | ||
if status == PENDING or status != task.status: | ||
# Update the DB only if there was a acctual change, to prevent noise. | ||
|
@@ -735,6 +856,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): | |
for task in sorted(self._state.get_running_tasks(), key=self._rank): | ||
if task.worker_running == worker_id and task.id not in ct_set: | ||
best_task = task | ||
best_tasks = [] if best_task is None else [best_task] | ||
|
||
locally_pending_tasks = 0 | ||
running_tasks = [] | ||
|
@@ -776,7 +898,13 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): | |
if len(task.workers) == 1 and not assistant: | ||
n_unique_pending += 1 | ||
|
||
if best_task: | ||
if (in_workers and best_tasks and task.batchable and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment above this line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._state.get_batch(worker_id, best_tasks + [task]) is not None and | ||
self._schedulable(task) and | ||
self._has_resources(task.resources, greedy_resources)): | ||
best_tasks.append(task) | ||
|
||
if best_tasks: | ||
continue | ||
|
||
if task.status == RUNNING and (task.worker_running in greedy_workers): | ||
|
@@ -786,7 +914,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): | |
|
||
if self._schedulable(task) and self._has_resources(task.resources, greedy_resources): | ||
if in_workers and self._has_resources(task.resources, used_resources): | ||
best_task = task | ||
best_tasks = [task] | ||
else: | ||
workers = itertools.chain(task.workers, [worker_id]) if assistant else task.workers | ||
for task_worker in workers: | ||
|
@@ -805,8 +933,18 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs): | |
'task_id': None, | ||
'n_unique_pending': n_unique_pending} | ||
|
||
if best_task: | ||
self._state.set_status(best_task, RUNNING, self._config) | ||
if best_tasks: | ||
best_batch = self._state.get_batch(worker_id, best_tasks) | ||
best_task = self._state.get_task(best_batch.id, setdefault=best_batch) | ||
for task in best_tasks: | ||
if task == best_task: | ||
continue | ||
self._state.set_status(task, BATCH_RUNNING, self._config) | ||
task.worker_running = worker_id | ||
task.time_running = time.time() | ||
self._update_task_history(task, BATCH_RUNNING, host=host) | ||
batch_tasks = [task.id for task in best_tasks if task != best_task] | ||
self._state.set_status(best_task, RUNNING, self._config, batch=batch_tasks) | ||
best_task.worker_running = worker_id | ||
best_task.time_running = time.time() | ||
self._update_task_history(best_task, RUNNING, host=host) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For such a complicated parameter I hoped for more docs. ^^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops! Not sure how that happened.