Skip to content

Commit

Permalink
[SPARK-46895][CORE] Replace Timer with single thread scheduled executor
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR propose to replace `Timer` with single thread scheduled executor.

### Why are the changes needed?
The javadoc recommends `ScheduledThreadPoolExecutor` instead of `Timer`.
![屏幕快照 2024-01-12 下午12 47 57](https://github.com/apache/spark/assets/8486025/4fc5ed61-6bb9-4768-915a-ad919a067d04)

This change based on the following two points.
**System time sensitivity**

Timer scheduling is based on the absolute time of the operating system and is sensitive to the operating system's time. Once the operating system's time changes, Timer scheduling is no longer precise.
The scheduled Thread Pool Executor scheduling is based on relative time and is not affected by changes in operating system time.

**Are anomalies captured**

Timer does not capture exceptions thrown by Timer Tasks, and in addition, Timer is single threaded. Once a scheduling task encounters an exception, the entire thread will terminate and other tasks that need to be scheduled will no longer be executed.
The scheduled Thread Pool Executor implements scheduling functions based on a thread pool. After a task throws an exception, other tasks can still execute normally.

### Does this PR introduce _any_ user-facing change?
'No'.

### How was this patch tested?
GA tests.

### Was this patch authored or co-authored using generative AI tooling?
'No'.

Closes apache#44718 from beliefer/replace-timer-with-threadpool.

Authored-by: beliefer <beliefer@163.com>
Signed-off-by: yangjie01 <yangjie01@baidu.com>
  • Loading branch information
beliefer authored and LuciferYang committed Feb 6, 2024
1 parent cfbf3c7 commit 5d5b3a5
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 28 deletions.
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@

package org.apache.spark

import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.TimerTask
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.function.Consumer

import scala.collection.mutable.{ArrayBuffer, HashSet}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
import org.apache.spark.util.ThreadUtils

/**
* For each barrier stage attempt, only at most one barrier() call can be active at any time, thus
Expand All @@ -51,7 +52,8 @@ private[spark] class BarrierCoordinator(

// TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
// fetch result, we shall fix the issue.
private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer")
private lazy val timer = ThreadUtils.newSingleThreadScheduledExecutor(
"BarrierCoordinator barrier epoch increment timer")

// Listen to StageCompleted event, clear corresponding ContextBarrierState.
private val listener = new SparkListener {
Expand All @@ -77,6 +79,7 @@ private[spark] class BarrierCoordinator(
states.forEachValue(1, clearStateConsumer)
states.clear()
listenerBus.removeListener(listener)
ThreadUtils.shutdown(timer)
} finally {
super.onStop()
}
Expand Down Expand Up @@ -168,7 +171,7 @@ private[spark] class BarrierCoordinator(
// we may timeout for the sync.
if (requesters.isEmpty) {
initTimerTask(this)
timer.schedule(timerTask, timeoutInSecs * 1000)
timer.schedule(timerTask, timeoutInSecs, TimeUnit.SECONDS)
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark

import java.util.{Properties, Timer, TimerTask}
import java.util.{Properties, TimerTask}
import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit}

import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -70,8 +71,8 @@ class BarrierTaskContext private[spark] (
s"current barrier epoch is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
// Log the update of global sync every 1 minute.
timer.scheduleAtFixedRate(timerTask, 1, 1, TimeUnit.MINUTES)

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
Expand Down Expand Up @@ -283,6 +284,11 @@ object BarrierTaskContext {
@Since("2.4.0")
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]

private val timer = new Timer("Barrier task timer for barrier() calls.")
private val timer = {
val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"Barrier task timer for barrier() calls.")
assert(executor.isInstanceOf[ScheduledThreadPoolExecutor])
executor.asInstanceOf[ScheduledThreadPoolExecutor]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.{Timer, TimerTask}
import java.util.TimerTask
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong

Expand Down Expand Up @@ -135,7 +135,8 @@ private[spark] class TaskSchedulerImpl(

@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer("task-starvation-timer", true)
private val starvationTimer = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"task-starvation-timer")

// Incrementing task IDs
val nextTaskId = new AtomicLong(0)
Expand Down Expand Up @@ -166,7 +167,7 @@ private[spark] class TaskSchedulerImpl(

protected val executorIdToHost = new HashMap[String, String]

private val abortTimer = new Timer("task-abort-timer", true)
private val abortTimer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-abort-timer")
// Exposed for testing
val unschedulableTaskSetToExpiryTime = new HashMap[TaskSetManager, Long]

Expand Down Expand Up @@ -282,7 +283,7 @@ private[spark] class TaskSchedulerImpl(
this.cancel()
}
}
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
hasReceivedTask = true
}
Expand Down Expand Up @@ -737,7 +738,7 @@ private[spark] class TaskSchedulerImpl(
logInfo(s"Waiting for $timeout ms for completely " +
s"excluded task to be schedulable again before aborting stage ${taskSet.stageId}.")
abortTimer.schedule(
createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout)
createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout, TimeUnit.MILLISECONDS)
}

private def createUnschedulableTaskSetAbortTimer(
Expand Down Expand Up @@ -963,8 +964,8 @@ private[spark] class TaskSchedulerImpl(
barrierCoordinator.stop()
}
}
starvationTimer.cancel()
abortTimer.cancel()
ThreadUtils.shutdown(starvationTimer)
ThreadUtils.shutdown(abortTimer)
}

override def defaultParallelism(): Int = backend.defaultParallelism()
Expand Down
11 changes: 4 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.ui

import java.util.concurrent.{Executors, TimeUnit}

import com.google.common.util.concurrent.ThreadFactoryBuilder
import java.util.concurrent.TimeUnit

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI._
import org.apache.spark.status.api.v1.StageData
import org.apache.spark.util.ThreadUtils

/**
* ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the
Expand All @@ -48,9 +47,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
private var lastProgressBar = ""

// Schedule a refresh thread to run periodically
private val threadFactory =
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("refresh progress").build()
private val timer = Executors.newSingleThreadScheduledExecutor(threadFactory)
private val timer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("refresh progress")
timer.scheduleAtFixedRate(
() => refresh(), firstDelayMSec, updatePeriodMSec, TimeUnit.MILLISECONDS)

Expand Down Expand Up @@ -124,5 +121,5 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
* Tear down the timer thread. The timer thread is a GC root, and it retains the entire
* SparkContext if it's not terminated.
*/
def stop(): Unit = timer.shutdown()
def stop(): Unit = ThreadUtils.shutdown(timer)
}
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private[spark] object ThreadUtils {
}

/**
* Wrapper over newSingleThreadExecutor.
* Wrapper over newFixedThreadPool with single daemon thread.
*/
def newDaemonSingleThreadExecutor(threadName: String): ThreadPoolExecutor = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
Expand Down Expand Up @@ -189,7 +189,7 @@ private[spark] object ThreadUtils {
}

/**
* Wrapper over ScheduledThreadPoolExecutor.
* Wrapper over ScheduledThreadPoolExecutor the pool with daemon threads.
*/
def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
Expand All @@ -200,6 +200,18 @@ private[spark] object ThreadUtils {
executor
}

/**
* Wrapper over ScheduledThreadPoolExecutor the pool with non-daemon threads.
*/
def newSingleThreadScheduledExecutor(threadName: String): ScheduledThreadPoolExecutor = {
val threadFactory = new ThreadFactoryBuilder().setNameFormat(threadName).build()
val executor = new ScheduledThreadPoolExecutor(1, threadFactory)
// By default, a cancelled task is not automatically removed from the work queue until its delay
// elapses. We have to enable it manually.
executor.setRemoveOnCancelPolicy(true)
executor
}

/**
* Wrapper over ScheduledThreadPoolExecutor.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private void acceptConnections() {
try {
while (running) {
final Socket client = server.accept();
TimerTask timeout = new TimerTask() {
TimerTask timerTask = new TimerTask() {
@Override
public void run() {
LOG.warning("Timed out waiting for hello message from client.");
Expand All @@ -236,7 +236,7 @@ public void run() {
}
}
};
ServerConnection clientConnection = new ServerConnection(client, timeout);
ServerConnection clientConnection = new ServerConnection(client, timerTask);
Thread clientThread = factory.newThread(clientConnection);
clientConnection.setConnectionThread(clientThread);
synchronized (clients) {
Expand All @@ -247,9 +247,9 @@ public void run() {
// 0 is used for testing to avoid issues with clock resolution / thread scheduling,
// and force an immediate timeout.
if (timeoutMs > 0) {
timeoutTimer.schedule(timeout, timeoutMs);
timeoutTimer.schedule(timerTask, timeoutMs);
} else {
timeout.run();
timerTask.run();
}

clientThread.start();
Expand Down

0 comments on commit 5d5b3a5

Please sign in to comment.