diff --git a/Cargo.toml b/Cargo.toml index 37d70d1..af1f95d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,11 @@ name = "tracing" path = "examples/tracing.rs" required-features = ["tracing-span-filter"] +[[example]] +name = "schema" +path = "examples/schema.rs" +required-features = ["schemars"] + [features] default = ["http_server", "rand", "uuid", "tracing-span-filter"] hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] @@ -30,6 +35,7 @@ rand = { version = "0.9", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.4", path = "macros" } restate-sdk-shared-core = { version = "0.3.0", features = ["request_identity", "sha2_random_seed", "http"] } +schemars = { version = "1.0.0-alpha.17", optional = true } serde = "1.0" serde_json = "1.0" thiserror = "2.0" @@ -44,6 +50,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] } trybuild = "1.0" reqwest = { version = "0.12", features = ["json"] } rand = "0.9" +schemars = "1.0.0-alpha.17" [build-dependencies] jsonptr = "0.5.1" diff --git a/examples/schema.rs b/examples/schema.rs new file mode 100644 index 0000000..a7072b6 --- /dev/null +++ b/examples/schema.rs @@ -0,0 +1,65 @@ +//! Run with auto-generated schemas for `Json` using `schemars`: +//! cargo run --example schema --features schemars +//! +//! Run with primitive schemas only: +//! cargo run --example schema + +use restate_sdk::prelude::*; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize, JsonSchema)] +struct Product { + id: String, + name: String, + price_cents: u32, +} + +#[restate_sdk::service] +trait CatalogService { + async fn get_product_by_id(product_id: String) -> Result, HandlerError>; + async fn save_product(product: Json) -> Result; + async fn is_in_stock(product_id: String) -> Result; +} + +struct CatalogServiceImpl; + +impl CatalogService for CatalogServiceImpl { + async fn get_product_by_id( + &self, + ctx: Context<'_>, + product_id: String, + ) -> Result, HandlerError> { + ctx.sleep(Duration::from_millis(50)).await?; + Ok(Json(Product { + id: product_id, + name: "Sample Product".to_string(), + price_cents: 1995, + })) + } + + async fn save_product( + &self, + _ctx: Context<'_>, + product: Json, + ) -> Result { + Ok(product.0.id) + } + + async fn is_in_stock( + &self, + _ctx: Context<'_>, + product_id: String, + ) -> Result { + Ok(!product_id.contains("out-of-stock")) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(CatalogServiceImpl.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/macros/src/gen.rs b/macros/src/gen.rs index a882b4d..e489ab0 100644 --- a/macros/src/gen.rs +++ b/macros/src/gen.rs @@ -194,11 +194,32 @@ impl<'a> ServiceGenerator<'a> { quote! { None } }; + let input_schema = match &handler.arg { + Some(PatType { ty, .. }) => { + quote! { + Some(::restate_sdk::discovery::InputPayload::from_metadata::<#ty>()) + } + } + None => quote! { + Some(::restate_sdk::discovery::InputPayload::empty()) + } + }; + + let output_ty = &handler.output_ok; + let output_schema = match output_ty { + syn::Type::Tuple(tuple) if tuple.elems.is_empty() => quote! { + Some(::restate_sdk::discovery::OutputPayload::empty()) + }, + _ => quote! { + Some(::restate_sdk::discovery::OutputPayload::from_metadata::<#output_ty>()) + } + }; + quote! { ::restate_sdk::discovery::Handler { name: ::restate_sdk::discovery::HandlerName::try_from(#handler_literal).expect("Handler name valid"), - input: None, - output: None, + input: #input_schema, + output: #output_schema, ty: #handler_ty, } } diff --git a/src/discovery.rs b/src/discovery.rs index 949ad7e..5c544da 100644 --- a/src/discovery.rs +++ b/src/discovery.rs @@ -8,3 +8,43 @@ mod generated { } pub use generated::*; + +use crate::serde::PayloadMetadata; + +impl InputPayload { + pub fn empty() -> Self { + Self { + content_type: None, + json_schema: None, + required: None, + } + } + + pub fn from_metadata() -> Self { + let input_metadata = T::input_metadata(); + Self { + content_type: Some(input_metadata.accept_content_type.to_owned()), + json_schema: T::json_schema(), + required: Some(input_metadata.is_required), + } + } +} + +impl OutputPayload { + pub fn empty() -> Self { + Self { + content_type: None, + json_schema: None, + set_content_type_if_empty: Some(false), + } + } + + pub fn from_metadata() -> Self { + let output_metadata = T::output_metadata(); + Self { + content_type: Some(output_metadata.content_type.to_owned()), + json_schema: T::json_schema(), + set_content_type_if_empty: Some(output_metadata.set_content_type_if_empty), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 533e192..d183e97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ //! - [Scheduling & Timers][crate::context::ContextTimers]: Let a handler pause for a certain amount of time. Restate durably tracks the timer across failures. //! - [Awakeables][crate::context::ContextAwakeables]: Durable Futures to wait for events and the completion of external tasks. //! - [Error Handling][crate::errors]: Restate retries failures infinitely. Use `TerminalError` to stop retries. -//! - [Serialization][crate::serde]: The SDK serializes results to send them to the Server. +//! - [Serialization][crate::serde]: The SDK serializes results to send them to the Server. Includes [Schema Generation and payload metadata](crate::serde::PayloadMetadata) for documentation & discovery. //! - [Serving][crate::http_server]: Start an HTTP server to expose services. //! //! # SDK Overview diff --git a/src/serde.rs b/src/serde.rs index 2967a0b..c78cf9b 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -4,8 +4,8 @@ //! //! Therefore, the types of the values that are stored, need to either: //! - be a primitive type -//! - use a wrapper type [`Json`] for using [`serde-json`](https://serde.rs/) -//! - have the [`Serialize`] and [`Deserialize`] trait implemented +//! - use a wrapper type [`Json`] for using [`serde-json`](https://serde.rs/). To enable JSON schema generation, you'll need to enable the `schemars` feature. See [PayloadMetadata] for more details. +//! - have the [`Serialize`] and [`Deserialize`] trait implemented. If you need to use a type for the handler input/output, you'll also need to implement [PayloadMetadata] to reply with correct content type and enable **JSON schema generation**. //! use bytes::Bytes; @@ -40,11 +40,160 @@ where fn deserialize(bytes: &mut Bytes) -> Result; } -/// Trait encapsulating `content-type` information for the given serializer/deserializer. +/// ## Payload metadata and Json Schemas /// -/// This is used by service discovery to correctly specify the content type. -pub trait WithContentType { - fn content_type() -> &'static str; +/// The SDK propagates during discovery some metadata to restate-server service catalog. This includes: +/// +/// * The JSON schema of the payload. See below for more details. +/// * The [InputMetadata] used to instruct restate how to accept requests. +/// * The [OutputMetadata] used to instruct restate how to send responses out. +/// +/// There are three approaches for generating JSON Schemas for handler inputs and outputs: +/// +/// ### 1. Primitive Types +/// +/// Primitive types (like `String`, `u32`, `bool`) have built-in schema implementations +/// that work automatically without additional code: +/// +/// ```rust +/// use restate_sdk::prelude::*; +/// +/// #[restate_sdk::service] +/// trait SimpleService { +/// async fn greet(name: String) -> HandlerResult; +/// } +/// ``` +/// +/// ### 2. Using `Json` with schemars +/// +/// For complex types wrapped in `Json`, you need to add the `schemars` feature and derive `JsonSchema`: +/// +/// ```rust +/// use restate_sdk::prelude::*; +/// +/// #[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] +/// struct User { +/// name: String, +/// age: u32, +/// } +/// +/// #[restate_sdk::service] +/// trait UserService { +/// async fn register(user: Json) -> HandlerResult>; +/// } +/// ``` +/// +/// To enable rich schema generation with `Json`, add the `schemars` feature to your dependency: +/// +/// ```toml +/// [dependencies] +/// restate-sdk = { version = "0.3", features = ["schemars"] } +/// schemars = "1.0.0-alpha.17" +/// ``` +/// +/// ### 3. Custom Implementation +/// +/// You can also implement the [PayloadMetadata] trait directly for your types to provide +/// custom schemas without relying on the `schemars` feature: +/// +/// ```rust +/// use restate_sdk::serde::{PayloadMetadata, Serialize, Deserialize}; +/// +/// #[derive(serde::Serialize, serde::Deserialize)] +/// struct User { +/// name: String, +/// age: u32, +/// } +/// +/// // Implement PayloadMetadata directly and override the json_schema implementation +/// impl PayloadMetadata for User { +/// fn json_schema() -> Option { +/// Some(serde_json::json!({ +/// "type": "object", +/// "properties": { +/// "name": {"type": "string"}, +/// "age": {"type": "integer", "minimum": 0} +/// }, +/// "required": ["name", "age"] +/// })) +/// } +/// } +/// ``` +/// +/// Trait encapsulating JSON Schema information for the given serializer/deserializer. +/// +/// This trait allows types to provide JSON Schema information that can be used for +/// documentation, validation, and client generation. +/// +/// ## Behavior with `schemars` Feature Flag +/// +/// When the `schemars` feature is enabled, implementations for complex types use +/// the `schemars` crate to automatically generate rich, JSON Schema 2020-12 conforming schemas. +/// When the feature is disabled, primitive types still provide basic schemas, +/// but complex types return empty schemas, unless manually implemented. +pub trait PayloadMetadata { + /// Generate a JSON Schema for this type. + /// + /// Returns a JSON value representing the schema for this type. When the `schemars` + /// feature is enabled, this returns an auto-generated JSON Schema 2020-12 conforming schema. When the feature is disabled, + /// this returns an empty schema for complex types, but basic schemas for primitives. + /// + /// If returns none, no schema is provided. This should be used when the payload is not expected to be json + fn json_schema() -> Option { + Some(serde_json::Value::Object(serde_json::Map::default())) + } + + /// Returns the [InputMetadata]. The default implementation returns metadata suitable for JSON payloads. + fn input_metadata() -> InputMetadata { + InputMetadata::default() + } + + /// Returns the [OutputMetadata]. The default implementation returns metadata suitable for JSON payloads. + fn output_metadata() -> OutputMetadata { + OutputMetadata::default() + } +} + +/// This struct encapsulates input payload metadata used by discovery. +/// +/// The default implementation works well with Json payloads. +pub struct InputMetadata { + /// Content type of the input. It can accept wildcards, in the same format as the 'Accept' header. + /// + /// By default, is `application/json`. + pub accept_content_type: &'static str, + /// If true, Restate itself will reject requests **without content-types**. + pub is_required: bool, +} + +impl Default for InputMetadata { + fn default() -> Self { + Self { + accept_content_type: APPLICATION_JSON, + is_required: true, + } + } +} + +/// This struct encapsulates output payload metadata used by discovery. +/// +/// The default implementation works for Json payloads. +pub struct OutputMetadata { + /// Content type of the output. + /// + /// By default, is `application/json`. + pub content_type: &'static str, + /// If true, the specified content-type is set even if the output is empty. This should be set to `true` only for encodings that can return a serialized empty byte array (e.g. Protobuf). + pub set_content_type_if_empty: bool, +} + +impl Default for OutputMetadata { + fn default() -> Self { + Self { + content_type: APPLICATION_JSON, + set_content_type_if_empty: false, + } + } } // --- Default implementation for Unit type @@ -65,12 +214,6 @@ impl Deserialize for () { } } -impl WithContentType for () { - fn content_type() -> &'static str { - "" - } -} - // --- Passthrough implementation impl Serialize for Vec { @@ -89,9 +232,23 @@ impl Deserialize for Vec { } } -impl WithContentType for Vec { - fn content_type() -> &'static str { - APPLICATION_OCTET_STREAM +impl PayloadMetadata for Vec { + fn json_schema() -> Option { + None + } + + fn input_metadata() -> InputMetadata { + InputMetadata { + accept_content_type: "*/*", + is_required: true, + } + } + + fn output_metadata() -> OutputMetadata { + OutputMetadata { + content_type: APPLICATION_OCTET_STREAM, + set_content_type_if_empty: false, + } } } @@ -111,15 +268,68 @@ impl Deserialize for Bytes { } } -impl WithContentType for Bytes { - fn content_type() -> &'static str { - APPLICATION_OCTET_STREAM +impl PayloadMetadata for Bytes { + fn json_schema() -> Option { + None + } + + fn input_metadata() -> InputMetadata { + InputMetadata { + accept_content_type: "*/*", + is_required: true, + } + } + + fn output_metadata() -> OutputMetadata { + OutputMetadata { + content_type: APPLICATION_OCTET_STREAM, + set_content_type_if_empty: false, + } + } +} +// --- Option implementation + +impl Serialize for Option { + type Error = T::Error; + + fn serialize(&self) -> Result { + if self.is_none() { + return Ok(Bytes::new()); + } + T::serialize(self.as_ref().unwrap()) + } +} + +impl Deserialize for Option { + type Error = T::Error; + + fn deserialize(b: &mut Bytes) -> Result { + if b.is_empty() { + return Ok(None); + } + T::deserialize(b).map(Some) + } +} + +impl PayloadMetadata for Option { + fn input_metadata() -> InputMetadata { + InputMetadata { + accept_content_type: T::input_metadata().accept_content_type, + is_required: false, + } + } + + fn output_metadata() -> OutputMetadata { + OutputMetadata { + content_type: T::output_metadata().content_type, + set_content_type_if_empty: false, + } } } // --- Primitives -macro_rules! impl_serde_primitives { +macro_rules! impl_integer_primitives { ($ty:ty) => { impl Serialize for $ty { type Error = serde_json::Error; @@ -137,30 +347,77 @@ macro_rules! impl_serde_primitives { } } - impl WithContentType for $ty { - fn content_type() -> &'static str { - APPLICATION_JSON + impl PayloadMetadata for $ty { + fn json_schema() -> Option { + let min = <$ty>::MIN; + let max = <$ty>::MAX; + Some(serde_json::json!({ "type": "integer", "minimum": min, "maximum": max })) + } + } + }; +} + +impl_integer_primitives!(u8); +impl_integer_primitives!(u16); +impl_integer_primitives!(u32); +impl_integer_primitives!(u64); +impl_integer_primitives!(u128); +impl_integer_primitives!(i8); +impl_integer_primitives!(i16); +impl_integer_primitives!(i32); +impl_integer_primitives!(i64); +impl_integer_primitives!(i128); + +macro_rules! impl_serde_primitives { + ($ty:ty) => { + impl Serialize for $ty { + type Error = serde_json::Error; + + fn serialize(&self) -> Result { + serde_json::to_vec(&self).map(Bytes::from) + } + } + + impl Deserialize for $ty { + type Error = serde_json::Error; + + fn deserialize(bytes: &mut Bytes) -> Result { + serde_json::from_slice(&bytes) } } }; } impl_serde_primitives!(String); -impl_serde_primitives!(u8); -impl_serde_primitives!(u16); -impl_serde_primitives!(u32); -impl_serde_primitives!(u64); -impl_serde_primitives!(u128); -impl_serde_primitives!(i8); -impl_serde_primitives!(i16); -impl_serde_primitives!(i32); -impl_serde_primitives!(i64); -impl_serde_primitives!(i128); impl_serde_primitives!(bool); impl_serde_primitives!(f32); impl_serde_primitives!(f64); -// --- Json responses +impl PayloadMetadata for String { + fn json_schema() -> Option { + Some(serde_json::json!({ "type": "string" })) + } +} + +impl PayloadMetadata for bool { + fn json_schema() -> Option { + Some(serde_json::json!({ "type": "boolean" })) + } +} + +impl PayloadMetadata for f32 { + fn json_schema() -> Option { + Some(serde_json::json!({ "type": "number" })) + } +} + +impl PayloadMetadata for f64 { + fn json_schema() -> Option { + Some(serde_json::json!({ "type": "number" })) + } +} + +// --- Json wrapper /// Wrapper type to use [`serde_json`] with Restate's [`Serialize`]/[`Deserialize`] traits. pub struct Json(pub T); @@ -204,3 +461,19 @@ impl Default for Json { Self(T::default()) } } + +// When schemars is disabled - works with any T +#[cfg(not(feature = "schemars"))] +impl PayloadMetadata for Json { + fn json_schema() -> Option { + Some(serde_json::json!({})) + } +} + +// When schemars is enabled - requires T: JsonSchema +#[cfg(feature = "schemars")] +impl PayloadMetadata for Json { + fn json_schema() -> Option { + Some(schemars::schema_for!(T).to_value()) + } +} diff --git a/test-services/Cargo.toml b/test-services/Cargo.toml index 79c4e48..477e793 100644 --- a/test-services/Cargo.toml +++ b/test-services/Cargo.toml @@ -6,9 +6,11 @@ publish = false [dependencies] anyhow = "1.0" +bytes = "1.10.1" tokio = { version = "1", features = ["full"] } tracing-subscriber = "0.3" futures = "0.3" -restate-sdk = { path = ".." } +restate-sdk = { path = "..", features = ["schemars"] } +schemars = "1.0.0-alpha.17" serde = { version = "1", features = ["derive"] } tracing = "0.1.40" diff --git a/test-services/src/cancel_test.rs b/test-services/src/cancel_test.rs index 768b951..ecbd866 100644 --- a/test-services/src/cancel_test.rs +++ b/test-services/src/cancel_test.rs @@ -1,10 +1,11 @@ use crate::awakeable_holder; use anyhow::anyhow; use restate_sdk::prelude::*; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::time::Duration; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub(crate) enum BlockingOperation { Call, diff --git a/test-services/src/counter.rs b/test-services/src/counter.rs index 0b51d86..7456c7e 100644 --- a/test-services/src/counter.rs +++ b/test-services/src/counter.rs @@ -1,8 +1,9 @@ use restate_sdk::prelude::*; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tracing::info; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct CounterUpdateResponse { old_value: u64, diff --git a/test-services/src/map_object.rs b/test-services/src/map_object.rs index cf5ab76..5dae831 100644 --- a/test-services/src/map_object.rs +++ b/test-services/src/map_object.rs @@ -1,8 +1,9 @@ use anyhow::anyhow; use restate_sdk::prelude::*; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct Entry { key: String, diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 6b1f221..67f91af 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -2,10 +2,11 @@ use futures::future::BoxFuture; use futures::FutureExt; use restate_sdk::context::RequestTarget; use restate_sdk::prelude::*; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::time::Duration; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct ProxyRequest { service_name: String, @@ -33,7 +34,7 @@ impl ProxyRequest { } } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct ManyCallRequest { proxy_request: ProxyRequest, diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index 0152062..325d208 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -15,7 +15,7 @@ pub(crate) trait TestUtilsService { #[name = "uppercaseEcho"] async fn uppercase_echo(input: String) -> HandlerResult; #[name = "rawEcho"] - async fn raw_echo(input: Vec) -> Result, Infallible>; + async fn raw_echo(input: bytes::Bytes) -> Result, Infallible>; #[name = "echoHeaders"] async fn echo_headers() -> HandlerResult>>; #[name = "sleepConcurrently"] @@ -37,8 +37,8 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(input.to_ascii_uppercase()) } - async fn raw_echo(&self, _: Context<'_>, input: Vec) -> Result, Infallible> { - Ok(input) + async fn raw_echo(&self, _: Context<'_>, input: bytes::Bytes) -> Result, Infallible> { + Ok(input.to_vec()) } async fn echo_headers( diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/virtual_object_command_interpreter.rs index d401c91..b620ef4 100644 --- a/test-services/src/virtual_object_command_interpreter.rs +++ b/test-services/src/virtual_object_command_interpreter.rs @@ -1,16 +1,17 @@ use anyhow::anyhow; use futures::TryFutureExt; use restate_sdk::prelude::*; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::time::Duration; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct InterpretRequest { commands: Vec, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(tag = "type")] #[serde(rename_all_fields = "camelCase")] pub(crate) enum Command { @@ -39,7 +40,7 @@ pub(crate) enum Command { GetEnvVariable { env_name: String }, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(tag = "type")] #[serde(rename_all_fields = "camelCase")] pub(crate) enum AwaitableCommand { @@ -51,14 +52,14 @@ pub(crate) enum AwaitableCommand { RunThrowTerminalException { reason: String }, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct ResolveAwakeable { awakeable_key: String, value: String, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub(crate) struct RejectAwakeable { awakeable_key: String, diff --git a/tests/schema.rs b/tests/schema.rs new file mode 100644 index 0000000..0619ef7 --- /dev/null +++ b/tests/schema.rs @@ -0,0 +1,182 @@ +use restate_sdk::prelude::*; +use restate_sdk::serde::{Json, PayloadMetadata}; +use restate_sdk::service::Discoverable; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[cfg(feature = "schemars")] +use schemars::JsonSchema; + +#[derive(Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(JsonSchema))] +struct TestUser { + name: String, + age: u32, +} + +#[derive(Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(JsonSchema))] +struct Person { + name: String, + age: u32, + address: Address, +} + +#[derive(Serialize, Deserialize, Default)] +#[cfg_attr(feature = "schemars", derive(JsonSchema))] +struct Address { + street: String, + city: String, +} + +#[restate_sdk::service] +trait SchemaTestService { + async fn string_handler(input: String) -> HandlerResult; + async fn no_input_handler() -> HandlerResult; + async fn json_handler(input: Json) -> HandlerResult>; + async fn complex_handler(input: Json) -> HandlerResult>>; + async fn empty_output_handler(input: String) -> HandlerResult<()>; +} + +struct SchemaTestServiceImpl; + +impl SchemaTestService for SchemaTestServiceImpl { + async fn string_handler(&self, _ctx: Context<'_>, _input: String) -> HandlerResult { + Ok(42) + } + async fn no_input_handler(&self, _ctx: Context<'_>) -> HandlerResult { + Ok("No input".to_string()) + } + async fn json_handler( + &self, + _ctx: Context<'_>, + input: Json, + ) -> HandlerResult> { + Ok(input) + } + async fn complex_handler( + &self, + _ctx: Context<'_>, + input: Json, + ) -> HandlerResult>> { + Ok(Json(HashMap::from([("original".to_string(), input.0)]))) + } + async fn empty_output_handler(&self, _ctx: Context<'_>, _input: String) -> HandlerResult<()> { + Ok(()) + } +} + +#[test] +fn schema_discovery_and_validation() { + let discovery = ServeSchemaTestService::::discover(); + assert_eq!(discovery.name.to_string(), "SchemaTestService"); + assert_eq!(discovery.handlers.len(), 5); + + for handler in &discovery.handlers { + let input = handler + .input + .as_ref() + .expect("Handler should have input schema"); + let output = handler + .output + .as_ref() + .expect("Handler should have output schema"); + + match handler.name.to_string().as_str() { + "string_handler" | "json_handler" | "complex_handler" | "empty_output_handler" => { + let input_schema = input + .json_schema + .as_ref() + .expect("Input schema should exist for handlers with input"); + let output_schema = output.json_schema.as_ref(); + + match handler.name.to_string().as_str() { + "string_handler" => { + assert_eq!( + input_schema.get("type").and_then(|v| v.as_str()), + Some("string") + ); + assert!(output_schema.is_some()); + assert_eq!( + output_schema.unwrap().get("type").and_then(|v| v.as_str()), + Some("integer") + ); + } + "json_handler" => { + #[cfg(feature = "schemars")] + { + let obj = input_schema + .as_object() + .expect("Schema should be an object"); + assert!( + obj.contains_key("properties"), + "Json schema should have properties" + ); + assert!(obj["properties"]["name"]["type"] == "string"); + assert!(obj["properties"]["age"]["type"] == "integer"); + } + #[cfg(not(feature = "schemars"))] + assert_eq!(input_schema, &serde_json::json!({})); + } + "complex_handler" => { + #[cfg(feature = "schemars")] + { + let obj = input_schema + .as_object() + .expect("Schema should be an object"); + assert!(obj.contains_key("properties") || obj.contains_key("$ref")); + let props = obj.get("properties").or_else(|| obj.get("$ref")).unwrap(); + assert!(props.is_object(), "Complex schema should define structure"); + } + #[cfg(not(feature = "schemars"))] + assert_eq!(input_schema, &serde_json::json!({})); + } + "empty_output_handler" => { + assert_eq!( + input_schema.get("type").and_then(|v| v.as_str()), + Some("string") + ); + // For empty output handler, we don't expect json_schema to be set in output + assert!( + output_schema.is_none(), + "Empty output handler should have json_schema set to None" + ); + // Verify that set_content_type_if_empty is set + assert_eq!(output.set_content_type_if_empty, Some(false)); + } + _ => unreachable!("Unexpected handler"), + } + } + "no_input_handler" => { + // For no_input_handler, we don't expect json_schema to be set + assert!( + input.json_schema.is_none(), + "No input handler should have json_schema set to None" + ); + + let output_schema = output + .json_schema + .as_ref() + .expect("Output schema should exist"); + + assert_eq!( + output_schema.get("type").and_then(|v| v.as_str()), + Some("string") + ); + } + _ => unreachable!("Unexpected handler"), + } + } +} + +#[test] +fn schema_generation() { + let string_schema = ::json_schema().unwrap(); + assert_eq!(string_schema["type"], "string"); + + let json_schema = as PayloadMetadata>::json_schema().unwrap(); + #[cfg(feature = "schemars")] + assert!(json_schema["properties"]["name"]["type"] == "string"); + #[cfg(not(feature = "schemars"))] + assert_eq!(json_schema, serde_json::json!({})); +}