diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 95449eb43e581..10c9950051bb0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -687,13 +687,25 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}): + def pipe(self, command, env={}, mode='permissive'): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() [u'1', u'2', u'', u'3'] """ + if mode == 'permissive': + def fail_condition(x): + return False + elif mode == 'strict': + def fail_condition(x): + return x == 0 + elif mode == 'grep': + def fail_condition(x): + return x == 0 or x == 1 + else: + raise ValueError("mode must be one of 'permissive', 'strict' or 'grep'.") + def func(iterator): pipe = Popen( shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) @@ -707,7 +719,7 @@ def pipe_objs(out): def check_return_code(): pipe.wait() - if pipe.returncode: + if fail_condition(pipe.returncode): raise Exception("Pipe function `%s' exited " "with error code %d" % (command, pipe.returncode)) else: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ca0fca2972860..42a14bf6dd292 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -878,10 +878,14 @@ def test_pipe_functions(self): data = ['1', '2', '3'] rdd = self.sc.parallelize(data) with QuietTest(self.sc): - self.assertRaises(Py4JJavaError, rdd.pipe('cc').collect) + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', mode='strict').collect) result = rdd.pipe('cat').collect() result.sort() [self.assertEqual(x, y) for x, y in zip(data, result)] + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', mode='strict').collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + self.assertEqual([], rdd.pipe('grep 4', mode='grep').collect()) class ProfilerTests(PySparkTestCase):