Skip to content

Commit

Permalink
STM - Next level (#863)
Browse files Browse the repository at this point in the history
* STM optimization

* Further optimizations

* Missing parens
  • Loading branch information
jdegoes committed May 17, 2019
1 parent e42a962 commit dfbb551
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 448 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ lazy val root = project
.aggregate(
coreJVM,
coreJS,
docs,
streamsJVM,
streamsJS,
interopSharedJVM,
Expand All @@ -50,8 +51,7 @@ lazy val root = project
interopReactiveStreamsJVM,
interopTwitterJVM,
benchmarks,
testkitJVM,
docs
testkitJVM
)
.enablePlugins(ScalaJSPlugin)

Expand Down
138 changes: 23 additions & 115 deletions core/jvm/src/test/scala/scalaz/zio/stm/STMSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,26 @@ final class STMSpec(implicit ee: org.specs2.concurrent.ExecutionEnv) extends Tes

def e15 =
unsafeRun(
unsafeRun(
for {
tVar <- TRef.makeCommit(0)
fiber <- ZIO.forkAll(List.fill(10)(incrementVarN(99, tVar)))
_ <- fiber.join
} yield tVar.get
).commit
for {
tVar <- TRef.makeCommit(0)
fiber <- ZIO.forkAll(List.fill(10)(incrementVarN(99, tVar)))
_ <- fiber.join
value <- tVar.get.commit
} yield value
) must_=== 1000

def e16 =
unsafeRun(
unsafeRun(
for {
tVars <- STM
.atomically(
TRef.make(10000) <*> TRef.make(0) <*> TRef.make(0)
)
tvar1 <*> tvar2 <*> tvar3 = tVars
fiber <- ZIO.forkAll(List.fill(10)(compute3VarN(99, tvar1, tvar2, tvar3)))
_ <- fiber.join
} yield tvar3.get
).commit
for {
tVars <- STM
.atomically(
TRef.make(10000) <*> TRef.make(0) <*> TRef.make(0)
)
tvar1 <*> tvar2 <*> tvar3 = tVars
fiber <- ZIO.forkAll(List.fill(10)(compute3VarN(99, tvar1, tvar2, tvar3)))
_ <- fiber.join
value <- tvar3.get.commit
} yield value
) must_=== 10000

def e17 =
Expand Down Expand Up @@ -240,26 +238,20 @@ final class STMSpec(implicit ee: org.specs2.concurrent.ExecutionEnv) extends Tes
for {
tvar1 <- TRef.makeCommit(10)
tvar2 <- TRef.makeCommit("Failed!")
fiber <- (
for {
v1 <- tvar1.get
_ <- STM.check(v1 > 0)
_ <- tvar2.set("Succeeded!")
v2 <- tvar2.get
} yield v2
).commit.fork
join <- fiber.join
join <- (for {
v1 <- tvar1.get
_ <- STM.check(v1 > 0)
_ <- tvar2.set("Succeeded!")
v2 <- tvar2.get
} yield v2).commit
} yield join must_=== "Succeeded!"
}

def e20 =
unsafeRun {
for {
tvar <- TRef.makeCommit(42)
fiber <- STM.atomically {
tvar.get.filter(_ == 42)
}.fork
join <- fiber.join
join <- tvar.get.filter(_ == 42).commit
_ <- tvar.set(9).commit
v <- tvar.get.commit
} yield (v must_=== 9) and (join must_=== 42)
Expand Down Expand Up @@ -558,87 +550,3 @@ final class STMSpec(implicit ee: org.specs2.concurrent.ExecutionEnv) extends Tes
} yield ()

}

object Examples {
object mutex {
type Mutex = TRef[Boolean]
val makeMutex = TRef.make(false).commit
def acquire(mutex: Mutex): UIO[Unit] =
(for {
value <- mutex.get
_ <- STM.check(!value)
_ <- mutex.set(true)
} yield ()).commit
def release(mutex: Mutex): UIO[Unit] =
mutex.set(false).commit.unit
def withMutex[R, E, A](mutex: Mutex)(zio: ZIO[R, E, A]): ZIO[R, E, A] =
acquire(mutex).bracket_[R, E].apply[R](release(mutex))[R, E, A](zio)
}
object semaphore {
type Semaphore = TRef[Int]
def makeSemaphore(n: Int): UIO[Semaphore] = TRef.makeCommit(n)
def acquire(semaphore: Semaphore, n: Int): UIO[Unit] =
(for {
value <- semaphore.get
_ <- STM.check(value >= n)
_ <- semaphore.set(value - n)
} yield ()).commit
def release(semaphore: Semaphore, n: Int): UIO[Unit] =
semaphore.update(_ + n).commit.unit
}
object promise {
type Promise[A] = TRef[Option[A]]
def makePromise[A]: UIO[Promise[A]] = TRef.makeCommit(None)
def complete[A](promise: Promise[A], v: A): UIO[Boolean] =
(for {
value <- promise.get
change <- value match {
case Some(_) => STM.succeed(false)
case None =>
promise.set(Some(v)) *>
STM.succeed(true)
}
} yield change).commit
def await[A](promise: Promise[A]): UIO[A] =
promise.get.collect {
case Some(a) => a
}.commit
}
object queue {
import scala.collection.immutable.{ Queue => ScalaQueue }

case class Queue[A](capacity: Int, tvar: TRef[ScalaQueue[A]])
def makeQueue[A](capacity: Int): UIO[Queue[A]] =
TRef.makeCommit(ScalaQueue.empty[A]).map(Queue(capacity, _))
def offer[A](queue: Queue[A], a: A): UIO[Unit] =
(for {
q <- queue.tvar.get
_ <- STM.check(q.length < queue.capacity)
_ <- queue.tvar.update(_ enqueue a)
} yield ()).commit
def take[A](queue: Queue[A]): UIO[A] =
(for {
q <- queue.tvar.get
a <- q.dequeueOption match {
case Some((a, as)) =>
queue.tvar.set(as) *> STM.succeed(a)
case _ => STM.retry
}
} yield a).commit
}
object fun {
case class Phone(value: String)
case class Developer(name: String, phone: Phone)
def page(phone: Phone, message: String): UIO[Unit] = ???
def pager(sysErrors: TRef[Int], onDuty: TRef[Set[Developer]]): UIO[Unit] =
(for {
errors <- sysErrors.get
_ <- STM.check(errors > 100)
devs <- onDuty.get
any <- devs.headOption match {
case Some(dev) => STM.succeed(dev)
case _ => STM.retry
}
} yield any).commit.flatMap(dev => page(dev.phone, "Wake up, too many bugs!"))
}
}
43 changes: 20 additions & 23 deletions core/shared/src/main/scala/scalaz/zio/Promise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package scalaz.zio

import java.util.concurrent.atomic.AtomicReference
import scalaz.zio.internal.Executor
import Promise.internal._

/**
Expand Down Expand Up @@ -108,35 +107,32 @@ class Promise[E, A] private (private val state: AtomicReference[State[E, A]]) ex
* has already been completed, the method will produce false.
*/
final def done(io: IO[E, A]): UIO[Boolean] =
IO.flatten(IO.effectTotal {
var action: UIO[Boolean] = null.asInstanceOf[UIO[Boolean]]
var retry = true
IO.effectTotal {
var action: () => Boolean = null.asInstanceOf[() => Boolean]
var retry = true

while (retry) {
val oldState = state.get

val newState = oldState match {
case Pending(joiners) =>
action =
IO.forkAll_(joiners.map(k => IO.effectTotal[Unit](k(io)))) *>
IO.succeed[Boolean](true)
while (retry) {
val oldState = state.get

Done(io)
val newState = oldState match {
case Pending(joiners) =>
action = () => { joiners.foreach(_(io)); true }

case Done(_) =>
action = IO.succeed[Boolean](false)
Done(io)

oldState
}
case Done(_) =>
action = Promise.ConstFalse

retry = !state.compareAndSet(oldState, newState)
oldState
}

action
})
.uninterruptible
retry = !state.compareAndSet(oldState, newState)
}

action()
}

private[zio] final def unsafeDone(io: IO[E, A], exec: Executor): Unit = {
private[zio] final def unsafeDone(io: IO[E, A]): Unit = {
var retry: Boolean = true
var joiners: List[IO[E, A] => Unit] = null

Expand All @@ -153,7 +149,7 @@ class Promise[E, A] private (private val state: AtomicReference[State[E, A]]) ex
retry = !state.compareAndSet(oldState, newState)
}

if (joiners ne null) joiners.reverse.foreach(k => exec.submit(() => k(io)))
if (joiners ne null) joiners.reverse.foreach(_(io))
}

private def interruptJoiner(joiner: IO[E, A] => Unit): Canceler = IO.effectTotal {
Expand All @@ -175,6 +171,7 @@ class Promise[E, A] private (private val state: AtomicReference[State[E, A]]) ex
}
}
object Promise {
private val ConstFalse: () => Boolean = () => false

/**
* Makes a new promise.
Expand Down

0 comments on commit dfbb551

Please sign in to comment.