Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
Browse files Browse the repository at this point in the history
…on-new
  • Loading branch information
jkbradley committed Aug 2, 2014
2 parents bf21be4 + e8e0fd6 commit aa29873
Show file tree
Hide file tree
Showing 12 changed files with 1,201 additions and 556 deletions.
1 change: 1 addition & 0 deletions .rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ dist/*
.*ipr
.*iws
logs
.*scalastyle-output.xml
69 changes: 49 additions & 20 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.Try
import scala.util.{Try, Success, Failure}

import net.razorvine.pickle.{Pickler, Unpickler}

Expand Down Expand Up @@ -536,25 +536,6 @@ private[spark] object PythonRDD extends Logging {
file.close()
}

/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* It is only used by pyspark.sql.
* TODO: Support more Python types.
*/
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
baseConf: Configuration): Configuration = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
Expand Down Expand Up @@ -701,6 +682,54 @@ private[spark] object PythonRDD extends Logging {
}
}


/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* This function is outdated, PySpark does not use it anymore
*/
@deprecated
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}

/**
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {

def toArray(obj: Any): Array[_] = {
obj match {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}

pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].map(toArray)
} else {
Seq(toArray(obj))
}
}
}.toJavaRDD()
}

/**
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
* PySpark.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

// Attempt to connect, restart and retry once if it fails
try {
new Socket(daemonHost, daemonPort)
val socket = new Socket(daemonHost, daemonPort)
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
if (launchStatus != 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
socket
} catch {
case exc: SocketException =>
logWarning("Python daemon unexpectedly quit, attempting to restart")
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
*/
private def mergeSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
if (propertiesFile == null) {
sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir =>
val sep = File.separator
val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf"
val file = new File(defaultPath)
if (file.exists()) {
propertiesFile = file.getAbsolutePath
}
}
}

if (propertiesFile == null) {
sys.env.get("SPARK_HOME").foreach { sparkHome =>
val sep = File.separator
Expand Down
179 changes: 71 additions & 108 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,39 @@
# limitations under the License.
#

import numbers
import os
import signal
import select
import socket
import sys
import traceback
import multiprocessing
from ctypes import c_bool
from errno import EINTR, ECHILD
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int

try:
POOLSIZE = multiprocessing.cpu_count()
except NotImplementedError:
POOLSIZE = 4

exit_flag = multiprocessing.Value(c_bool, False)


def should_exit():
global exit_flag
return exit_flag.value


def compute_real_exit_code(exit_code):
# SystemExit's code can be integer or string, but os._exit only accepts integers
import numbers
if isinstance(exit_code, numbers.Integral):
return exit_code
else:
return 1


def worker(listen_sock):
def worker(sock):
"""
Called by a worker process after the fork().
"""
# Redirect stdout to stderr
os.dup2(2, 1)
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1

# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(*args):
assert should_exit()
signal.signal(SIGHUP, handle_sighup)

# Cleanup zombie children
def handle_sigchld(*args):
pid = status = None
try:
while (pid, status) != (0, 0):
pid, status = os.waitpid(0, os.WNOHANG)
except EnvironmentError as err:
if err.errno == EINTR:
# retry
handle_sigchld()
elif err.errno != ECHILD:
raise
signal.signal(SIGCHLD, handle_sigchld)
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)

# Blocks until the socket is closed by draining the input stream
# until it raises an exception or returns EOF.
Expand All @@ -85,55 +60,23 @@ def waitSocketClose(sock):
except:
pass

# Handle clients
while not should_exit():
# Wait until a client arrives or we have to exit
sock = None
while not should_exit() and sock is None:
try:
sock, addr = listen_sock.accept()
except EnvironmentError as err:
if err.errno != EINTR:
raise

if sock is not None:
# Fork a child to handle the client.
# The client is handled in the child so that the manager
# never receives SIGCHLD unless a worker crashes.
if os.fork() == 0:
# Leave the worker pool
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
listen_sock.close()
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
finally:
outfile.flush()
# The Scala side will close the socket upon task completion.
waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))
else:
sock.close()


def launch_worker(listen_sock):
if os.fork() == 0:
try:
worker(listen_sock)
except Exception as err:
traceback.print_exc()
os._exit(1)
else:
assert should_exit()
os._exit(0)
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
write_int(0, outfile) # Acknowledge that the fork was successful
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
exit_code = exc.code
finally:
outfile.flush()
# The Scala side will close the socket upon task completion.
waitSocketClose(sock)
os._exit(compute_real_exit_code(exit_code))


def manager():
Expand All @@ -143,29 +86,28 @@ def manager():
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket.socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_sock.listen(max(1024, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)

# Launch initial worker pool
for idx in range(POOLSIZE):
launch_worker(listen_sock)
listen_sock.close()

def shutdown():
global exit_flag
exit_flag.value = True
def shutdown(code):
signal.signal(SIGTERM, SIG_DFL)
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
exit(code)

# Gracefully exit on SIGTERM, don't die on SIGHUP
signal.signal(SIGTERM, lambda signum, frame: shutdown())
signal.signal(SIGHUP, SIG_IGN)
def handle_sigterm(*args):
shutdown(1)
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP

# Cleanup zombie children
def handle_sigchld(*args):
try:
pid, status = os.waitpid(0, os.WNOHANG)
if status != 0 and not should_exit():
raise RuntimeError("worker crashed: %s, %s" % (pid, status))
if status != 0:
msg = "worker %s crashed abruptly with exit status %s" % (pid, status)
print >> sys.stderr, msg
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
Expand All @@ -174,20 +116,41 @@ def handle_sigchld(*args):
# Initialization complete
sys.stdout.close()
try:
while not should_exit():
while True:
try:
# Spark tells us to exit by closing stdin
if os.read(0, 512) == '':
shutdown()
except EnvironmentError as err:
if err.errno != EINTR:
shutdown()
ready_fds = select.select([0, listen_sock], [], [])[0]
except select.error as ex:
if ex[0] == EINTR:
continue
else:
raise
if 0 in ready_fds:
# Spark told us to exit by closing stdin
shutdown(0)
if listen_sock in ready_fds:
sock, addr = listen_sock.accept()
# Launch a worker process
try:
fork_return_code = os.fork()
if fork_return_code == 0:
listen_sock.close()
try:
worker(sock)
except:
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
sock.close()
except OSError as e:
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
sock.close()
finally:
signal.signal(SIGTERM, SIG_DFL)
exit_flag.value = True
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
shutdown(1)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit aa29873

Please sign in to comment.