Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables running of multiple tasks in batches #1538

Closed
wants to merge 11 commits into from
4 changes: 3 additions & 1 deletion luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def run(self):
_counter = 0 # non-atomically increasing counter used for ordering parameters.

def __init__(self, default=_no_value, is_global=False, significant=True, description=None,
config_path=None, positional=True, always_in_help=False):
config_path=None, positional=True, always_in_help=False, batch_method=None):
"""
:param default: the default value for this parameter. This should match the type of the
Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for
Expand All @@ -139,6 +139,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip
``positional=False`` for abstract base classes and similar cases.
:param bool always_in_help: For the --help option in the command line
parsing. Set true to always show in --help.
:param str batch_method: How multiple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For such a complicated parameter I hoped for more docs. ^^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! Not sure how that happened.

"""
self._default = default
if is_global:
Expand All @@ -151,6 +152,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip

self.description = description
self.always_in_help = always_in_help
self.batch_method = batch_method

if config_path is not None and ('section' not in config_path or 'name' not in config_path):
raise ParameterException('config_path must be a hash containing entries for section and name')
Expand Down
11 changes: 10 additions & 1 deletion luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def ping(self, worker):
def add_task(self, worker, task_id, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None, priority=0,
family='', module=None, params=None, assistant=False,
tracking_url=None):
tracking_url=None, batchable=None):
self._request('/api/add_task', {
'task_id': task_id,
'worker': worker,
Expand All @@ -172,6 +172,15 @@ def add_task(self, worker, task_id, status=PENDING, runnable=True,
'params': params,
'assistant': assistant,
'tracking_url': tracking_url,
'batchable': batchable,
})

def add_task_batcher(self, worker, family, batcher_aggregate_args, max_batch_size=None):
self._request('/api/add_task_batcher', {
'worker': worker,
'family': family,
'batcher_aggregate_args': batcher_aggregate_args,
'max_batch_size': max_batch_size,
})

def get_work(self, worker, host=None, assistant=False, current_tasks=None):
Expand Down
172 changes: 155 additions & 17 deletions luigi/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from luigi import notifications
from luigi import parameter
from luigi import task_history as history
from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN
from luigi.task import Config
from luigi.task_status import BATCH_RUNNING, DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN
from luigi.task import Config, task_id_str

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,12 +71,25 @@ class Scheduler(object):
STATUS_TO_UPSTREAM_MAP = {
FAILED: UPSTREAM_FAILED,
RUNNING: UPSTREAM_RUNNING,
BATCH_RUNNING: UPSTREAM_RUNNING,
PENDING: UPSTREAM_MISSING_INPUT,
DISABLED: UPSTREAM_DISABLED,
}

TASK_FAMILY_RE = re.compile(r'([^(_]+)[(_]')

TASKS = 'tasks'
ACTIVE_WORKERS = 'active_workers'
BATCH_TASKS = 'batch_tasks'
RUNNING_BATCHES = 'running_batches'

AGGREGATE_FUNCTIONS = {
'csv': ','.join,
'max': max,
'min': min,
'range': lambda args: '%s-%s' % (min(args), max(args)),
}


class scheduler(Config):
# TODO(erikbern): the config_path is needed for backwards compatilibity. We
Expand Down Expand Up @@ -162,7 +175,7 @@ class Task(object):

def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None,
params=None, disable_failures=None, disable_window=None, disable_hard_timeout=None,
tracking_url=None, status_message=None):
tracking_url=None, status_message=None, is_batch=False):
self.id = task_id
self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active)
self.workers = set() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active
Expand Down Expand Up @@ -190,6 +203,8 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='',
self.status_message = status_message
self.scheduler_disable_time = None
self.runnable = False
self.is_batch = is_batch
self.batchable = False

def __repr__(self):
return "Task(%r)" % vars(self)
Expand All @@ -214,6 +229,31 @@ def pretty_id(self):
return '{}({})'.format(self.family, param_str)


class TaskBatcher(object):
def __init__(self, aggregate_args, max_batch_size):
self.aggregates = aggregate_args
self.max_batch_size = max_batch_size

def task_id(self, tasks):
if self.max_batch_size is not None and len(tasks) > self.max_batch_size:
return None, None
params = {}
for arg in tasks[0].params:
raw_vals = [task.params[arg] for task in tasks]
agg_function = self.aggregates.get(arg)
if agg_function is not None:
arg_val = AGGREGATE_FUNCTIONS[agg_function](raw_vals)
elif any(v != raw_vals[0] for v in raw_vals):
return None, None
else:
arg_val = raw_vals[0]
params[arg] = arg_val

family = tasks[0].family
assert all(task.family == family for task in tasks)
return task_id_str(family, params), params


class Worker(object):
"""
Structure for tracking worker activity and keeping their references.
Expand Down Expand Up @@ -287,14 +327,28 @@ class SimpleTaskState(object):
def __init__(self, state_path):
self._state_path = state_path
self._tasks = {} # map from id to a Task object
self._batch_tasks = {} # map from family to TaskBatch object
self._running_batches = {} # map from id to list of batched task ids
self._status_tasks = collections.defaultdict(dict)
self._active_workers = {} # map from id to a Worker object

def get_state(self):
return self._tasks, self._active_workers
return {
TASKS: self._tasks,
ACTIVE_WORKERS: self._active_workers,
BATCH_TASKS: self._batch_tasks,
RUNNING_BATCHES: self._running_batches,
}

def set_state(self, state):
self._tasks, self._active_workers = state
if isinstance(state, dict):
self._tasks = state.get(TASKS, {})
self._active_workers = state.get(ACTIVE_WORKERS, {})
self._batch_tasks = state.get(BATCH_TASKS, {})
self._running_batches = state.get(RUNNING_BATCHES, {})
else:
self._tasks, self._active_workers = state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a note that this is only for backward compatibility and this codepath will be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._batch_tasks = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add self._running_batches = {}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def dump(self):
try:
Expand Down Expand Up @@ -338,6 +392,37 @@ def get_pending_tasks(self):
return itertools.chain.from_iterable(six.itervalues(self._status_tasks[status])
for status in [PENDING, RUNNING])

def get_batch(self, worker_id, tasks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this name confusing? Maybe create_batch_task?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if len(tasks) == 1:
return tasks[0]
families = set(task.family for task in tasks)
if len(families) != 1:
return None
family = families.pop()
batch_task = self.get_batcher(worker_id, family)
if batch_task is None:
return None
task_id, params = batch_task.task_id(tasks)
if task_id is None:
return None
priority = max(task.priority for task in tasks)
resource_keys = functools.reduce(set.union, (task.resources.keys() for task in tasks), set())
resources = {key: max(task.resources.get(key, 0) for task in tasks) for key in resource_keys}
deps = functools.reduce(set.union, (task.deps for task in tasks), set())
batch_task_obj = Task(
task_id=task_id,
status=PENDING,
deps=deps,
priority=priority,
resources=resources,
family=family,
params=params,
is_batch=True,
)
batch_task_obj.stakeholders.add(worker_id)
batch_task_obj.workers.add(worker_id)
return batch_task_obj

def num_pending_tasks(self):
"""
Return how many tasks are PENDING + RUNNING. O(1).
Expand All @@ -352,6 +437,14 @@ def get_task(self, task_id, default=None, setdefault=None):
else:
return self._tasks.get(task_id, default)

def get_batcher(self, worker, family):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be explicit and write task_family?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return self._batch_tasks.get((worker, family))

def set_batcher(self, worker, family, batcher_aggregate_args, max_batch_size):
batcher = TaskBatcher(batcher_aggregate_args, max_batch_size)
self._batch_tasks[(worker, family)] = batcher
return batcher

def has_task(self, task_id):
return task_id in self._tasks

Expand All @@ -362,11 +455,11 @@ def re_enable(self, task, config=None):
self.set_status(task, FAILED, config)
task.failures.clear()

def set_status(self, task, new_status, config=None):
def set_status(self, task, new_status, config=None, batch=None):
if new_status == FAILED:
assert config is not None

if new_status == DISABLED and task.status == RUNNING:
if new_status == DISABLED and task.status in (RUNNING, BATCH_RUNNING):
return

if task.status == DISABLED:
Expand All @@ -377,7 +470,7 @@ def set_status(self, task, new_status, config=None):
elif task.scheduler_disable_time is not None and new_status != DISABLED:
return

if new_status == FAILED and task.status != DISABLED:
if new_status == FAILED and task.status != DISABLED and not task.is_batch:
task.add_failure()
if task.has_excessive_failures():
task.scheduler_disable_time = time.time()
Expand All @@ -394,6 +487,17 @@ def set_status(self, task, new_status, config=None):
elif new_status == DISABLED:
task.scheduler_disable_time = None

if new_status == RUNNING and batch:
self._running_batches[task.id] = set(batch)
elif task.id in self._running_batches and new_status != RUNNING:
subtask_status = FAILED if new_status == DISABLED else new_status
for subtask_id in self._running_batches.pop(task.id):
subtask = self.get_task(subtask_id)
if subtask_status == FAILED:
subtask.retry = time.time() + config.retry_delay
self.set_status(subtask, subtask_status, config)
subtask.expl = task.expl

if new_status != task.status:
self._status_tasks[task.status].pop(task.id)
self._status_tasks[new_status][task.id] = task
Expand All @@ -402,7 +506,7 @@ def set_status(self, task, new_status, config=None):

def fail_dead_worker_task(self, task, config, assistants):
# If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic
if task.status == RUNNING and task.worker_running and task.worker_running not in task.stakeholders | assistants:
if task.status in (RUNNING, BATCH_RUNNING) and task.worker_running and task.worker_running not in task.stakeholders | assistants:
logger.info("Task %r is marked as running by disconnected worker %r -> marking as "
"FAILED with retry delay of %rs", task.id, task.worker_running,
config.retry_delay)
Expand All @@ -428,7 +532,10 @@ def update_status(self, task, config):
self.set_status(task, PENDING, config)

def may_prune(self, task):
return task.remove and time.time() > task.remove
return (
(task.remove and time.time() > task.remove) or
(task.is_batch and task.status != RUNNING)
)

def inactivate_tasks(self, delete_tasks):
# The terminology is a bit confusing: we used to "delete" tasks when they became inactive,
Expand Down Expand Up @@ -483,6 +590,13 @@ def get_necessary_tasks(self):
necessary_tasks.add(task.id)
return necessary_tasks

def update_tracking_url(self, task, tracking_url):
task.tracking_url = tracking_url
for batch_task_id in self._running_batches.get(task.id, []):
batch_task = self.get_task(batch_task_id)
if batch_task:
batch_task.tracking_url = tracking_url


class CentralPlannerScheduler(Scheduler):
"""
Expand Down Expand Up @@ -575,10 +689,14 @@ def _update_priority(self, task, prio, worker):
if t is not None and prio > t.priority:
self._update_priority(t, prio, worker)

def add_task_batcher(self, worker, family, batcher_aggregate_args, max_batch_size=None):
self._state.set_batcher(worker, family, batcher_aggregate_args, max_batch_size)

def add_task(self, task_id=None, status=PENDING, runnable=True,
deps=None, new_deps=None, expl=None, resources=None,
priority=0, family='', module=None, params=None,
assistant=False, tracking_url=None, **kwargs):
assistant=False, tracking_url=None, batchable=None,
**kwargs):
"""
* add task identified by task_id if it doesn't exist
* if deps is not None, update dependency list
Expand Down Expand Up @@ -611,15 +729,18 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
task.params = _get_default(params, {})

if tracking_url is not None or task.status != RUNNING:
task.tracking_url = tracking_url
self._state.update_tracking_url(task, tracking_url)

if task.remove is not None:
task.remove = None # unmark task for removal so it isn't removed after being added

if expl is not None:
task.expl = expl

if not (task.status == RUNNING and status == PENDING) or new_deps:
if batchable is not None:
task.batchable = batchable

if not (task.status in (RUNNING, BATCH_RUNNING) and status == PENDING) or new_deps:
# don't allow re-scheduling of task while it is running, it must either fail or succeed first
if status == PENDING or status != task.status:
# Update the DB only if there was a acctual change, to prevent noise.
Expand Down Expand Up @@ -735,6 +856,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
for task in sorted(self._state.get_running_tasks(), key=self._rank):
if task.worker_running == worker_id and task.id not in ct_set:
best_task = task
best_tasks = [] if best_task is None else [best_task]

locally_pending_tasks = 0
running_tasks = []
Expand Down Expand Up @@ -776,7 +898,13 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
if len(task.workers) == 1 and not assistant:
n_unique_pending += 1

if best_task:
if (in_workers and best_tasks and task.batchable and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment above this line # Batch as many tasks as possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._state.get_batch(worker_id, best_tasks + [task]) is not None and
self._schedulable(task) and
self._has_resources(task.resources, greedy_resources)):
best_tasks.append(task)

if best_tasks:
continue

if task.status == RUNNING and (task.worker_running in greedy_workers):
Expand All @@ -786,7 +914,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):

if self._schedulable(task) and self._has_resources(task.resources, greedy_resources):
if in_workers and self._has_resources(task.resources, used_resources):
best_task = task
best_tasks = [task]
else:
workers = itertools.chain(task.workers, [worker_id]) if assistant else task.workers
for task_worker in workers:
Expand All @@ -805,8 +933,18 @@ def get_work(self, host=None, assistant=False, current_tasks=None, **kwargs):
'task_id': None,
'n_unique_pending': n_unique_pending}

if best_task:
self._state.set_status(best_task, RUNNING, self._config)
if best_tasks:
best_batch = self._state.get_batch(worker_id, best_tasks)
best_task = self._state.get_task(best_batch.id, setdefault=best_batch)
for task in best_tasks:
if task == best_task:
continue
self._state.set_status(task, BATCH_RUNNING, self._config)
task.worker_running = worker_id
task.time_running = time.time()
self._update_task_history(task, BATCH_RUNNING, host=host)
batch_tasks = [task.id for task in best_tasks if task != best_task]
self._state.set_status(best_task, RUNNING, self._config, batch=batch_tasks)
best_task.worker_running = worker_id
best_task.time_running = time.time()
self._update_task_history(best_task, RUNNING, host=host)
Expand Down
1 change: 1 addition & 0 deletions luigi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def initialize(self, scheduler):
def get(self, method):
if method not in [
'add_task',
'add_task_batcher',
'add_worker',
'dep_graph',
'disable_worker',
Expand Down