Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #8696 #8845

Merged
merged 8 commits into from
May 22, 2024
18 changes: 18 additions & 0 deletions core-tests/shared/src/test/scala/zio/stm/ZSTMSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,24 @@ object ZSTMSpec extends ZIOBaseSpec {
val right = STM.fail("right")

(left orElse right).commit.exit.map(assert(_)(fails(equalTo("right"))))
},
test("retries on LHS variable change") {
for {
ref1 <- TRef.makeCommit(0)
ref2 <- TRef.makeCommit(0)
txn1 = ref1.get.flatMap {
case 0 => STM.retry
case n => STM.succeed(n)
} orElse ref2.get.flatMap {
case 0 => STM.retry
case n => STM.succeed(n)
}
txn2 = ref1.set(1) // if this is `ref2.set(1)`, it doesn't hang
fib <- txn1.commit.forkDaemon
_ <- liveClockSleep(1.second)
_ <- txn2.commit
result <- fib.join
} yield assert(result)(equalTo(1))
}
) @@ zioTag(errors),
test("orElseEither returns result of the first successful transaction wrapped in either") {
Expand Down
41 changes: 34 additions & 7 deletions core/shared/src/main/scala/zio/stm/ZSTM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1621,15 +1621,30 @@ object ZSTM {
* Creates a function that can reset the journal.
*/
def prepareResetJournal(journal: Journal): () => Any = {
BijenderKumar1 marked this conversation as resolved.
Show resolved Hide resolved
val saved = new MutableMap[TRef[_], Entry](journal.size)

val it = journal.entrySet.iterator
while (it.hasNext) {
val entry = it.next
saved.put(entry.getKey, entry.getValue.copy())
val currentNewValues = new MutableMap[TRef[_], Any]
val itCapture = journal.entrySet.iterator
while (itCapture.hasNext) {
val entry = itCapture.next()
currentNewValues.put(entry.getKey, entry.getValue.unsafeGet[Any])
}

() => { journal.clear(); journal.putAll(saved); () }
() => {
val saved = new MutableMap[TRef[_], Entry](journal.size)
val it = journal.entrySet.iterator
while (it.hasNext) {
val entry = it.next()
val key = entry.getKey
val resetValue = if (currentNewValues.containsKey(key)) {
currentNewValues.get(key)
} else {
entry.getValue.expected.value
}
saved.put(entry.getKey, entry.getValue.copy().reset(resetValue))
}
journal.clear()
journal.putAll(saved)
()
}
}

/**
Expand Down Expand Up @@ -2051,6 +2066,18 @@ object ZSTM {
_isChanged = self.isChanged
}

/**
* Resets the Entry with a given value.
*/
def reset(resetValue: Any): Entry = new Entry {
BijenderKumar1 marked this conversation as resolved.
Show resolved Hide resolved
type S = self.S
val tref = self.tref
val expected = self.expected
val isNew = self.isNew
var newValue = resetValue.asInstanceOf[S]
_isChanged = false
}

/**
* Determines if the entry is invalid. This is the negated version of
* `isValid`.
Expand Down