Skip to content

Commit

Permalink
Enables running of multiple tasks in batches
Browse files Browse the repository at this point in the history
Sometimes it's more efficient to run a group of tasks all at once rather than
one at a time. With luigi, it's difficult to take advantage of this because your
batch size will also be the minimum granularity you're able to compute. So if
you have a job that runs hourly, you can't combine their computation when many
of them get backlogged. When you have a task that runs daily, you can't get
hourly runs.

In order to gain efficiency when many jobs are queued up, this change allows
workers to provide details of how jobs can be batched to the scheduler. If you
have several hourly jobs of the same type in the scheduler, it can combine them
into a single job for the worker. We allow parameters to be combined in three
ways: we can combine all the arguments in a csv, take the min and max to form
a range, or just provide the min or max. The csv gives the most specificity,
but range and min/max are available for when that's all you need. In particular,
the max function provides an implementation of #570, allowing for jobs that
overwrite eachother to be grouped by just running the largest one.

In order to implement this, the scheduler will create a new task based on the
information sent by the worker. It's possible (as in the max/min case) that the
new task already exists, but if it doesn't it will be cleaned up at the end of
the run. While this new task is running, any other tasks will be marked as
BATCH_RUNNING. When the head task becomes DONE or FAILED, the BATCH_RUNNING
tasks will also be updated accordingly. They'll also have their tracking urls
updated to match the batch task.

This is a fairly big change to how the scheduler works, so there are a few
issues with it in the initial implementation:
  - newly created batch tasks don't show up in dependency graphs
  - the run summary doesn't know what happened to the batched tasks
  - we can't limit how big batches can be (how should we handle ranges?)
  - batching takes quadratic time for simplicity of implementation
  - I'm not sure what would happen if there was a yield in a batch run function

On the worker side, batching is accomplished by setting a batch_class,
batcher_args and batcher_aggregate_args. The batch class is the Python class
that runs the batched version of the job. This can be set equal to the current
class by overriding the class method get_batch_class.

The batcher_args are the arguments passed from the current class to the batch
class. These come in pairs. So if the original class has parameters machine and
filename that need to go to host and files in the batcher, you'll use
  [('machine', 'host'), ('filename', 'files')]
for batcher_args.

The final value is batcher_aggregate_args, which explains which arguments are to
be aggregated and how. So using the machine, filename example, we might want to
batch multiple files together for the same machine. For that, we could do
something like
  {'filename': 'csv'}
to combine them all as comma-separated values. Now if we have multiple machine,
filename pairs such as ('m1', 'f1'), ('m1', 'f2'), ('m2', 'f3'), ('m2', 'f4'),
we'd end up with batch jobs with host, files pairs of ('m1', 'f1,f2') and
('m2', 'f3,f4').

The worker will send the batch class, batcher args and batcher aggregate args to
the worker once per class, which is why these are class methods. It doesn't make
sense to have different ways to batch per individual task, so that's not
allowed.
  • Loading branch information
daveFNbuck committed Feb 9, 2016
1 parent f771622 commit bcde8bb
Show file tree
Hide file tree
Showing 11 changed files with 637 additions and 45 deletions.
13 changes: 12 additions & 1 deletion luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def ping(self, worker):
def add_task(self, worker, task_id, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None, priority=0,
family='', module=None, params=None, assistant=False,
tracking_url=None):
tracking_url=None, batchable=None):
self._request('/api/add_task', {
'task_id': task_id,
'worker': worker,
Expand All @@ -172,6 +172,17 @@ def add_task(self, worker, task_id, status=PENDING, runnable=True,
'params': params,
'assistant': assistant,
'tracking_url': tracking_url,
'batchable': batchable,
})

def add_task_batcher(self, worker, family, batcher_family, batcher_args,
batcher_aggregate_args):
self._request('/api/add_task_batcher', {
'worker': worker,
'family': family,
'batcher_family': batcher_family,
'batcher_args': batcher_args,
'batcher_aggregate_args': batcher_aggregate_args,
})

def get_work(self, worker, host=None, assistant=False, current_tasks=None):
Expand Down
175 changes: 159 additions & 16 deletions luigi/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from luigi import notifications
from luigi import parameter
from luigi import task_history as history
from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN
from luigi.task import Config
from luigi.task_status import BATCH_RUNNING, DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN
from luigi.task import Config, task_id_str

logger = logging.getLogger("luigi.server")

Expand Down Expand Up @@ -71,10 +71,23 @@ class Scheduler(object):
STATUS_TO_UPSTREAM_MAP = {
FAILED: UPSTREAM_FAILED,
RUNNING: UPSTREAM_RUNNING,
BATCH_RUNNING: UPSTREAM_RUNNING,
PENDING: UPSTREAM_MISSING_INPUT,
DISABLED: UPSTREAM_DISABLED,
}

TASKS = 'tasks'
ACTIVE_WORKERS = 'active_workers'
BATCH_TASKS = 'batch_tasks'
RUNNING_BATCHES = 'running_batches'

AGGREGATE_FUNCTIONS = {
'csv': ','.join,
'max': max,
'min': min,
'range': lambda args: '%s-%s' % (min(args), max(args)),
}


class scheduler(Config):
# TODO(erikbern): the config_path is needed for backwards compatilibity. We should drop the compatibility
Expand Down Expand Up @@ -170,7 +183,7 @@ class Task(object):

def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None,
params=None, disable_failures=None, disable_window=None, disable_hard_timeout=None,
tracking_url=None):
tracking_url=None, is_batch=False):
self.id = task_id
self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active)
self.workers = set() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active
Expand All @@ -196,6 +209,8 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='',
self.tracking_url = tracking_url
self.scheduler_disable_time = None
self.runnable = False
self.is_batch = is_batch
self.batchable = False

def __repr__(self):
return "Task(%r)" % vars(self)
Expand All @@ -220,6 +235,27 @@ def can_disable(self):
self.disable_hard_timeout is not None)


class TaskBatcher(object):
def __init__(self, family, args, aggregate_args):
self.family = family
self.args = args
self.aggregates = aggregate_args

def task_id(self, tasks):
params = {}
for task_arg, batch_arg in self.args:
raw_vals = [task.params[task_arg] for task in tasks]
agg_function = self.aggregates.get(task_arg)
if agg_function:
arg_val = AGGREGATE_FUNCTIONS[agg_function](raw_vals)
elif any(v != raw_vals[0] for v in raw_vals):
return None, None
else:
arg_val = raw_vals[0]
params[batch_arg] = arg_val
return task_id_str(self.family, params), params


class Worker(object):
"""
Structure for tracking worker activity and keeping their references.
Expand Down Expand Up @@ -292,14 +328,28 @@ class SimpleTaskState(object):
def __init__(self, state_path):
self._state_path = state_path
self._tasks = {} # map from id to a Task object
self._batch_tasks = {} # map from family to TaskBatch object
self._running_batches = {} # map from id to list of batched task ids
self._status_tasks = collections.defaultdict(dict)
self._active_workers = {} # map from id to a Worker object

def get_state(self):
return self._tasks, self._active_workers
return {
TASKS: self._tasks,
ACTIVE_WORKERS: self._active_workers,
BATCH_TASKS: self._batch_tasks,
RUNNING_BATCHES: self._running_batches,
}

def set_state(self, state):
self._tasks, self._active_workers = state
if isinstance(state, dict):
self._tasks = state.get(TASKS, {})
self._active_workers = state.get(ACTIVE_WORKERS, {})
self._batch_tasks = state.get(BATCH_TASKS, {})
self._running_batches = state.get(RUNNING_BATCHES, {})
else:
self._tasks, self._active_workers = state
self._batch_tasks = {}

def dump(self):
try:
Expand Down Expand Up @@ -349,6 +399,13 @@ def load(self):
if any(not hasattr(t, 'disable_hard_timeout') for t in six.itervalues(self._tasks)):
for t in six.itervalues(self._tasks):
t.disable_hard_timeout = None

# Compatibility since 2016-02-03
if any(not hasattr(t, 'is_batch') for t in six.itervalues(self._tasks)):
for t in six.itervalues(self._tasks):
t.batchable = False
t.is_batch = False

else:
logger.info("No prior state file exists at %s. Starting with clean slate", self._state_path)

Expand All @@ -367,6 +424,37 @@ def get_pending_tasks(self):
return itertools.chain.from_iterable(six.itervalues(self._status_tasks[status])
for status in [PENDING, RUNNING])

def get_batch(self, worker_id, tasks):
if len(tasks) == 1:
return tasks[0]
families = set(task.family for task in tasks)
if len(families) != 1:
return None
family = families.pop()
batch_task = self.get_batcher(worker_id, family)
if batch_task is None:
return None
task_id, params = batch_task.task_id(tasks)
if task_id is None:
return None
priority = max(task.priority for task in tasks)
resource_keys = functools.reduce(set.union, (task.resources.keys() for task in tasks), set())
resources = {key: max(task.resources.get(key, 0) for task in tasks) for key in resource_keys}
deps = functools.reduce(set.union, (task.deps for task in tasks), set())
batch_task_obj = Task(
task_id=task_id,
status=PENDING,
deps=deps,
priority=priority,
resources=resources,
family=batch_task.family,
params=params,
is_batch=True,
)
batch_task_obj.stakeholders.add(worker_id)
batch_task_obj.workers.add(worker_id)
return batch_task_obj

def num_pending_tasks(self):
"""
Return how many tasks are PENDING + RUNNING. O(1).
Expand All @@ -381,6 +469,14 @@ def get_task(self, task_id, default=None, setdefault=None):
else:
return self._tasks.get(task_id, default)

def get_batcher(self, worker, family):
return self._batch_tasks.get((worker, family))

def set_batcher(self, worker, family, batcher_family, batcher_args, batcher_aggregate_args):
batcher = TaskBatcher(batcher_family, batcher_args, batcher_aggregate_args)
self._batch_tasks[(worker, family)] = batcher
return batcher

def has_task(self, task_id):
return task_id in self._tasks

Expand All @@ -391,11 +487,11 @@ def re_enable(self, task, config=None):
self.set_status(task, FAILED, config)
task.failures.clear()

def set_status(self, task, new_status, config=None):
def set_status(self, task, new_status, config=None, batch=None):
if new_status == FAILED:
assert config is not None

if new_status == DISABLED and task.status == RUNNING:
if new_status == DISABLED and task.status in (RUNNING, BATCH_RUNNING):
return

if task.status == DISABLED:
Expand All @@ -406,7 +502,7 @@ def set_status(self, task, new_status, config=None):
elif task.scheduler_disable_time is not None and new_status != DISABLED:
return

if new_status == FAILED and task.can_disable() and task.status != DISABLED:
if new_status == FAILED and task.can_disable() and task.status != DISABLED and not task.is_batch:
task.add_failure()
if task.has_excessive_failures():
task.scheduler_disable_time = time.time()
Expand All @@ -423,13 +519,25 @@ def set_status(self, task, new_status, config=None):
elif new_status == DISABLED:
task.scheduler_disable_time = None

if new_status == RUNNING and batch:
self._running_batches[task.id] = set(batch)
elif task.id in self._running_batches and new_status != RUNNING:
for subtask_id in self._running_batches.pop(task.id):
subtask = self.get_task(subtask_id)
if new_status == FAILED:
subtask.retry = time.time() + config.retry_delay
self.set_status(subtask, new_status, config)

if task.id in self._running_batches and new_status in (DONE, FAILED):
del self._running_batches[task.id]

self._status_tasks[task.status].pop(task.id)
self._status_tasks[new_status][task.id] = task
task.status = new_status

def fail_dead_worker_task(self, task, config, assistants):
# If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic
if task.status == RUNNING and task.worker_running and task.worker_running not in task.stakeholders | assistants:
if task.status in (RUNNING, BATCH_RUNNING) and task.worker_running and task.worker_running not in task.stakeholders | assistants:
logger.info("Task %r is marked as running by disconnected worker %r -> marking as "
"FAILED with retry delay of %rs", task.id, task.worker_running,
config.retry_delay)
Expand Down Expand Up @@ -457,6 +565,10 @@ def prune(self, task, config):
logger.info("Removing task %r (no connected stakeholders)", task.id)
remove = True

if task.is_batch and task.status != RUNNING:
logger.info("Removing batch task %r", task.id)
remove = True

# Reset FAILED tasks to PENDING if max timeout is reached, and retry delay is >= 0
if task.status == FAILED and config.retry_delay >= 0 and task.retry < time.time():
self.set_status(task, PENDING, config)
Expand Down Expand Up @@ -509,6 +621,13 @@ def get_necessary_tasks(self):
necessary_tasks.add(task.id)
return necessary_tasks

def update_tracking_url(self, task, tracking_url):
task.tracking_url = tracking_url
for batch_task_id in self._running_batches.get(task.id, []):
batch_task = self.get_task(batch_task_id)
if batch_task:
batch_task.tracking_url = tracking_url


class CentralPlannerScheduler(Scheduler):
"""
Expand Down Expand Up @@ -595,10 +714,14 @@ def _update_priority(self, task, prio, worker):
if t is not None and prio > t.priority:
self._update_priority(t, prio, worker)

def add_task_batcher(self, worker, family, batcher_family, batcher_args, batcher_aggregate_args):
self._state.set_batcher(worker, family, batcher_family, batcher_args, batcher_aggregate_args)

def add_task(self, task_id=None, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None,
priority=0, family='', module=None, params=None,
assistant=False, tracking_url=None, **kwargs):
assistant=False, tracking_url=None, batchable=None,
**kwargs):
"""
* add task identified by task_id if it doesn't exist
* if deps is not None, update dependency list
Expand All @@ -622,15 +745,18 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
task.params = _get_default(params, {})

if tracking_url is not None or task.status != RUNNING:
task.tracking_url = tracking_url
self._state.update_tracking_url(task, tracking_url)

if task.remove is not None:
task.remove = None # unmark task for removal so it isn't removed after being added

if expl is not None:
task.expl = expl

if not (task.status == RUNNING and status == PENDING) or new_deps:
if batchable is not None:
task.batchable = batchable

if not (task.status in (RUNNING, BATCH_RUNNING) and status == PENDING) or new_deps:
# don't allow re-scheduling of task while it is running, it must either fail or succeed first
if status == PENDING or status != task.status:
# Update the DB only if there was a acctual change, to prevent noise.
Expand Down Expand Up @@ -743,6 +869,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
for task in sorted(self._state.get_running_tasks(), key=self._rank):
if task.worker_running == worker_id and task.id not in ct_set:
best_task = task
best_tasks = [] if best_task is None else [best_task]

locally_pending_tasks = 0
running_tasks = []
Expand Down Expand Up @@ -783,7 +910,13 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
if len(task.workers) == 1 and not assistant:
n_unique_pending += 1

if best_task:
if (best_tasks and task.batchable and
self._state.get_batch(worker_id, best_tasks + [task]) is not None and
self._schedulable(task) and
self._has_resources(task.resources, greedy_resources)):
best_tasks.append(task)

if best_tasks:
continue

if task.status == RUNNING and (task.worker_running in greedy_workers):
Expand All @@ -793,7 +926,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):

if self._schedulable(task) and self._has_resources(task.resources, greedy_resources):
if in_workers and self._has_resources(task.resources, used_resources):
best_task = task
best_tasks = [task]
else:
workers = itertools.chain(task.workers, [worker_id]) if assistant else task.workers
for task_worker in workers:
Expand All @@ -812,8 +945,18 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
'task_id': None,
'n_unique_pending': n_unique_pending}

if best_task:
self._state.set_status(best_task, RUNNING, self._config)
if best_tasks:
best_batch = self._state.get_batch(worker_id, best_tasks)
best_task = self._state.get_task(best_batch.id, setdefault=best_batch)
for task in best_tasks:
if task == best_task:
continue
self._state.set_status(task, BATCH_RUNNING, self._config)
task.worker_running = worker_id
task.time_running = time.time()
self._update_task_history(task, BATCH_RUNNING, host=host)
batch_tasks = [task.id for task in best_tasks if task != best_task]
self._state.set_status(best_task, RUNNING, self._config, batch=batch_tasks)
best_task.worker_running = worker_id
best_task.time_running = time.time()
self._update_task_history(best_task, RUNNING, host=host)
Expand Down
10 changes: 10 additions & 0 deletions luigi/static/visualiser/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,16 @@ <h3 class="box-title">{{name}}</h3>
</div><!-- /.info-box -->
</div>

<div class="col-md-3 col-sm-6 col-xs-12">
<div class="info-box status-info" data-color='purple' data-category='BATCH_RUNNING' id="BATCH_RUNNING_info">
<span class="info-box-icon bg-purple"><i class="fa fa-spinner fa-pulse"></i></span>
<div class="info-box-content">
<span class="info-box-text">Batch Running Tasks</span>
<span class="info-box-number">?</span>
</div><!-- /.info-box-content -->
</div><!-- /.info-box -->
</div>

<div class="col-md-3 col-sm-6 col-xs-12">
<div class="info-box status-info" data-color='green' data-category='DONE' id="DONE_info">
<span class="info-box-icon bg-green"><i class="fa fa-spinner fa-pulse"></i></span>
Expand Down

0 comments on commit bcde8bb

Please sign in to comment.