Skip to content

Commit

Permalink
Rebase SparkSubmitTask onto ExternalProgramTask
Browse files Browse the repository at this point in the history
The public interface of `SparkSubmitTask` remains as-is.

However, there will be subtle changes to the output to `stdout` and
logs (e.g. 'Program failed[...]' with this patch vs. 'Spark job
failed[...]' before). Also it will raise a `ExternalProgramRunError`
on execution errors instead of a `SparkJobError` as before.
  • Loading branch information
ehdr committed Jan 26, 2016
1 parent b8c9c17 commit 002d95b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 81 deletions.
80 changes: 14 additions & 66 deletions luigi/contrib/spark.py
Expand Up @@ -17,8 +17,6 @@

import logging
import os
import signal
import subprocess
import sys
import tempfile
import shutil
Expand All @@ -30,54 +28,13 @@
import pickle

from luigi import six
import luigi
import luigi.format
import luigi.contrib.hdfs
from luigi.contrib.external_program import ExternalProgramTask
from luigi import configuration

logger = logging.getLogger('luigi-interface')


class SparkRunContext(object):

def __init__(self, proc):
self.proc = proc

def __enter__(self):
self.__old_signal = signal.getsignal(signal.SIGTERM)
signal.signal(signal.SIGTERM, self.kill_job)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is KeyboardInterrupt:
self.kill_job()
signal.signal(signal.SIGTERM, self.__old_signal)

def kill_job(self, captured_signal=None, stack_frame=None):
self.proc.kill()
if captured_signal is not None:
# adding 128 gives the exit code corresponding to a signal
sys.exit(128 + captured_signal)


class SparkJobError(RuntimeError):

def __init__(self, message, out=None, err=None):
super(SparkJobError, self).__init__(message, out, err)
self.message = message
self.out = out
self.err = err

def __str__(self):
info = self.message
if self.out:
info += "\nSTDOUT: " + str(self.out)
if self.err:
info += "\nSTDERR: " + str(self.err)
return info


class SparkSubmitTask(luigi.Task):
class SparkSubmitTask(ExternalProgramTask):
"""
Template task for running a Spark job
Expand All @@ -93,6 +50,9 @@ class SparkSubmitTask(luigi.Task):
entry_class = None
app = None

# Only log stderr if spark fails (since stderr is normally quite verbose)
always_log_stderr = False

def app_options(self):
"""
Subclass this method to map your task parameters to the app's arguments
Expand Down Expand Up @@ -195,6 +155,12 @@ def get_environment(self):
env['HADOOP_CONF_DIR'] = hadoop_conf_dir
return env

def program_environment(self):
return self.get_environment()

def program_args(self):
return self.spark_command() + self.app_command()

def spark_command(self):
command = [self.spark_submit]
command += self._text_arg('--master', self.master)
Expand Down Expand Up @@ -226,27 +192,6 @@ def app_command(self):
raise NotImplementedError("subclass should define an app (.jar or .py file)")
return [self.app] + self.app_options()

def run(self):
args = list(map(str, self.spark_command() + self.app_command()))
logger.info('Running: {}'.format(subprocess.list2cmdline(args)))
tmp_stdout, tmp_stderr = tempfile.TemporaryFile(), tempfile.TemporaryFile()
proc = subprocess.Popen(args, stdout=tmp_stdout, stderr=tmp_stderr,
env=self.get_environment(), close_fds=True,
universal_newlines=True)
try:
with SparkRunContext(proc):
proc.wait()
tmp_stdout.seek(0)
stdout = "".join(map(lambda s: s.decode('utf-8'), tmp_stdout.readlines()))
logger.info("Spark job stdout:\n{0}".format(stdout))
if proc.returncode != 0:
tmp_stderr.seek(0)
stderr = "".join(map(lambda s: s.decode('utf-8'), tmp_stderr.readlines()))
raise SparkJobError('Spark job failed {0}'.format(repr(args)), out=stdout, err=stderr)
finally:
tmp_stderr.close()
tmp_stdout.close()

def _list_config(self, config):
if config and isinstance(config, six.string_types):
return list(map(lambda x: x.strip(), config.split(',')))
Expand Down Expand Up @@ -323,6 +268,9 @@ def main(self, sc, *args):
"""
raise NotImplementedError("subclass should define a main method")

def program_args(self):
return self.spark_command() + self.app_command()

def app_command(self):
return [self.app, self.run_pickle] + self.app_options()

Expand Down
47 changes: 32 additions & 15 deletions test/contrib/spark_test.py
Expand Up @@ -22,8 +22,9 @@
from luigi import six
from luigi.mock import MockTarget
from helpers import with_config
from luigi.contrib.spark import SparkJobError, SparkSubmitTask, PySparkTask
from mock import patch, MagicMock
from luigi.contrib.external_program import ExternalProgramRunError
from luigi.contrib.spark import SparkSubmitTask, PySparkTask
from mock import patch, call, MagicMock

BytesIO = six.BytesIO

Expand Down Expand Up @@ -94,7 +95,7 @@ class SparkSubmitTaskTest(unittest.TestCase):
ss = 'ss-stub'

@with_config({'spark': {'spark-submit': ss, 'master': "yarn-client", 'hadoop-conf-dir': 'path'}})
@patch('luigi.contrib.spark.subprocess.Popen')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_run(self, proc):
setup_run_process(proc)
job = TestSparkSubmitTask()
Expand All @@ -121,7 +122,7 @@ def test_environment_is_set_correctly(self, proc):

@with_config({'spark': {'spark-submit': ss, 'master': 'spark://host:7077', 'conf': 'prop1=val1', 'jars': 'jar1.jar,jar2.jar',
'files': 'file1,file2', 'py-files': 'file1.py,file2.py', 'archives': 'archive1'}})
@patch('luigi.contrib.spark.subprocess.Popen')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_defaults(self, proc):
proc.return_value.returncode = 0
job = TestDefaultSparkSubmitTask()
Expand All @@ -131,27 +132,43 @@ def test_defaults(self, proc):
'--py-files', 'file1.py,file2.py', '--files', 'file1,file2', '--archives', 'archive1',
'--conf', 'prop1=val1', 'test.py'])

@patch('luigi.contrib.spark.tempfile.TemporaryFile')
@patch('luigi.contrib.spark.subprocess.Popen')
def test_handle_failed_job(self, proc, file):
@patch('luigi.contrib.external_program.logger')
@patch('luigi.contrib.external_program.tempfile.TemporaryFile')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_handle_failed_job(self, proc, file, logger):
proc.return_value.returncode = 1
file.return_value = BytesIO(b'stderr')
file.return_value = BytesIO(b'spark test error')
try:
job = TestSparkSubmitTask()
job.run()
except SparkJobError as e:
self.assertEqual(e.err, 'stderr')
self.assertTrue('STDERR: stderr' in six.text_type(e))
except ExternalProgramRunError as e:
self.assertEqual(e.err, 'spark test error')
self.assertIn('spark test error', six.text_type(e))
self.assertIn(call.info('Program stderr:\nspark test error'),
logger.mock_calls)
else:
self.fail("Should have thrown SparkJobError")
self.fail("Should have thrown ExternalProgramRunError")

@patch('luigi.contrib.external_program.logger')
@patch('luigi.contrib.external_program.tempfile.TemporaryFile')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_dont_log_stderr_on_success(self, proc, file, logger):
proc.return_value.returncode = 0
file.return_value = BytesIO(b'spark normal error output')
job = TestSparkSubmitTask()
job.run()

self.assertNotIn(call.info(
'Program stderr:\nspark normal error output'),
logger.mock_calls)

@patch('luigi.contrib.spark.subprocess.Popen')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_app_must_be_set(self, proc):
with self.assertRaises(NotImplementedError):
job = SparkSubmitTask()
job.run()

@patch('luigi.contrib.spark.subprocess.Popen')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_app_interruption(self, proc):

def interrupt():
Expand All @@ -170,7 +187,7 @@ class PySparkTaskTest(unittest.TestCase):
ss = 'ss-stub'

@with_config({'spark': {'spark-submit': ss, 'master': "spark://host:7077"}})
@patch('luigi.contrib.spark.subprocess.Popen')
@patch('luigi.contrib.external_program.subprocess.Popen')
def test_run(self, proc):
setup_run_process(proc)
job = TestPySparkTask()
Expand Down

0 comments on commit 002d95b

Please sign in to comment.