diff --git a/luigi/contrib/spark.py b/luigi/contrib/spark.py index 516066f6f6..9d7317f94b 100644 --- a/luigi/contrib/spark.py +++ b/luigi/contrib/spark.py @@ -17,8 +17,6 @@ import logging import os -import signal -import subprocess import sys import tempfile import shutil @@ -30,54 +28,13 @@ import pickle from luigi import six -import luigi -import luigi.format -import luigi.contrib.hdfs +from luigi.contrib.external_program import ExternalProgramTask from luigi import configuration logger = logging.getLogger('luigi-interface') -class SparkRunContext(object): - - def __init__(self, proc): - self.proc = proc - - def __enter__(self): - self.__old_signal = signal.getsignal(signal.SIGTERM) - signal.signal(signal.SIGTERM, self.kill_job) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is KeyboardInterrupt: - self.kill_job() - signal.signal(signal.SIGTERM, self.__old_signal) - - def kill_job(self, captured_signal=None, stack_frame=None): - self.proc.kill() - if captured_signal is not None: - # adding 128 gives the exit code corresponding to a signal - sys.exit(128 + captured_signal) - - -class SparkJobError(RuntimeError): - - def __init__(self, message, out=None, err=None): - super(SparkJobError, self).__init__(message, out, err) - self.message = message - self.out = out - self.err = err - - def __str__(self): - info = self.message - if self.out: - info += "\nSTDOUT: " + str(self.out) - if self.err: - info += "\nSTDERR: " + str(self.err) - return info - - -class SparkSubmitTask(luigi.Task): +class SparkSubmitTask(ExternalProgramTask): """ Template task for running a Spark job @@ -93,6 +50,9 @@ class SparkSubmitTask(luigi.Task): entry_class = None app = None + # Only log stderr if spark fails (since stderr is normally quite verbose) + always_log_stderr = False + def app_options(self): """ Subclass this method to map your task parameters to the app's arguments @@ -195,6 +155,12 @@ def get_environment(self): env['HADOOP_CONF_DIR'] = hadoop_conf_dir return env + def program_environment(self): + return self.get_environment() + + def program_args(self): + return self.spark_command() + self.app_command() + def spark_command(self): command = [self.spark_submit] command += self._text_arg('--master', self.master) @@ -226,27 +192,6 @@ def app_command(self): raise NotImplementedError("subclass should define an app (.jar or .py file)") return [self.app] + self.app_options() - def run(self): - args = list(map(str, self.spark_command() + self.app_command())) - logger.info('Running: {}'.format(subprocess.list2cmdline(args))) - tmp_stdout, tmp_stderr = tempfile.TemporaryFile(), tempfile.TemporaryFile() - proc = subprocess.Popen(args, stdout=tmp_stdout, stderr=tmp_stderr, - env=self.get_environment(), close_fds=True, - universal_newlines=True) - try: - with SparkRunContext(proc): - proc.wait() - tmp_stdout.seek(0) - stdout = "".join(map(lambda s: s.decode('utf-8'), tmp_stdout.readlines())) - logger.info("Spark job stdout:\n{0}".format(stdout)) - if proc.returncode != 0: - tmp_stderr.seek(0) - stderr = "".join(map(lambda s: s.decode('utf-8'), tmp_stderr.readlines())) - raise SparkJobError('Spark job failed {0}'.format(repr(args)), out=stdout, err=stderr) - finally: - tmp_stderr.close() - tmp_stdout.close() - def _list_config(self, config): if config and isinstance(config, six.string_types): return list(map(lambda x: x.strip(), config.split(','))) @@ -323,6 +268,9 @@ def main(self, sc, *args): """ raise NotImplementedError("subclass should define a main method") + def program_args(self): + return self.spark_command() + self.app_command() + def app_command(self): return [self.app, self.run_pickle] + self.app_options() diff --git a/test/contrib/spark_test.py b/test/contrib/spark_test.py index 81aad45359..9b5fc0fb81 100644 --- a/test/contrib/spark_test.py +++ b/test/contrib/spark_test.py @@ -22,8 +22,9 @@ from luigi import six from luigi.mock import MockTarget from helpers import with_config -from luigi.contrib.spark import SparkJobError, SparkSubmitTask, PySparkTask -from mock import patch, MagicMock +from luigi.contrib.external_program import ExternalProgramRunError +from luigi.contrib.spark import SparkSubmitTask, PySparkTask +from mock import patch, call, MagicMock BytesIO = six.BytesIO @@ -94,7 +95,7 @@ class SparkSubmitTaskTest(unittest.TestCase): ss = 'ss-stub' @with_config({'spark': {'spark-submit': ss, 'master': "yarn-client", 'hadoop-conf-dir': 'path'}}) - @patch('luigi.contrib.spark.subprocess.Popen') + @patch('luigi.contrib.external_program.subprocess.Popen') def test_run(self, proc): setup_run_process(proc) job = TestSparkSubmitTask() @@ -121,7 +122,7 @@ def test_environment_is_set_correctly(self, proc): @with_config({'spark': {'spark-submit': ss, 'master': 'spark://host:7077', 'conf': 'prop1=val1', 'jars': 'jar1.jar,jar2.jar', 'files': 'file1,file2', 'py-files': 'file1.py,file2.py', 'archives': 'archive1'}}) - @patch('luigi.contrib.spark.subprocess.Popen') + @patch('luigi.contrib.external_program.subprocess.Popen') def test_defaults(self, proc): proc.return_value.returncode = 0 job = TestDefaultSparkSubmitTask() @@ -131,27 +132,43 @@ def test_defaults(self, proc): '--py-files', 'file1.py,file2.py', '--files', 'file1,file2', '--archives', 'archive1', '--conf', 'prop1=val1', 'test.py']) - @patch('luigi.contrib.spark.tempfile.TemporaryFile') - @patch('luigi.contrib.spark.subprocess.Popen') - def test_handle_failed_job(self, proc, file): + @patch('luigi.contrib.external_program.logger') + @patch('luigi.contrib.external_program.tempfile.TemporaryFile') + @patch('luigi.contrib.external_program.subprocess.Popen') + def test_handle_failed_job(self, proc, file, logger): proc.return_value.returncode = 1 - file.return_value = BytesIO(b'stderr') + file.return_value = BytesIO(b'spark test error') try: job = TestSparkSubmitTask() job.run() - except SparkJobError as e: - self.assertEqual(e.err, 'stderr') - self.assertTrue('STDERR: stderr' in six.text_type(e)) + except ExternalProgramRunError as e: + self.assertEqual(e.err, 'spark test error') + self.assertIn('spark test error', six.text_type(e)) + self.assertIn(call.info('Program stderr:\nspark test error'), + logger.mock_calls) else: - self.fail("Should have thrown SparkJobError") + self.fail("Should have thrown ExternalProgramRunError") + + @patch('luigi.contrib.external_program.logger') + @patch('luigi.contrib.external_program.tempfile.TemporaryFile') + @patch('luigi.contrib.external_program.subprocess.Popen') + def test_dont_log_stderr_on_success(self, proc, file, logger): + proc.return_value.returncode = 0 + file.return_value = BytesIO(b'spark normal error output') + job = TestSparkSubmitTask() + job.run() + + self.assertNotIn(call.info( + 'Program stderr:\nspark normal error output'), + logger.mock_calls) - @patch('luigi.contrib.spark.subprocess.Popen') + @patch('luigi.contrib.external_program.subprocess.Popen') def test_app_must_be_set(self, proc): with self.assertRaises(NotImplementedError): job = SparkSubmitTask() job.run() - @patch('luigi.contrib.spark.subprocess.Popen') + @patch('luigi.contrib.external_program.subprocess.Popen') def test_app_interruption(self, proc): def interrupt(): @@ -170,7 +187,7 @@ class PySparkTaskTest(unittest.TestCase): ss = 'ss-stub' @with_config({'spark': {'spark-submit': ss, 'master': "spark://host:7077"}}) - @patch('luigi.contrib.spark.subprocess.Popen') + @patch('luigi.contrib.external_program.subprocess.Popen') def test_run(self, proc): setup_run_process(proc) job = TestPySparkTask()