Skip to content

Commit

Permalink
Add Pool
Browse files Browse the repository at this point in the history
  • Loading branch information
chuwy committed Jul 12, 2021
1 parent ad7e781 commit 1cfb5d2
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 4 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Expand Up @@ -70,6 +70,7 @@ lazy val loader = project.in(file("modules/loader"))
.settings(resolvers ++= Dependencies.resolutionRepos)
.settings(
addCompilerPlugin("com.olegpy" %% "better-monadic-for" % "0.3.1"),
Compile / scalacOptions ~= filterConsoleScalacOptions, // TODO: REMOVE
libraryDependencies ++= Seq(
Dependencies.slf4j,
Dependencies.redshift,
Expand All @@ -89,7 +90,8 @@ lazy val loader = project.in(file("modules/loader"))

Dependencies.specs2,
Dependencies.specs2ScalaCheck,
Dependencies.scalaCheck
Dependencies.scalaCheck,
Dependencies.catsTesting,
)
)
.dependsOn(common % "compile->compile;test->test", aws)
Expand Down
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2012-2021 Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Apache License Version 2.0,
* and you may not use this file except in compliance with the Apache License Version 2.0.
* You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the Apache License Version 2.0 is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the Apache License Version 2.0 for the specific language governing permissions and limitations there under.
*/
package com.snowplowanalytics.snowplow.rdbloader.db

import cats.implicits._

import cats.effect.Concurrent
import cats.effect.concurrent.{Ref, Semaphore}

import fs2.concurrent.Queue

/**
* A connection pool like entity, managing acquisition and use of several
* resources `R`. The `use` function is completely transparent to user
* code and identical to `Resource#use`, but in case of `Pool` a function
* would do one of the following:
* 1. Receive a free pre-allocated resource
* 2. Trigger creation of a new one if capacity allows
* 3. Block (semantically) until other fibers release a resource if capacity
* doesn't allow
*
* @tparam F an effect type, usually `IO`
* @tparam R a resource type, such as DB
*/
trait Pool[F[_], R] {
def use[O](f: R => F[O]): F[O]
}

object Pool {

def createQ[F[_]: Concurrent, R](acquire: F[R], release: R => F[Unit], max: Int): F[Pool[F, R]] = {
val resourceP = acquire.map(res => ResourceP(res, release(res)))
for {
resourceQueue <- Queue.bounded[F, ResourceP[F, R]](max)
semaphore <- Semaphore(max.toLong)
availableR <- Ref.of[F, Int](max)
} yield new Pool[F, R] {
def use[O](f: R => F[O]): F[O] = {
def useAndReturn(r: ResourceP[F, R]): F[O] =
Concurrent[F].attempt(f(r.resource)).flatMap {
case Right(result) => resourceQueue.enqueue1(r).as(result)
case Left(error) => availableR.update(_ + 1) *> r.release *> Concurrent[F].raiseError(error)
}

// The semaphore protects otherwise thread-unsafe Ref.get -> Ref.update chain
// Otherwise max+1 fibers could run into (available <= 0) branch deadlocking dequeue
semaphore.acquire *>
resourceQueue.tryDequeue1.flatMap {
case Some(r) =>
useAndReturn(r)
case None =>
availableR.get.flatMap { available =>
if (available <= 0) resourceQueue.dequeue1.flatMap(useAndReturn)
else availableR.update(_ - 1) *> resourceP.flatMap(useAndReturn) <* availableR.update(_ + 1)
}
} <* semaphore.release
}
}
}

private case class ResourceP[F[_], R](resource: R, release: F[Unit])
}
@@ -0,0 +1,196 @@
/*
* Copyright (c) 2012-2021 Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Apache License Version 2.0,
* and you may not use this file except in compliance with the Apache License Version 2.0.
* You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the Apache License Version 2.0 is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the Apache License Version 2.0 for the specific language governing permissions and limitations there under.
*/
package com.snowplowanalytics.snowplow.rdbloader.db

import scala.concurrent.duration._

import cats.Applicative
import cats.implicits._

import cats.effect.{Timer, IO, Concurrent, Sync}
import cats.effect.concurrent.Ref
import cats.effect.implicits._

import com.snowplowanalytics.snowplow.rdbloader.db.Pool.{createQ, ResourceP}
import cats.effect.testing.specs2.CatsIO

import org.scalacheck.Arbitrary._
import org.scalacheck.{Gen, Shrink}
import org.specs2.mutable.Specification
import org.specs2.ScalaCheck

class PoolSpec extends Specification with CatsIO with ScalaCheck {
import PoolSpec.State.Message

override val Timeout = 5.seconds

"Pool" should {
"create no more resources than max" in {
val argsGen = for {
resN <- Gen.chooseNum(2, 10)
fibN <- Gen.chooseNum(resN, resN * 10)
} yield (resN, fibN)

prop { args: (Int, Int) =>
val (resourceN, fibersN) = args
PoolSpec
.runConcurrentTest[IO](PoolSpec.runUseNoSleep[IO])(resourceN, fibersN)
.map(_.getLog)
.map { log =>
val resources = log.collect { case Message.Using(id, _) => id }
val iterations = log.collect { case Message.Using(_, iteration) => iteration }

resources.toSet.size must beBetween(2, resourceN)
iterations must haveSize(fibersN)
}
}.setGen(argsGen).setShrink(Shrink(_ => Stream.empty[(Int, Int)]))
}

"maintain a single resource with sequential jobs" in {
val argsGen = for {
resN <- Gen.chooseNum(1, 100)
jobsN <- Gen.chooseNum(1, 15)
} yield (resN, jobsN)

prop { args: (Int, Int) =>
val (resourceN, jobsN) = args
PoolSpec
.runSyncTest[IO](PoolSpec.runUseNoSleep[IO])(resourceN, jobsN)
.map(_.getLog)
.map { log =>
val resources = log.collect { case Message.Using(id, _) => id }
val iterations = log.collect { case Message.Using(_, iteration) => iteration }

resources.toSet.size must beEqualTo(1)
iterations must haveSize(jobsN)
}
}.setGen(argsGen).setShrink(Shrink(_ => Stream.empty[(Int, Int)]))
}

"acquire new resources in case of a failure" in {
val argsGen = for {
resN <- Gen.chooseNum(2, 10)
fibN <- Gen.chooseNum(resN, resN * 10)
} yield (resN, fibN)

def runUse(stateRef: Ref[IO, PoolSpec.State])(resource: Int): IO[Unit] =
for {
state <- stateRef.updateAndGet(_.incrementUse)
_ <- if (state.use % 10 == 0)
stateRef.update(_.write(PoolSpec.State.Message.Failure(resource, state.use))) *>
IO.raiseError(new RuntimeException(s"Interrupting resource $resource at ${state.use}"))
else IO.unit
_ <- stateRef.update(_.write(PoolSpec.State.Message.Using(resource, state.use)))
} yield ()

prop { args: (Int, Int) =>
val (resourceN, fibersN) = args
PoolSpec
.runSyncTest[IO](runUse)(resourceN, fibersN)
.map { state =>
val log = state.getLog
val resources = log.collect { case Message.Using(id, _) => id }
val iterations = log.collect { case Message.Using(_, iteration) => iteration }
val failures = log.collect { case Message.Failure(_, iteration) => iteration }
val releases = log.collect { case Message.Released(id) => id }

resources.toSet.size must beLessThanOrEqualTo(resourceN)
iterations must haveSize(fibersN)
}
}.setGen(argsGen).setShrink(Shrink(_ => Stream.empty[(Int, Int)]))
}

// "foo" in {
//
// }
}
}

object PoolSpec {
case class State(resourceId: Int, use: Int, log: List[State.Message]) {
def increment: State =
State(resourceId + 1, use, log)

def incrementUse: State =
State(resourceId, use + 1, log)

def getLog: List[State.Message] =
log.reverse

def write(event: State.Message): State =
State(resourceId, use, event :: log)
}

object State {
def init: State =
State(0, 0, Nil)

sealed trait Message
object Message {
case class Using(resourceId: Int, iteration: Int) extends Message
case class Released(resourceId: Int) extends Message
case class Failure(resourceId: Int, iteration: Int) extends Message
}
}

def resource[F[_]: Sync](stateRef: Ref[F, State]): (F[Int], Int => F[Unit]) = {
val acquire = stateRef.updateAndGet(_.increment).map(_.resourceId)
val release = (id: Int) => stateRef.update(_.write(State.Message.Released(id)))
(acquire, release)
}

def noSleep[F[_]: Applicative](s: State): F[Option[Long]] =
Applicative[F].pure(None)

def runConcurrentTest[F[_]: Concurrent: Timer](runUse: Ref[F, State] => Int => F[Unit])
(resourcesN: Int, fibersN: Int): F[State] =
for {
state <- Ref.of[F, State](State.init)
(acquire, release) = resource[F](state)
_ <- createQ[F, Int](acquire, release, resourcesN).flatMap { pool =>
pool
.use(runUse(state))
.start
.replicateA(fibersN)
.flatMap(fibers => fibers.traverse_(f => Concurrent[F].attempt(f.join)).void)
}
latest <- state.get
} yield latest

def runSyncTest[F[_]: Concurrent: Timer](runUse: Ref[F, State] => Int => F[Unit])
(resourcesN: Int, jobsN: Int): F[State] =
for {
state <- Ref.of[F, State](State.init)
(acquire, release) = resource[F](state)
_ <- createQ[F, Int](acquire, release, resourcesN).flatMap { pool =>
pool
.use(runUse(state))
.replicateA(jobsN)
}
latest <- state.get
} yield latest

def runUse[F[_]: Timer: Sync](getSleep: State => F[Option[Long]], stateRef: Ref[F, State])(resource: Int): F[Unit] =
for {
state <- stateRef.updateAndGet(_.incrementUse)
_ <- getSleep(state).map(x => x.map(_.millis).fold(Sync[F].unit)(Timer[F].sleep))
_ <- stateRef.update(_.write(State.Message.Using(resource, state.use)))
} yield ()

def runUseNoSleep[F[_]: Sync: Timer](stateRef: Ref[F, PoolSpec.State])(resource: Int): F[Unit] =
for {
state <- stateRef.updateAndGet(_.incrementUse)
_ <- stateRef.update(_.write(PoolSpec.State.Message.Using(resource, state.use)))
} yield ()

}
8 changes: 5 additions & 3 deletions project/Dependencies.scala
Expand Up @@ -51,6 +51,7 @@ object Dependencies {

// Scala (test only)
val specs2 = "4.10.5"
val catsTesting = "0.5.3"
val scalaCheck = "1.14.3"
}

Expand Down Expand Up @@ -111,7 +112,8 @@ object Dependencies {
val aws2kinesis = "software.amazon.awssdk" % "kinesis" % V.aws2

// Scala (test only)
val specs2 = "org.specs2" %% "specs2-core" % V.specs2 % Test
val specs2ScalaCheck = "org.specs2" %% "specs2-scalacheck" % V.specs2 % Test
val scalaCheck = "org.scalacheck" %% "scalacheck" % V.scalaCheck % Test
val specs2 = "org.specs2" %% "specs2-core" % V.specs2 % Test
val specs2ScalaCheck = "org.specs2" %% "specs2-scalacheck" % V.specs2 % Test
val scalaCheck = "org.scalacheck" %% "scalacheck" % V.scalaCheck % Test
val catsTesting = "com.codecommit" %% "cats-effect-testing-specs2" % V.catsTesting % Test
}

0 comments on commit 1cfb5d2

Please sign in to comment.