diff --git a/core-tests/shared/src/test/scala/zio/stm/ZSTMSpec.scala b/core-tests/shared/src/test/scala/zio/stm/ZSTMSpec.scala index b62128efc4d..d33dbf9cd94 100644 --- a/core-tests/shared/src/test/scala/zio/stm/ZSTMSpec.scala +++ b/core-tests/shared/src/test/scala/zio/stm/ZSTMSpec.scala @@ -873,6 +873,24 @@ object ZSTMSpec extends ZIOBaseSpec { val right = STM.fail("right") (left orElse right).commit.exit.map(assert(_)(fails(equalTo("right")))) + }, + test("retries on LHS variable change") { + for { + ref1 <- TRef.makeCommit(0) + ref2 <- TRef.makeCommit(0) + txn1 = ref1.get.flatMap { + case 0 => STM.retry + case n => STM.succeed(n) + } orElse ref2.get.flatMap { + case 0 => STM.retry + case n => STM.succeed(n) + } + txn2 = ref1.set(1) + fib <- txn1.commit.forkDaemon + _ <- liveClockSleep(1.second) + _ <- txn2.commit + result <- fib.join + } yield assert(result)(equalTo(1)) } ) @@ zioTag(errors), test("orElseEither returns result of the first successful transaction wrapped in either") { diff --git a/core/shared/src/main/scala/zio/stm/ZSTM.scala b/core/shared/src/main/scala/zio/stm/ZSTM.scala index 98747a5dbf4..612f9b0c5b1 100644 --- a/core/shared/src/main/scala/zio/stm/ZSTM.scala +++ b/core/shared/src/main/scala/zio/stm/ZSTM.scala @@ -1621,15 +1621,30 @@ object ZSTM { * Creates a function that can reset the journal. */ def prepareResetJournal(journal: Journal): () => Any = { - val saved = new MutableMap[TRef[_], Entry](journal.size) - - val it = journal.entrySet.iterator - while (it.hasNext) { - val entry = it.next - saved.put(entry.getKey, entry.getValue.copy()) + val currentNewValues = new MutableMap[TRef[_], Any] + val itCapture = journal.entrySet.iterator + while (itCapture.hasNext) { + val entry = itCapture.next() + currentNewValues.put(entry.getKey, entry.getValue.unsafeGet[Any]) } - () => { journal.clear(); journal.putAll(saved); () } + () => { + val saved = new MutableMap[TRef[_], Entry](journal.size) + val it = journal.entrySet.iterator + while (it.hasNext) { + val entry = it.next() + val key = entry.getKey + val resetValue = if (currentNewValues.containsKey(key)) { + currentNewValues.get(key) + } else { + entry.getValue.expected.value + } + saved.put(entry.getKey, entry.getValue.copy().reset(resetValue)) + } + journal.clear() + journal.putAll(saved) + () + } } /** @@ -2051,6 +2066,18 @@ object ZSTM { _isChanged = self.isChanged } + /** + * Resets the Entry with a given value. + */ + private[stm] def reset(resetValue: Any): Entry = new Entry { + type S = self.S + val tref = self.tref + val expected = self.expected + val isNew = self.isNew + var newValue = resetValue.asInstanceOf[S] + _isChanged = false + } + /** * Determines if the entry is invalid. This is the negated version of * `isValid`.