Skip to content

Commit

Permalink
[SPARK-47346][PYTHON] Make daemon mode configurable when creating Pyt…
Browse files Browse the repository at this point in the history
…hon planner workers

### What changes were proposed in this pull request?

This PR adds an extra config to env.createPythonWorker to make daemon mode configurable to give more flexibility when creating Python planner workers.

### Why are the changes needed?

To make python workers more flexible.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#45468 from allisonwang-db/spark-47346-py-worker.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
allisonwang-db authored and ueshin committed Mar 15, 2024
1 parent 7c81bdf commit add49b3
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 26 deletions.
27 changes: 23 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,39 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Long]) = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
pythonWorkers.getOrElseUpdate(key,
new PythonWorkerFactory(pythonExec, workerModule, daemonModule, envVars)).create()
val workerFactory = pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(
pythonExec, workerModule, daemonModule, envVars, useDaemon))
if (workerFactory.useDaemonEnabled != useDaemon) {
throw SparkException.internalError("PythonWorkerFactory is already created with " +
s"useDaemon = ${workerFactory.useDaemonEnabled}, but now is requested with " +
s"useDaemon = $useDaemon. This is not allowed to change after the PythonWorkerFactory " +
s"is created given the same key: $key.")
}
workerFactory.create()
}
}

private[spark] def createPythonWorker(
pythonExec: String,
workerModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Long]) = {
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, useDaemon)
}

private[spark] def createPythonWorker(
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
val useDaemon = conf.get(Python.PYTHON_USE_DAEMON)
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars)
pythonExec, workerModule, daemonModule, envVars, useDaemon)
}

private[spark] def destroyPythonWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import scala.jdk.CollectionConverters._
import org.apache.spark._
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Python._
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}

Expand All @@ -47,14 +46,17 @@ private[spark] class PythonWorkerFactory(
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String])
envVars: Map[String, String],
val useDaemonEnabled: Boolean)
extends Logging { self =>

def this(
pythonExec: String,
workerModule: String,
envVars: Map[String, String]) =
this(pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars)
envVars: Map[String, String],
useDaemonEnabled: Boolean) =
this(pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule,
envVars, useDaemonEnabled)

import PythonWorkerFactory._

Expand All @@ -63,8 +65,6 @@ private[spark] class PythonWorkerFactory(
// currently only works on UNIX-based systems now because it uses signals for child management,
// so we can also fall back to launching workers, pyspark/worker.py (by default) directly.
private val useDaemon = {
val useDaemonEnabled = SparkEnv.get.conf.get(PYTHON_USE_DAEMON)

// This flag is ignored on Windows as it's unable to fork.
!System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTHON_USE_DAEMON}
import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT


private[spark] object StreamingPythonRunner {
Expand Down Expand Up @@ -68,17 +68,11 @@ private[spark] class StreamingPythonRunner(
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)

val prevConf = conf.get(PYTHON_USE_DAEMON)
conf.set(PYTHON_USE_DAEMON, false)
try {
val workerFactory =
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true)
pythonWorker = Some(worker)
pythonWorkerFactory = Some(workerFactory)
} finally {
conf.set(PYTHON_USE_DAEMON, prevConf)
}
val workerFactory =
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap, false)
val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true)
pythonWorker = Some(worker)
pythonWorkerFactory = Some(workerFactory)

val stream = new BufferedOutputStream(
pythonWorker.get.channel.socket().getOutputStream, bufferSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PythonWorkerFactorySuite extends SparkFunSuite with SharedSparkContext {
// E.g. the worker might fail at the beginning before it tries to connect back.

val workerFactory = new PythonWorkerFactory(
"python3", "pyspark.testing.non_existing_worker_module", Map.empty
"python3", "pyspark.testing.non_existing_worker_module", Map.empty, false
)

// Create the worker in a separate thread so that if there is a bug where it does not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {

protected def receiveFromPython(dataIn: DataInputStream): T

def runInPython(): T = {
def runInPython(useDaemon: Boolean = SparkEnv.get.conf.get(PYTHON_USE_DAEMON)): T = {
val env = SparkEnv.get
val bufferSize: Int = env.conf.get(BUFFER_SIZE)
val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
Expand Down Expand Up @@ -82,7 +82,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
/* valueCompare = */ false)

val (worker: PythonWorker, _) =
env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, useDaemon)
var releasedOrClosed = false
val bufferStream = new DirectByteBufferOutputStream()
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class PythonStreamingSourceRunner(
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)

val workerFactory =
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap, false)
val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true)
pythonWorker = Some(worker)
pythonWorkerFactory = Some(workerFactory)
Expand Down

0 comments on commit add49b3

Please sign in to comment.