Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix blocking worker queue when scheduling in parallel #3013

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 82 additions & 32 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import subprocess
import sys
import contextlib
import warnings

import queue as Queue
import random
Expand Down Expand Up @@ -342,21 +343,65 @@ def respond(self, response):
self._scheduler.add_scheduler_message_response(self._task_id, self._message_id, response)


class SyncResult(object):
"""
Synchronous implementation of ``multiprocessing.pool.AsyncResult`` that immediately calls *func*
with *args* and *kwargs*. Its methods :py:meth:`get`, :py:meth:`wait`, :py:meth:`ready` and
:py:meth:`successful` work in a similar fashion, depending on the result of the function call.
"""

def __init__(self, func, args=None, kwargs=None):
super(SyncResult, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

luigi isn't officially supporting py 2.7 anymore so this could just be super().__init__(), but totally optional


# store function and arguments
self._func = func
self._args = args or ()
self._kwargs = kwargs or {}

# store return value and potential exceptions
self._return_value = None
self._exception = None

# immediately call
self._call()

def _call(self):
try:
self._return_value = self._func(*self._args, **self._kwargs)
except BaseException as e:
self._exception = e

def get(self, timeout=None):
if self._exception:
raise self._exception
else:
return self._return_value

def wait(self, timeout=None):
return

def ready(self):
return True

def successful(self):
return self._exception is None


class SingleProcessPool:
"""
Dummy process pool for using a single processor.

Imitates the api of multiprocessing.Pool using single-processor equivalents.
"""

def apply_async(self, function, args):
return function(*args)
def apply_async(self, function, args=None, kwargs=None):
return SyncResult(function, args=args, kwargs=kwargs)

def close(self):
pass
return

def join(self):
pass
return


class DequeQueue(collections.deque):
Expand All @@ -380,6 +425,8 @@ class AsyncCompletionException(Exception):
"""

def __init__(self, trace):
warnings.warn("{} is deprecated and will be removed in a future release".format(
self.__class__.__name__), DeprecationWarning)
self.trace = trace


Expand All @@ -389,19 +436,17 @@ class TracebackWrapper:
"""

def __init__(self, trace):
warnings.warn("{} is deprecated and will be removed in a future release".format(
self.__class__.__name__), DeprecationWarning)
self.trace = trace


def check_complete(task, out_queue):
def check_complete(task):
"""
Checks if task is complete, puts the result to out_queue.
Checks if task is complete.
"""
logger.debug("Checking if %s is complete", task)
try:
is_complete = task.complete()
except Exception:
is_complete = TracebackWrapper(traceback.format_exc())
out_queue.put((task, is_complete))
return task.complete()
Comment on lines -400 to +449
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the potential for hitting an exception here? Was the previous catch never caught and thus pointless? Or are we introducing the opportunity here for an unhandled exception?



class worker(Config):
Expand Down Expand Up @@ -727,7 +772,7 @@ def _handle_task_load_error(self, exception, task_ids):
)
notifications.send_error_email(subject, error_message)

def add(self, task, multiprocess=False, processes=0):
def add(self, task, multiprocess=False, processes=0, wait_interval=0.01):
"""
Add a Task for the worker to check and possibly schedule and run.

Expand All @@ -737,28 +782,35 @@ def add(self, task, multiprocess=False, processes=0):
self._first_task = task.task_id
self.add_succeeded = True
if multiprocess:
queue = multiprocessing.Manager().Queue()
pool = multiprocessing.Pool(processes=processes if processes > 0 else None)
else:
queue = DequeQueue()
pool = SingleProcessPool()
self._validate_task(task)
pool.apply_async(check_complete, [task, queue])
results = [(task, pool.apply_async(check_complete, (task,)))]

# we track queue size ourselves because len(queue) won't work for multiprocessing
queue_size = 1
try:
seen = {task.task_id}
while queue_size:
current = queue.get()
queue_size -= 1
item, is_complete = current
for next in self._add(item, is_complete):
if next.task_id not in seen:
self._validate_task(next)
seen.add(next.task_id)
pool.apply_async(check_complete, [next, queue])
queue_size += 1
while results:
# fetch the first done result
for i, (task, result) in enumerate(list(results)):
if result.ready():
results.pop(i)
break
else:
time.sleep(wait_interval)
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The for/else + continue makes the flow unclear. How about packing up the rest of the loop into a new method and calling it before break?


# get the result or error
try:
is_complete = result.get()
except Exception as e:
is_complete = e

for dep in self._add(task, is_complete):
if dep.task_id not in seen:
self._validate_task(dep)
seen.add(dep.task_id)
results.append((dep, pool.apply_async(check_complete, (dep,))))
except (KeyboardInterrupt, TaskException):
raise
except Exception as ex:
Expand Down Expand Up @@ -800,8 +852,6 @@ def _add(self, task, is_complete):
self._check_complete_value(is_complete)
except KeyboardInterrupt:
raise
except AsyncCompletionException as ex:
formatted_traceback = ex.trace
except BaseException:
formatted_traceback = traceback.format_exc()

Expand Down Expand Up @@ -881,9 +931,9 @@ def _validate_dependency(self, dependency):
raise Exception('requires() must return Task objects but {} is a {}'.format(dependency, type(dependency)))

def _check_complete_value(self, is_complete):
if is_complete not in (True, False):
if isinstance(is_complete, TracebackWrapper):
raise AsyncCompletionException(is_complete.trace)
if isinstance(is_complete, BaseException):
raise is_complete
elif not isinstance(is_complete, bool):
raise Exception("Return value of Task.complete() must be boolean (was %r)" % is_complete)

def _add_worker(self):
Expand Down
2 changes: 0 additions & 2 deletions test/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from multiprocessing import Process
from helpers import unittest

import luigi.scheduler
import luigi.server
import luigi.configuration
from helpers import with_config
from luigi.target import FileAlreadyExists

Expand Down
59 changes: 58 additions & 1 deletion test/worker_parallel_scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import pickle
import time
import warnings
from helpers import unittest

import mock
Expand All @@ -28,6 +29,7 @@
import luigi
from luigi.worker import Worker
from luigi.task_status import UNKNOWN
from helpers import RunOnceTask


def running_children():
Expand Down Expand Up @@ -95,6 +97,10 @@ class UnpicklableException(Exception):
raise UnpicklableException()


class PicklableTask(RunOnceTask):
i = luigi.IntParameter()


class ParallelSchedulingTest(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -175,11 +181,62 @@ def test_raise_unpicklable_exception_in_complete(self, send):
send.check_called_once()
self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]['status'])
self.assertFalse(self.sch.add_task.call_args[1]['runnable'])
self.assertTrue('raise UnpicklableException()' in send.call_args[0][1])
self.assertTrue("Can't pickle local object 'UnpicklableExceptionTask" in send.call_args[0][1])

@mock.patch('luigi.notifications.send_error_email')
def test_raise_exception_in_requires(self, send):
self.w.add(ExceptionRequiresTask(), multiprocess=True)
send.check_called_once()
self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]['status'])
self.assertFalse(self.sch.add_task.call_args[1]['runnable'])

def test_parallel_scheduling_with_picklable_tasks(self):
tasks = [PicklableTask(i=i) for i in range(5)]
success = luigi.interface.build(tasks, local_scheduler=True, parallel_scheduling=True,
parallel_scheduling_processes=2)
self.assertTrue(success)

def test_parallel_scheduling_with_unpicklable_tasks(self):
class UnpicklableTask(RunOnceTask):
i = luigi.IntParameter()

tasks = [UnpicklableTask(i=i) for i in range(5)]
success = luigi.interface.build(tasks, local_scheduler=True, parallel_scheduling=True,
parallel_scheduling_processes=2)
self.assertFalse(success)

def test_sync_result(self):
def func1(a, b):
return a + b

def func2(a, b):
raise Exception("unknown")

def func3(a):
raise Exception("never called")

r = luigi.worker.SyncResult(func1, (1, 2))
self.assertIsNone(r.wait())
self.assertTrue(r.ready())
self.assertTrue(r.successful())
self.assertEqual(r.get(), 3)

r = luigi.worker.SyncResult(func2, (1, 2))
self.assertIsNone(r.wait())
self.assertTrue(r.ready())
self.assertFalse(r.successful())
with self.assertRaises(Exception):
r.get()

r = luigi.worker.SyncResult(func3, (1, 2))
self.assertIsNone(r.wait())
self.assertTrue(r.ready())
self.assertFalse(r.successful())
with self.assertRaises(TypeError):
r.get()

def test_deprecations(self):
with warnings.catch_warnings(record=True) as w:
luigi.worker.AsyncCompletionException("foo")
luigi.worker.TracebackWrapper("foo")
self.assertEqual(len(w), 2)