Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 81 additions & 53 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import dev.restate.sdk.common.HandlerRequest
import dev.restate.sdk.common.StateKey
import dev.restate.sdk.common.TerminalException
import dev.restate.sdk.endpoint.definition.HandlerContext
import dev.restate.sdk.kotlin.internal.InsideRunElement
import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun
import dev.restate.serde.Serde
import dev.restate.serde.SerdeFactory
import dev.restate.serde.TypeTag
Expand All @@ -31,6 +33,7 @@ internal constructor(
internal val handlerContext: HandlerContext,
internal val contextSerdeFactory: SerdeFactory,
) : WorkflowContext {

override fun key(): String {
return this.handlerContext.objectKey()
}
Expand All @@ -39,75 +42,89 @@ internal constructor(
return this.handlerContext.request()
}

override suspend fun <T : Any> get(key: StateKey<T>): T? =
resolveSerde<T?>(key.serdeInfo())
.let { serde ->
SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap {
it.getOrNull()?.let { serde.deserialize(it) }
}
override suspend fun <T : Any> get(key: StateKey<T>): T? {
checkNotInsideRun()
return resolveSerde<T?>(key.serdeInfo())
.let { serde ->
SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap {
it.getOrNull()?.let { serde.deserialize(it) }
}
.await()
}
.await()
}

override suspend fun stateKeys(): Collection<String> =
SingleDurableFutureImpl(handlerContext.getKeys().await()).await()
override suspend fun stateKeys(): Collection<String> {
checkNotInsideRun()
return SingleDurableFutureImpl(handlerContext.getKeys().await()).await()
}

override suspend fun <T : Any> set(key: StateKey<T>, value: T) {
checkNotInsideRun()
handlerContext.set(key.name(), resolveAndSerialize(key.serdeInfo(), value)).await()
}

override suspend fun clear(key: StateKey<*>) {
checkNotInsideRun()
handlerContext.clear(key.name()).await()
}

override suspend fun clearAll() {
checkNotInsideRun()
handlerContext.clearAll().await()
}

override suspend fun timer(duration: Duration, name: String?): DurableFuture<Unit> =
SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await()).map {}
override suspend fun timer(duration: Duration, name: String?): DurableFuture<Unit> {
checkNotInsideRun()
return SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await())
.map {}
}

override suspend fun <Req : Any?, Res : Any?> call(
request: Request<Req, Res>
): CallDurableFuture<Res> =
resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
val callHandle =
handlerContext
.call(
request.getTarget(),
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
request.getIdempotencyKey(),
request.getHeaders()?.entries,
)
.await()

val callAsyncResult =
callHandle.callAsyncResult.map {
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
}
): CallDurableFuture<Res> {
checkNotInsideRun()
return resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
val callHandle =
handlerContext
.call(
request.getTarget(),
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
request.getIdempotencyKey(),
request.getHeaders()?.entries,
)
.await()

return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult)
}
val callAsyncResult =
callHandle.callAsyncResult.map {
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
}

return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult)
}
}

override suspend fun <Req : Any?, Res : Any?> send(
request: Request<Req, Res>,
delay: Duration?,
): InvocationHandle<Res> =
resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
val invocationIdAsyncResult =
handlerContext
.send(
request.getTarget(),
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
request.getIdempotencyKey(),
request.getHeaders()?.entries,
delay?.toJavaDuration(),
)
.await()
): InvocationHandle<Res> {
checkNotInsideRun()
return resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
val invocationIdAsyncResult =
handlerContext
.send(
request.getTarget(),
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
request.getIdempotencyKey(),
request.getHeaders()?.entries,
delay?.toJavaDuration(),
)
.await()

object : BaseInvocationHandle<Res>(handlerContext, responseSerde) {
override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await()
}
object : BaseInvocationHandle<Res>(handlerContext, responseSerde) {
override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await()
}
}
}

override fun <Res> invocationHandle(
invocationId: String,
Expand All @@ -125,6 +142,7 @@ internal constructor(
retryPolicy: RetryPolicy?,
block: suspend () -> T,
): DurableFuture<T> {
checkNotInsideRun()
val serde: Serde<T> = resolveSerde(typeTag)
val coroutineCtx = currentCoroutineContext()
val javaRetryPolicy =
Expand All @@ -138,7 +156,10 @@ internal constructor(
.setMaxDuration(it.maxDuration?.toJavaDuration())
}

val scope = CoroutineScope(coroutineCtx + CoroutineName("restate-run-$name"))
val scope =
CoroutineScope(
coroutineCtx + CoroutineName("restate-run-$name") + InsideRunElement.INSTANCE
)

val asyncResult =
handlerContext
Expand All @@ -159,6 +180,7 @@ internal constructor(
}

override suspend fun <T : Any> awakeable(typeTag: TypeTag<T>): Awakeable<T> {
checkNotInsideRun()
val serde: Serde<T> = resolveSerde(typeTag)
val awk = handlerContext.awakeable().await()
return AwakeableImpl(awk.asyncResult, serde, awk.id)
Expand All @@ -184,22 +206,27 @@ internal constructor(
DurablePromise<T> {
val serde: Serde<T> = resolveSerde(key.serdeInfo())

override suspend fun future(): DurableFuture<T> =
SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap {
serde.deserialize(it)
}
override suspend fun future(): DurableFuture<T> {
checkNotInsideRun()
return SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap {
serde.deserialize(it)
}
}

override suspend fun peek(): Output<T> =
SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await())
.simpleMap { it.map { serde.deserialize(it) } }
.await()
override suspend fun peek(): Output<T> {
checkNotInsideRun()
return SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await())
.simpleMap { it.map { serde.deserialize(it) } }
.await()
}
}

inner class DurablePromiseHandleImpl<T : Any>(private val key: DurablePromiseKey<T>) :
DurablePromiseHandle<T> {
val serde: Serde<T> = resolveSerde(key.serdeInfo())

override suspend fun resolve(payload: T) {
checkNotInsideRun()
SingleDurableFutureImpl(
handlerContext
.resolvePromise(
Expand All @@ -212,6 +239,7 @@ internal constructor(
}

override suspend fun reject(reason: String) {
checkNotInsideRun()
SingleDurableFutureImpl(
handlerContext.rejectPromise(key.name(), TerminalException(reason)).await()
)
Expand Down
30 changes: 19 additions & 11 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import dev.restate.sdk.common.TerminalException
import dev.restate.sdk.common.TimeoutException
import dev.restate.sdk.endpoint.definition.AsyncResult
import dev.restate.sdk.endpoint.definition.HandlerContext
import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun
import dev.restate.serde.Serde
import dev.restate.serde.TypeTag
import java.util.concurrent.CompletableFuture
Expand All @@ -32,6 +33,7 @@ internal abstract class BaseDurableFutureImpl<T : Any?> : DurableFuture<T> {
get() = SelectClauseImpl(this)

override suspend fun await(): T {
checkNotInsideRun()
return asyncResult().poll().await()
}

Expand Down Expand Up @@ -193,20 +195,25 @@ internal constructor(
private val responseSerde: Serde<Res>,
) : InvocationHandle<Res> {
override suspend fun cancel() {
checkNotInsideRun()
val ignored = handlerContext.cancelInvocation(invocationId()).await()
}

override suspend fun attach(): DurableFuture<Res> =
SingleDurableFutureImpl(
handlerContext.attachInvocation(invocationId()).await().map {
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
}
)
override suspend fun attach(): DurableFuture<Res> {
checkNotInsideRun()
return SingleDurableFutureImpl(
handlerContext.attachInvocation(invocationId()).await().map {
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
}
)
}

override suspend fun output(): Output<Res> =
SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await())
.simpleMap { it.map { responseSerde.deserialize(it) } }
.await()
override suspend fun output(): Output<Res> {
checkNotInsideRun()
return SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await())
.simpleMap { it.map { responseSerde.deserialize(it) } }
.await()
}
}

internal class AwakeableImpl<T : Any?>
Expand All @@ -218,13 +225,14 @@ internal constructor(asyncResult: AsyncResult<Slice>, serde: Serde<T>, override

internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) : AwakeableHandle {
override suspend fun <T : Any> resolve(typeTag: TypeTag<T>, payload: T) {
checkNotInsideRun()
contextImpl.handlerContext
.resolveAwakeable(id, contextImpl.resolveAndSerialize(typeTag, payload))
.await()
}

override suspend fun reject(reason: String) {
return
checkNotInsideRun()
contextImpl.handlerContext.rejectAwakeable(id, TerminalException(reason)).await()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.kotlin.internal

import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.currentCoroutineContext

/**
* Coroutine context element that marks the current coroutine as executing inside a `ctx.run()`
* block. Context methods check for this element and throw [IllegalStateException] if present.
*/
internal class InsideRunElement private constructor() : AbstractCoroutineContextElement(Key) {
companion object Key : CoroutineContext.Key<InsideRunElement> {
val INSTANCE = InsideRunElement()

suspend fun checkNotInsideRun() {
if (currentCoroutineContext()[Key] != null) {
throw IllegalStateException(
"Cannot invoke context method inside ctx.run(). " +
"The run closure is meant for non-deterministic operations (e.g., HTTP calls, database reads). " +
"You MUST use context methods outside of ctx.run(), check the documentation: https://docs.restate.dev/develop/java/durable-steps#run"
)
}
}
}
}
Loading
Loading