-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Issue-1177 TArray * STM.die when array index out of bounds
- Loading branch information
Showing
2 changed files
with
324 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
/* | ||
* Copyright 2017-2019 John A. De Goes and the ZIO Contributors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package zio.stm | ||
|
||
import zio._ | ||
|
||
final class TArraySpec(implicit ee: org.specs2.concurrent.ExecutionEnv) extends TestRuntime { | ||
def is = "TArraySpec".title ^ s2""" | ||
apply: | ||
happy-path $applyHappy | ||
dies with ArrayIndexOutOfBounds when index is out of bounds $applyOutOfBounds | ||
collect: | ||
is atomic $collectAtomic | ||
is safe for empty array $collectEmpty | ||
fold: | ||
is atomic $foldAtomic | ||
foldM: | ||
is atomic $foldMAtomic | ||
returns effect failure $foldMFailure | ||
foreach: | ||
side-effect is transactional $foreachTransactional | ||
map: | ||
creates new array atomically $mapAtomically | ||
mapM: | ||
creates new array atomically $mapMAtomically | ||
returns effect failure $mapMFailure | ||
transform: | ||
updates values atomically $transformAtomically | ||
transformM: | ||
updates values atomically $transformMAtomically | ||
updates all or nothing $transformMTransactionally | ||
update: | ||
happy-path $updateHappy | ||
dies with ArrayIndexOutOfBounds when index is out of bounds $updateOutOfBounds | ||
updateM: | ||
happy-path $updateMHappy | ||
dies with ArrayIndexOutOfBounds when index is out of bounds $updateMOutOfBounds | ||
updateM failure $updateMFailure | ||
""" | ||
|
||
def applyHappy = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(1)(42) | ||
value <- tArray(0).commit | ||
} yield value | ||
) mustEqual 42 | ||
|
||
def applyOutOfBounds = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(1)(42) | ||
_ <- tArray(-1).commit | ||
} yield () | ||
) must throwA[FiberFailure] | ||
|
||
def collectAtomic = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(N)("alpha-bravo-charlie") | ||
_ <- STM.foreach(tArray.array)(_.update(_.take(11))).commit.fork | ||
collected <- tArray.collect { | ||
case a if a.length == 11 => a | ||
}.commit | ||
} yield collected.array.size | ||
) must (equalTo(0) or equalTo(N)) | ||
|
||
def collectEmpty = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(0)("nothing") | ||
collected <- tArray.collect { | ||
case _ => () | ||
}.commit | ||
} yield collected.array.isEmpty | ||
) mustEqual true | ||
|
||
def foldAtomic = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(N)(0) | ||
sum1Fiber <- tArray.fold(0)(_ + _).commit.fork | ||
_ <- STM.foreach(0 until N)(i => tArray.array(i).update(_ + 1)).commit | ||
sum1 <- sum1Fiber.join | ||
} yield sum1 | ||
) must (equalTo(0) or equalTo(N)) | ||
|
||
def foldMAtomic = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(N)(0) | ||
sum1Fiber <- tArray.foldM(0)((z, a) => STM.succeed(z + a)).commit.fork | ||
_ <- STM.foreach(0 until N)(i => tArray.array(i).update(_ + 1)).commit | ||
sum1 <- sum1Fiber.join | ||
} yield sum1 | ||
) must (equalTo(0) or equalTo(N)) | ||
|
||
def foldMFailure = { | ||
def failInTheMiddle(acc: Int, a: Int): STM[Exception, Int] = | ||
if (acc == N / 2) STM.fail(boom) else STM.succeed(acc + a) | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(N)(1) | ||
res <- tArray.foldM(0)(failInTheMiddle).commit.either | ||
} yield res | ||
) mustEqual Left(boom) | ||
} | ||
|
||
def foreachTransactional = | ||
unsafeRun(for { | ||
ref <- TRef.make(0).commit | ||
tArray <- makeTArray(n)(1) | ||
_ <- tArray.foreach(a => ref.update(_ + a).const(())).commit.fork | ||
value <- ref.get.commit | ||
} yield value) must (equalTo(0) or equalTo(n)) | ||
|
||
def mapAtomically = | ||
unsafeRun(for { | ||
tArray <- makeTArray(N)("alpha-bravo-charlie") | ||
lengthsFiber <- tArray.map(_.length).commit.fork | ||
_ <- STM.foreach(0 until N)(i => tArray.array(i).set("abc")).commit | ||
lengths <- lengthsFiber.join | ||
firstAndLast <- lengths.array(0).get.zip(lengths.array(N - 1).get).commit | ||
} yield firstAndLast) must (equalTo((19, 19)) or equalTo((3, 3))) | ||
|
||
def mapMAtomically = | ||
unsafeRun(for { | ||
tArray <- makeTArray(N)("thisStringLengthIs20") | ||
lengthsFiber <- tArray.mapM(a => STM.succeedLazy(a.length)).commit.fork | ||
_ <- STM.foreach(0 until N)(idx => tArray.array(idx).set("abc")).commit | ||
lengths <- lengthsFiber.join | ||
first <- lengths.array(0).get.commit | ||
last <- lengths.array(N - 1).get.commit | ||
} yield (first, last)) must (equalTo((20, 20)) or equalTo((3, 3))) | ||
|
||
def mapMFailure = | ||
unsafeRun(for { | ||
tArray <- makeTArray(N)("abc") | ||
_ <- tArray.array(N / 2).update(_ => "").commit | ||
result <- tArray.mapM(a => if (a.isEmpty) STM.fail(boom) else STM.succeed(())).commit.either | ||
} yield result) mustEqual (Left(boom)) | ||
|
||
def transformAtomically = | ||
unsafeRun(for { | ||
tArray <- makeTArray(N)("a") | ||
transformFiber <- tArray.transform(_ + "+b").commit.fork | ||
_ <- STM.foreach(0 until N)(idx => tArray.array(idx).update(_ + "+c")).commit | ||
_ <- transformFiber.join | ||
first <- tArray.array(0).get.commit | ||
last <- tArray.array(N - 1).get.commit | ||
} yield (first, last)) must (equalTo(("a+b+c", "a+b+c")) or equalTo(("a+c+b", "a+c+b"))) | ||
|
||
def transformMAtomically = | ||
unsafeRun(for { | ||
tArray <- makeTArray(N)("a") | ||
transformFiber <- tArray.transformM(a => STM.succeedLazy(a + "+b")).commit.fork | ||
_ <- STM.foreach(0 until N)(idx => tArray.array(idx).update(_ + "+c")).commit | ||
_ <- transformFiber.join | ||
first <- tArray.array(0).get.commit | ||
last <- tArray.array(N - 1).get.commit | ||
} yield (first, last)) must (equalTo(("a+b+c", "a+b+c")) or equalTo(("a+c+b", "a+c+b"))) | ||
|
||
def transformMTransactionally = | ||
unsafeRun(for { | ||
tArray <- makeTArray(N)(0) | ||
_ <- tArray.array(N / 2).update(_ => 1).commit | ||
result <- tArray.transformM(a => if (a == 0) STM.succeed(42) else STM.fail(boom)).commit.either | ||
first <- tArray.array(0).get.commit | ||
} yield (first, result)) mustEqual ((0, Left(boom))) | ||
|
||
def updateHappy = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(1)(42) | ||
v <- tArray.update(0, a => -a).commit | ||
} yield v | ||
) mustEqual -42 | ||
|
||
def updateOutOfBounds = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(1)(42) | ||
_ <- tArray.update(-1, identity).commit | ||
} yield () | ||
) must throwA[FiberFailure] | ||
|
||
def updateMHappy = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(1)(42) | ||
v <- tArray.updateM(0, a => STM.succeedLazy(-a)).commit | ||
} yield v | ||
) mustEqual -42 | ||
|
||
def updateMOutOfBounds = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(10)(0) | ||
_ <- tArray.updateM(10, STM.succeed(_)).commit | ||
} yield () | ||
) must throwA[FiberFailure] | ||
|
||
def updateMFailure = | ||
unsafeRun( | ||
for { | ||
tArray <- makeTArray(n)(0) | ||
result <- tArray.updateM(0, _ => STM.fail(boom)).commit.either | ||
} yield result.fold(_.getMessage, _ => "unexpected") | ||
) mustEqual ("Boom!") | ||
|
||
private val N = 1000 | ||
private val n = 10 | ||
private val boom = new Exception("Boom!") | ||
|
||
private def makeTArray[T](n: Int)(a: T) = | ||
ZIO.sequence(List.fill(n)(TRef.makeCommit(a))).map(refs => TArray(refs.toArray)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright 2017-2019 John A. De Goes and the ZIO Contributors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package zio.stm | ||
|
||
/** Wraps array of [[TRef]] and adds methods for convenience. | ||
* Caution: most of methods are not stack-safe. | ||
* */ | ||
class TArray[A] private (val array: Array[TRef[A]]) extends AnyVal { | ||
|
||
/** Extracts value from ref in array. */ | ||
final def apply(index: Int): STM[Nothing, A] = | ||
if (0 <= index && index < array.size) array(index).get | ||
else STM.die(new ArrayIndexOutOfBoundsException(index)) | ||
|
||
final def collect[B](pf: PartialFunction[A, B]): STM[Nothing, TArray[B]] = | ||
this | ||
.foldM(List.empty[TRef[B]]) { | ||
case (acc, a) => | ||
if (pf.isDefinedAt(a)) TRef.make(pf(a)).map(tref => tref :: acc) | ||
else STM.succeed(acc) | ||
} | ||
.map(l => TArray(l.reverse.toArray)) | ||
|
||
/* Atomically folds [[TArray]] with pure function. */ | ||
final def fold[Z](acc: Z)(op: (Z, A) => Z): STM[Nothing, Z] = | ||
if (array.isEmpty) STM.succeed(acc) | ||
else array.head.get.flatMap(a => new TArray(array.tail).fold(op(acc, a))(op)) | ||
|
||
/** Atomically folds [[TArray]] with STM function. */ | ||
final def foldM[E, Z](acc: Z)(op: (Z, A) => STM[E, Z]): STM[E, Z] = | ||
if (array.isEmpty) STM.succeed(acc) | ||
else | ||
for { | ||
a <- array.head.get | ||
acc2 <- op(acc, a) | ||
res <- new TArray(array.tail).foldM(acc2)(op) | ||
} yield res | ||
|
||
/** Atomically performs side-effect for each item in array */ | ||
final def foreach[E](f: A => STM[E, Unit]): STM[E, Unit] = | ||
this.foldM(())((_, a) => f(a)) | ||
|
||
/** Creates [[TArray]] of new [[TRef]]s, mapped with pure function. */ | ||
final def map[B](f: A => B): STM[Nothing, TArray[B]] = | ||
this.mapM(f andThen STM.succeed) | ||
|
||
/** Creates [[TArray]] of new [[TRef]]s, mapped with transactional effect. */ | ||
final def mapM[E, B](f: A => STM[E, B]): STM[E, TArray[B]] = | ||
STM.foreach(array)(_.get.flatMap(f).flatMap(b => TRef.make(b))).map(l => new TArray(l.toArray)) | ||
|
||
/** Atomically updates all [[TRef]]s inside this array using pure function. */ | ||
final def transform(f: A => A): STM[Nothing, Unit] = | ||
(0 to array.size - 1).foldLeft(STM.succeed(())) { | ||
case (tx, idx) => array(idx).update(f) *> tx | ||
} | ||
|
||
/** Atomically updates all elements using transactional effect. */ | ||
final def transformM[E](f: A => STM[E, A]): STM[E, Unit] = | ||
(0 to array.size - 1).foldLeft[STM[E, Unit]](STM.succeed(())) { | ||
case (tx, idx) => | ||
val ref = array(idx) | ||
ref.get.flatMap(f).flatMap(a => ref.set(a)).flatMap(_ => tx) | ||
} | ||
|
||
/** Updates element in the array with given function. */ | ||
final def update(index: Int, fn: A => A): STM[Nothing, A] = | ||
if (0 <= index && index < array.size) array(index).update(fn) | ||
else STM.die(new ArrayIndexOutOfBoundsException(index)) | ||
|
||
/** Atomically updates element in the array with given transactionall effect. */ | ||
final def updateM[E](index: Int, fn: A => STM[E, A]): STM[E, A] = | ||
if (0 <= index && index < array.size) array(index).get.flatMap(fn) | ||
else STM.die(new ArrayIndexOutOfBoundsException(index)) | ||
} | ||
|
||
object TArray { | ||
|
||
final def apply[A](array: Array[TRef[A]]): TArray[A] = new TArray(array) | ||
|
||
} |