diff --git a/Cargo.toml b/Cargo.toml index 178434d..6e7b0d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,11 +7,10 @@ license = "MIT" repository = "https://github.com/restatedev/sdk-rust" [features] -default = ["http", "anyhow"] +default = ["http"] http = ["hyper", "http-body-util", "hyper-util", "tokio/net", "tokio/signal", "restate-sdk-shared-core/http"] [dependencies] -anyhow = {version = "1.0", optional = true} bytes = "1.6.1" futures = "0.3" http-body-util = { version = "0.1", optional = true } diff --git a/macros/src/gen.rs b/macros/src/gen.rs index da878a8..8d8b066 100644 --- a/macros/src/gen.rs +++ b/macros/src/gen.rs @@ -178,7 +178,7 @@ impl<'a> ServiceGenerator<'a> { let service_literal = Literal::string(restate_name); - let service_ty = match service_ty { + let service_ty_token = match service_ty { ServiceType::Service => quote! { ::restate_sdk::discovery::ServiceType::Service }, ServiceType::Object => { quote! { ::restate_sdk::discovery::ServiceType::VirtualObject } @@ -191,6 +191,8 @@ impl<'a> ServiceGenerator<'a> { let handler_ty = if handler.is_shared { quote! { Some(::restate_sdk::discovery::HandlerType::Shared) } + } else if *service_ty == ServiceType::Workflow { + quote! { Some(::restate_sdk::discovery::HandlerType::Workflow) } } else { // Macro has same defaulting rules of the discovery manifest quote! { None } @@ -212,7 +214,7 @@ impl<'a> ServiceGenerator<'a> { { fn discover() -> ::restate_sdk::discovery::Service { ::restate_sdk::discovery::Service { - ty: #service_ty, + ty: #service_ty_token, name: ::restate_sdk::discovery::ServiceName::try_from(#service_literal.to_string()) .expect("Service name valid"), handlers: vec![#( #handlers ),*], diff --git a/src/errors.rs b/src/errors.rs index 0bf4f96..e9ad1d4 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,7 +1,6 @@ use restate_sdk_shared_core::Failure; use std::error::Error as StdError; use std::fmt; -use thiserror::__private::AsDynError; #[derive(Debug)] pub(crate) enum HandlerErrorInner { @@ -23,7 +22,7 @@ impl fmt::Display for HandlerErrorInner { impl StdError for HandlerErrorInner { fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { - HandlerErrorInner::Retryable(e) => Some(e.as_dyn_error()), + HandlerErrorInner::Retryable(e) => Some(e.as_ref()), HandlerErrorInner::Terminal(e) => Some(e), } } @@ -32,16 +31,9 @@ impl StdError for HandlerErrorInner { #[derive(Debug)] pub struct HandlerError(pub(crate) HandlerErrorInner); -impl HandlerError { - #[cfg(feature = "anyhow")] - pub fn from_anyhow(err: anyhow::Error) -> Self { - Self(HandlerErrorInner::Retryable(err.into())) - } -} - -impl From for HandlerError { +impl>> From for HandlerError { fn from(value: E) -> Self { - Self(HandlerErrorInner::Retryable(Box::new(value))) + Self(HandlerErrorInner::Retryable(value.into())) } } diff --git a/src/serde.rs b/src/serde.rs index bae9065..b445aa7 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -174,3 +174,9 @@ where serde_json::from_slice(bytes).map(Json) } } + +impl Default for Json { + fn default() -> Self { + Self(T::default()) + } +} diff --git a/test-services/README.md b/test-services/README.md new file mode 100644 index 0000000..259a391 --- /dev/null +++ b/test-services/README.md @@ -0,0 +1,13 @@ +# Test services + +To build (from the repo root): + +```shell +$ podman build -f test-services/Dockerfile -t restatedev/rust-test-services . +``` + +To run (download the [sdk-test-suite](https://github.com/restatedev/sdk-test-suite) first): + +```shell +$ java -jar restate-sdk-test-suite.jar run restatedev/rust-test-services +``` \ No newline at end of file diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index c9b30bd..f9e3440 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -8,7 +8,6 @@ exclusions: - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - "dev.restate.sdktesting.tests.UserErrors" - - "dev.restate.sdktesting.tests.WorkflowAPI" "default": - "dev.restate.sdktesting.tests.AwaitTimeout" - "dev.restate.sdktesting.tests.CallOrdering" @@ -21,7 +20,6 @@ exclusions: - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - "dev.restate.sdktesting.tests.UserErrors" - - "dev.restate.sdktesting.tests.WorkflowAPI" "persistedTimers": - "dev.restate.sdktesting.tests.Sleep" "singleThreadSinglePartition": @@ -36,4 +34,3 @@ exclusions: - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - "dev.restate.sdktesting.tests.UserErrors" - - "dev.restate.sdktesting.tests.WorkflowAPI" diff --git a/test-services/src/awakeable_holder.rs b/test-services/src/awakeable_holder.rs new file mode 100644 index 0000000..1d28650 --- /dev/null +++ b/test-services/src/awakeable_holder.rs @@ -0,0 +1,39 @@ +use restate_sdk::prelude::*; + +#[restate_sdk::object] +#[name = "AwakeableHolder"] +pub(crate) trait AwakeableHolder { + #[name = "hold"] + async fn hold(id: String) -> HandlerResult<()>; + #[name = "hasAwakeable"] + #[shared] + async fn has_awakeable() -> HandlerResult; + #[name = "unlock"] + async fn unlock(payload: String) -> HandlerResult<()>; +} + +pub(crate) struct AwakeableHolderImpl; + +const ID: &str = "id"; + +impl AwakeableHolder for AwakeableHolderImpl { + async fn hold(&self, context: ObjectContext<'_>, id: String) -> HandlerResult<()> { + context.set(ID, id); + Ok(()) + } + + async fn has_awakeable(&self, context: SharedObjectContext<'_>) -> HandlerResult { + Ok(context.get::(ID).await?.is_some()) + } + + async fn unlock(&self, context: ObjectContext<'_>, payload: String) -> HandlerResult<()> { + let k: String = context.get(ID).await?.ok_or_else(|| { + TerminalError::new(format!( + "No awakeable stored for awakeable holder {}", + context.key() + )) + })?; + context.resolve_awakeable(&k, payload); + Ok(()) + } +} diff --git a/test-services/src/block_and_wait_workflow.rs b/test-services/src/block_and_wait_workflow.rs new file mode 100644 index 0000000..e9b092f --- /dev/null +++ b/test-services/src/block_and_wait_workflow.rs @@ -0,0 +1,49 @@ +use restate_sdk::prelude::*; + +#[restate_sdk::workflow] +#[name = "BlockAndWaitWorkflow"] +pub(crate) trait BlockAndWaitWorkflow { + #[name = "run"] + async fn run(input: String) -> HandlerResult; + #[name = "unblock"] + #[shared] + async fn unblock(output: String) -> HandlerResult<()>; + #[name = "getState"] + #[shared] + async fn get_state() -> HandlerResult>>; +} + +pub(crate) struct BlockAndWaitWorkflowImpl; + +const MY_PROMISE: &str = "my-promise"; +const MY_STATE: &str = "my-state"; + +impl BlockAndWaitWorkflow for BlockAndWaitWorkflowImpl { + async fn run(&self, context: WorkflowContext<'_>, input: String) -> HandlerResult { + context.set(MY_STATE, input); + + let promise: String = context.promise(MY_PROMISE).await?; + + if context.peek_promise::(MY_PROMISE).await?.is_none() { + return Err(TerminalError::new("Durable promise should be completed").into()); + } + + Ok(promise) + } + + async fn unblock( + &self, + context: SharedWorkflowContext<'_>, + output: String, + ) -> HandlerResult<()> { + context.resolve_promise(MY_PROMISE, output); + Ok(()) + } + + async fn get_state( + &self, + context: SharedWorkflowContext<'_>, + ) -> HandlerResult>> { + Ok(Json(context.get::(MY_STATE).await?)) + } +} diff --git a/test-services/src/list_object.rs b/test-services/src/list_object.rs new file mode 100644 index 0000000..7988c9f --- /dev/null +++ b/test-services/src/list_object.rs @@ -0,0 +1,45 @@ +use restate_sdk::prelude::*; + +#[restate_sdk::object] +#[name = "ListObject"] +pub(crate) trait ListObject { + #[name = "append"] + async fn append(value: String) -> HandlerResult<()>; + #[name = "get"] + async fn get() -> HandlerResult>>; + #[name = "clear"] + async fn clear() -> HandlerResult>>; +} + +pub(crate) struct ListObjectImpl; + +const LIST: &str = "list"; + +impl ListObject for ListObjectImpl { + async fn append(&self, ctx: ObjectContext<'_>, value: String) -> HandlerResult<()> { + let mut list = ctx + .get::>>(LIST) + .await? + .unwrap_or_default() + .into_inner(); + list.push(value); + ctx.set(LIST, Json(list)); + Ok(()) + } + + async fn get(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { + Ok(ctx + .get::>>(LIST) + .await? + .unwrap_or_default()) + } + + async fn clear(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { + let get = ctx + .get::>>(LIST) + .await? + .unwrap_or_default(); + ctx.clear(LIST); + Ok(get) + } +} diff --git a/test-services/src/main.rs b/test-services/src/main.rs index 96ff8cb..3a0b813 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -1,4 +1,7 @@ +mod awakeable_holder; +mod block_and_wait_workflow; mod counter; +mod list_object; mod map_object; mod proxy; @@ -22,6 +25,19 @@ async fn main() { if services == "*" || services.contains("MapObject") { builder = builder.with_service(map_object::MapObject::serve(map_object::MapObjectImpl)) } + if services == "*" || services.contains("ListObject") { + builder = builder.with_service(list_object::ListObject::serve(list_object::ListObjectImpl)) + } + if services == "*" || services.contains("AwakeableHolder") { + builder = builder.with_service(awakeable_holder::AwakeableHolder::serve( + awakeable_holder::AwakeableHolderImpl, + )) + } + if services == "*" || services.contains("BlockAndWaitWorkflow") { + builder = builder.with_service(block_and_wait_workflow::BlockAndWaitWorkflow::serve( + block_and_wait_workflow::BlockAndWaitWorkflowImpl, + )) + } HyperServer::new(builder.build()) .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) diff --git a/test-services/src/map_object.rs b/test-services/src/map_object.rs index 4bc4320..cf5ab76 100644 --- a/test-services/src/map_object.rs +++ b/test-services/src/map_object.rs @@ -44,7 +44,7 @@ impl MapObject for MapObjectImpl { let value = ctx .get(&k) .await? - .ok_or_else(|| HandlerError::from_anyhow(anyhow!("Missing key {k}")))?; + .ok_or_else(|| anyhow!("Missing key {k}"))?; entries.push(Entry { key: k, value }) }