Skip to content

Commit

Permalink
Fixes in preparation for ZIO 2.0.15 (#35320)
Browse files Browse the repository at this point in the history
* batch event loop fix

* fix worker timeout interruption. cleanup code

GitOrigin-RevId: 31023194ccc982dfe8720620bfbfe26dc274918c
  • Loading branch information
leonbur authored and wix-oss committed Oct 5, 2023
1 parent c0b828d commit f8c6082
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,49 +104,46 @@ class ParallelConsumerIT extends BaseTestWithSharedEnv[Env, TestResources] {
fastMessages = allMessages - 1
drainTimeout = 5.seconds

keyWithSlowHandling = "slow-key"
numProcessedMessges <- Ref.make[Int](0)
fastMessagesLatch <- CountDownLatch.make(fastMessages)
numProcessedMessages <- Ref.make[Int](0)
fastMessagesLatch <- CountDownLatch.make(fastMessages)

randomKeys <- ZIO.foreach(1 to fastMessages)(i => randomKey(i.toString)).map(_.toSeq)

fastRecords = randomKeys.map { key => recordWithKey(topic, key, partition) }
slowRecord = recordWithKey(topic, keyWithSlowHandling, partition)
slowRecord = recordWithoutKey(topic, partition)

finishRebalance <- Promise.make[Nothing, Unit]

// handler that sleeps only on the slow key
handler = RecordHandler { cr: ConsumerRecord[Chunk[Byte], Chunk[Byte]] =>
(cr.key match {
case Some(k) if k == Chunk.fromArray(keyWithSlowHandling.getBytes) =>
// make sure the handler doesn't finish before the rebalance is done, including drain timeout
finishRebalance.await *> ZIO.sleep(drainTimeout + 1.second)
case _ => fastMessagesLatch.countDown
}) *> numProcessedMessges.update(_ + 1)
}
_ <-
for {
consumer <- makeParallelConsumer(handler, kafka, topic, group, cId, drainTimeout = drainTimeout, startPaused = true)
_ <- produceRecords(producer, Seq(slowRecord))
_ <- produceRecords(producer, fastRecords)
_ <- ZIO.sleep(2.seconds)
// produce is done synchronously to make sure all records are produced before consumer starts, so all records are polled at once
_ <- consumer.resume
_ <- fastMessagesLatch.await
_ <- ZIO.sleep(3.second) // sleep to ensure commit is done before rebalance
// start another consumer to trigger a rebalance before slow handler is done
_ <- makeParallelConsumer(
handler,
kafka,
topic,
group,
cId,
drainTimeout = drainTimeout,
onAssigned = _ => finishRebalance.succeed()
)
} yield ()

_ <- eventuallyZ(numProcessedMessges.get, 25.seconds)(_ == allMessages)
handler = RecordHandler { cr: ConsumerRecord[Chunk[Byte], Chunk[Byte]] =>
(cr.key match {
case Some(_) =>
fastMessagesLatch.countDown
case None =>
// make sure the handler doesn't finish before the rebalance is done, including drain timeout
finishRebalance.await *> ZIO.sleep(drainTimeout + 5.second)
}) *> numProcessedMessages.update(_ + 1)
}
consumer <- makeParallelConsumer(handler, kafka, topic, group, cId, drainTimeout = drainTimeout, startPaused = true)
_ <- produceRecords(producer, Seq(slowRecord))
_ <- produceRecords(producer, fastRecords)
_ <- ZIO.sleep(2.seconds)
// produce is done synchronously to make sure all records are produced before consumer starts, so all records are polled at once
_ <- consumer.resume
_ <- fastMessagesLatch.await
_ <- ZIO.sleep(3.second) // sleep to ensure commit is done before rebalance
// start another consumer to trigger a rebalance before slow handler is done
_ <- makeParallelConsumer(
handler,
kafka,
topic,
group,
cId,
drainTimeout = drainTimeout,
onAssigned = _ => finishRebalance.succeed()
)

_ <- eventuallyZ(numProcessedMessages.get, 25.seconds)(_ == allMessages)
} yield {
ok
}
Expand Down Expand Up @@ -319,6 +316,9 @@ class ParallelConsumerIT extends BaseTestWithSharedEnv[Env, TestResources] {
private def recordWithKey(topic: String, key: String, partition: Int) =
ProducerRecord(topic, "", Some(key), partition = Some(partition))

private def recordWithoutKey(topic: String, partition: Int) =
ProducerRecord(topic, "", None, partition = Some(partition))

private def randomKey(prefix: String) =
randomId.map(r => s"$prefix-$r")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,18 @@ object Dispatcher {
ZIO
.foreachParDiscard(workers) {
case (partition, worker) =>
report(StoppingWorker(group, clientId, partition, drainTimeout.toMillis, consumerAttributes)) *>
workersShutdownRef.get.flatMap(_.get(partition).fold(ZIO.unit)(promise => promise.onShutdown.shuttingDown)) *>
worker.shutdown
.catchSomeCause {
case _: Cause[InterruptedException] => ZIO.unit
} // happens on revoke - must not fail on it so we have visibility to worker completion
.timed
.map(_._1)
.flatMap(duration => report(WorkerStopped(group, clientId, partition, duration.toMillis, consumerAttributes)))
for {
_ <- report(StoppingWorker(group, clientId, partition, drainTimeout.toMillis, consumerAttributes))
workersShutdownMap <- workersShutdownRef.get
_ <- workersShutdownMap.get(partition).fold(ZIO.unit)(promise => promise.onShutdown.shuttingDown)
duration <- worker.shutdown
.catchSomeCause {
case _: Cause[InterruptedException] => ZIO.unit
} // happens on revoke - must not fail on it so we have visibility to worker completion
.timed
.map(_._1)
_ <- report(WorkerStopped(group, clientId, partition, duration.toMillis, consumerAttributes))
} yield ()
}
.resurrect
.ignore
Expand Down Expand Up @@ -324,7 +327,7 @@ object Dispatcher {
override def shutdown: URIO[Any, Unit] =
for {
_ <- internalState.update(_.shutdown).commit
timeout <- fiber.join.ignore.disconnect.timeout(drainTimeout)
timeout <- fiber.join.ignore.interruptible.timeout(drainTimeout)
_ <- ZIO.when(timeout.isEmpty)(fiber.interruptFork)
} yield ()

Expand Down Expand Up @@ -404,19 +407,26 @@ object Dispatcher {
case DispatcherState.Running =>
queue.poll.flatMap {
case Some(record) =>
report(TookRecordFromQueue(record, group, clientId, consumerAttributes)) *>
ZIO
.attempt(currentTimeMillis())
.flatMap(t => internalState.updateAndGet(_.startedWith(t)).commit)
.tapBoth(
e => report(FailToUpdateCurrentExecutionStarted(record, group, clientId, consumerAttributes, e)),
t => report(CurrentExecutionStartedEvent(partition, group, clientId, t.currentExecutionStarted))
) *> handle(record).interruptible.ignore *> isActive(internalState)
case None => isActive(internalState).delay(5.millis)
for {
_ <- report(TookRecordFromQueue(record, group, clientId, consumerAttributes))
clock <- ZIO.clock
executionStartTime <- clock.currentTime(TimeUnit.MILLISECONDS)
_ <- internalState
.updateAndGet(_.startedWith(executionStartTime))
.commit
_ <- report(CurrentExecutionStartedEvent(partition, group, clientId, Some(executionStartTime)))
_ <- handle(record).interruptible.ignore
active <- isActive(internalState)
} yield active
case None =>
isActive(internalState).delay(5.millis)
}
case DispatcherState.Paused(resume) =>
report(WorkerWaitingForResume(group, clientId, partition, consumerAttributes)) *> resume.await.timeout(30.seconds) *>
isActive(internalState)
for {
_ <- report(WorkerWaitingForResume(group, clientId, partition, consumerAttributes))
_ <- resume.await.timeout(30.seconds)
active <- isActive(internalState)
} yield active
case DispatcherState.ShuttingDown =>
ZIO.succeed(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,21 @@ private[greyhound] class BatchEventLoopImpl[R](
)
)

private def pollAndHandle()(implicit trace: Trace): URIO[R, Unit] = for {
_ <- pauseAndResume().provide(ZLayer.succeed(capturedR))
records <-
consumer
.poll(config.fetchTimeout)
.provide(ZLayer.succeed(capturedR))
.catchAll(_ => ZIO.succeed(Nil))
.flatMap(records => seekRequests.get.map(seeks => records.filterNot(record => seeks.keys.toSet.contains(record.topicPartition))))
_ <- handleRecords(records).timed
.tap { case (duration, _) => report(FullBatchHandled(clientId, group, records.toSeq, duration, consumerAttributes)) }
private def pollAndHandle()(implicit trace: Trace): URIO[R with GreyhoundMetrics, Unit] = for {
_ <- pauseAndResume().ignore
allRecords <- consumer
.poll(config.fetchTimeout)
.catchAll(_ => ZIO.succeed(Nil))
seeks <- seekRequests.get.map(_.keySet)
records = allRecords.filterNot(record => seeks.contains(record.topicPartition))
_ <- handleRecords(records).timed
.tap { case (duration, _) => report(FullBatchHandled(clientId, group, records.toSeq, duration, consumerAttributes)) }
} yield ()

private def pauseAndResume()(implicit trace: Trace) = for {
pr <- elState.shouldPauseAndResume()
_ <- ZIO.when(pr.toPause.nonEmpty)((consumer.pause(pr.toPause) *> elState.partitionsPaused(pr.toPause)).ignore)
_ <- ZIO.when(pr.toResume.nonEmpty)((consumer.resume(pr.toResume) *> elState.partitionsResumed(pr.toResume)).ignore)
_ <- ZIO.when(pr.toPause.nonEmpty)(consumer.pause(pr.toPause) *> elState.partitionsPaused(pr.toPause))
_ <- ZIO.when(pr.toResume.nonEmpty)(consumer.resume(pr.toResume) *> elState.partitionsResumed(pr.toResume))
} yield ()

private def handleRecords(polled: Records)(implicit trace: Trace): ZIO[R, Nothing, Unit] = {
Expand Down Expand Up @@ -512,10 +511,10 @@ private[greyhound] class BatchEventLoopState(
partitionsPaused(pauseResume.toPause) *> partitionsResumed(pauseResume.toResume)

def shouldPauseAndResume[R]()(implicit trace: Trace): URIO[R, PauseResume] = for {
pending <- allPending
pending <- allPending.map(_.keySet)
paused <- pausedPartitions
toPause = pending.keySet -- paused
toResume = paused -- pending.keySet
toPause = pending -- paused
toResume = paused -- pending
} yield PauseResume(toPause, toResume)

def appendPending(records: Consumer.Records)(implicit trace: Trace): UIO[Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class BatchEventLoopTest extends JUnitRunnableSpec {
ZIO.scoped(BatchEventLoop.make(group, ConsumerSubscription.topics(topics: _*), consumer, handler, clientId, retry).flatMap {
loop =>
for {
_ <- ZIO.succeed(println(s"Should not retry for retry: $retry, cause: $cause"))
_ <- ZIO.debug(s"Should not retry for retry: $retry, cause: $cause")
_ <- givenHandleError(failOnPartition(0, cause))
_ <- givenRecords(consumerRecords)
handledRecords <- handled.await(_.nonEmpty)
Expand All @@ -79,7 +79,7 @@ class BatchEventLoopTest extends JUnitRunnableSpec {
ZIO.scoped(BatchEventLoop.make(group, ConsumerSubscription.topics(topics: _*), consumer, handler, clientId, Some(retry)).flatMap {
loop =>
for {
_ <- ZIO.succeed(println(s"Should retry for cause: $cause"))
_ <- ZIO.debug(s"Should retry for cause: $cause")
_ <- givenHandleError(failOnPartition(0, cause))
_ <- givenRecords(consumerRecords)
handled1 <- handled.await(_.nonEmpty)
Expand Down Expand Up @@ -153,13 +153,14 @@ class BatchEventLoopTest extends JUnitRunnableSpec {

val consumer = new EmptyConsumer {
override def poll(timeout: Duration)(implicit trace: Trace): Task[Records] =
queue.take
queue.take.interruptible
.timeout(timeout)
.map(_.getOrElse(Iterable.empty))
.tap(r => ZIO.succeed(println(s"poll($timeout): $r")))
.tap(r => ZIO.debug(s"poll($timeout): $r"))


override def commit(offsets: Map[TopicPartition, Offset])(implicit trace: Trace): Task[Unit] = {
ZIO.succeed(println(s"commit($offsets)")) *> committedOffsetsRef.update(_ ++ offsets)
ZIO.debug(s"commit($offsets)") *> committedOffsetsRef.update(_ ++ offsets)
}

override def commitWithMetadata(offsetsAndMetadata: Map[TopicPartition, OffsetAndMetadata])(
Expand Down Expand Up @@ -190,14 +191,16 @@ class BatchEventLoopTest extends JUnitRunnableSpec {
}
}

val handler = new BatchRecordHandler[Any, Throwable, Chunk[Byte], Chunk[Byte]] {
override def handle(records: RecordBatch): ZIO[Any, HandleError[Throwable], Any] = {
ZIO.succeed(println(s"handle($records)")) *>
(handlerErrorsRef.get.flatMap(he => he(records.records).fold(ZIO.unit: IO[HandleError[Throwable], Unit])(ZIO.failCause(_))) *>
handled.update(_ :+ records.records))
.tapErrorCause(e => ZIO.succeed(println(s"handle failed with $e, records: $records")))
.tap(_ => ZIO.succeed(println(s"handled $records")))
}
val handler = new BatchRecordHandler[Any, Throwable, Chunk[Byte], Chunk[Byte]] {
override def handle(records: RecordBatch): ZIO[Any, HandleError[Throwable], Any] = for {
_ <- ZIO.debug(s"handle($records)")
he <- handlerErrorsRef.get
_ <- he(records.records).fold(ZIO.unit: IO[HandleError[Throwable], Unit])(ZIO.failCause(_))
_ <- handled
.update(_ :+ records.records)
.tapErrorCause(e => ZIO.debug(s"handle failed with $e, records: $records"))
.tap(_ => ZIO.debug(s"handled $records"))
} yield ()
}

def givenRecords(records: Seq[Consumer.Record]) = queue.offer(records)
Expand Down

0 comments on commit f8c6082

Please sign in to comment.