diff --git a/.cargo/config.toml b/.cargo/config.toml index 30944a2..b6389ea 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,4 +1,6 @@ [alias] + +coverage = ["tarpaulin"] run_clippy = [ "clippy", "--features", diff --git a/.gitignore b/.gitignore index 02cbf32..036f2a6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,9 @@ settings.json + # test artifacts /coverage +.coverage *.lcov *.profraw \ No newline at end of file diff --git a/src/generated_schema/2024_11_05/schema_utils.rs b/src/generated_schema/2024_11_05/schema_utils.rs index 7fd1fcf..97d374b 100644 --- a/src/generated_schema/2024_11_05/schema_utils.rs +++ b/src/generated_schema/2024_11_05/schema_utils.rs @@ -1223,5 +1223,9 @@ mod tests { ); let result = detect_message_type(&json!(message)); assert!(matches!(result, MessageTypes::Error)); + + // default + let result = detect_message_type(&json!({})); + assert!(matches!(result, MessageTypes::Request)); } } diff --git a/tarpaulin.toml b/tarpaulin.toml new file mode 100644 index 0000000..dd03f74 --- /dev/null +++ b/tarpaulin.toml @@ -0,0 +1,13 @@ +[version_2024_11_05] +no-default-features = true +features = "2024_11_05 schema_utils" +# release = true + +[version_draft] +no-default-features = true +features = "draft schema_utils" +# release = true + +[report] +output-dir = "./.coverage" +out = ["Html", "Xml"] diff --git a/tests/common/sample_mcp_messages.json b/tests/common/sample_mcp_messages.json index b380fc0..5e94ede 100644 --- a/tests/common/sample_mcp_messages.json +++ b/tests/common/sample_mcp_messages.json @@ -57,7 +57,7 @@ "req_tools_call_4": {"jsonrpc":"2.0","id":18,"method":"tools/call","params":{"_meta":{"progressToken":3},"name":"sampleLLM","arguments":{"a":15,"b":21,"prompt":"my prompt","maxTokens":5}}}, /* ServerRequest::CreateMessageRequest */ "req_sampling_create_message_2": {"method":"sampling/createMessage","params":{"messages":[{"role":"user","content":{"type":"text","text":"Resource sampleLLM context: my prompt"}}],"systemPrompt":"You are a helpful test server.","maxTokens":5,"temperature":0.7,"includeContext":"thisServer"},"jsonrpc":"2.0","id":1}, - /* ServerRequest::CreateMessageRequest */ + /* ClientResult::CreateMessageResult */ "res_sampling_create_message_2": {"jsonrpc":"2.0","id":1,"result":{"model":"stub-model","stopReason":"endTurn","role":"assistant","content":{"type":"text","text":"This is a stub response."}}}, /* ServerResult::CallToolResult */ "res_tools_call_4": {"result":{"content":[{"type":"text","text":"LLM sampling result: This is a stub response."}]},"jsonrpc":"2.0","id":18}, diff --git a/tests/test_deserialize.rs b/tests/test_deserialize.rs index ab49ec3..520afe4 100644 --- a/tests/test_deserialize.rs +++ b/tests/test_deserialize.rs @@ -2,10 +2,8 @@ pub mod common; mod test_deserialize { - use rust_mcp_schema::{ - schema_utils::*, ClientNotification, ClientRequest, RequestId, ServerRequest, ServerResult, JSONRPC_VERSION, - LATEST_PROTOCOL_VERSION, - }; + use rust_mcp_schema::schema_utils::*; + use rust_mcp_schema::*; use super::common::get_message; @@ -126,6 +124,16 @@ mod test_deserialize { )); } + /* ---------------------- CLIENT RESPONSES ---------------------- */ + #[test] + fn test_list_tools_result() { + let message = get_message("res_sampling_create_message_2"); + assert!(matches!(message, ClientMessage::Response(client_message) + if matches!(&client_message.result, ResultFromClient::ClientResult(client_result) + if matches!( client_result, ClientResult::CreateMessageResult(_)) + ) + )); + } /* ---------------------- SERVER RESPONSES ---------------------- */ #[test] diff --git a/tests/test_serialize.rs b/tests/test_serialize.rs index 6dde39d..bb97db9 100644 --- a/tests/test_serialize.rs +++ b/tests/test_serialize.rs @@ -27,9 +27,11 @@ mod test_serialize { protocol_version: LATEST_PROTOCOL_VERSION.to_string(), }); + let client_request = ClientRequest::InitializeRequest(request); + let message: ClientMessage = ClientMessage::Request(ClientJsonrpcRequest::new( RequestId::Integer(15), - RequestFromClient::ClientRequest(ClientRequest::InitializeRequest(request)), + RequestFromClient::ClientRequest(client_request.clone()), )); let message: ClientMessage = re_serialize(message); @@ -38,6 +40,17 @@ mod test_serialize { if matches!(&client_message.request, RequestFromClient::ClientRequest(client_request) if matches!(client_request, ClientRequest::InitializeRequest(_))) )); + + // test From for RequestFromClient + let message: ClientMessage = + ClientMessage::Request(ClientJsonrpcRequest::new(RequestId::Integer(15), client_request.into())); + + let message: ClientMessage = re_serialize(message); + + assert!(matches!(message, ClientMessage::Request(client_message) + if matches!(&client_message.request, RequestFromClient::ClientRequest(client_request) + if matches!(client_request, ClientRequest::InitializeRequest(_))) + )); } #[test] @@ -158,6 +171,67 @@ mod test_serialize { )); } + #[test] + fn test_client_custom_request() { + let message: ClientMessage = ClientMessage::Request(ClientJsonrpcRequest::new( + RequestId::Integer(15), + RequestFromClient::CustomRequest(json!({"method":"my_custom_method"})), + )); + + let message: ClientMessage = re_serialize(message); + + assert!(matches!(message, ClientMessage::Request(client_message) + if matches!(&client_message.request, RequestFromClient::CustomRequest(_)) && client_message.method == "my_custom_method" + )); + + // test From for RequestFromClient + let message: ClientMessage = ClientMessage::Request(ClientJsonrpcRequest::new( + RequestId::Integer(15), + json!({"method":"my_custom_method"}).into(), + )); + + let message: ClientMessage = re_serialize(message); + + assert!(matches!(message, ClientMessage::Request(client_message) + if matches!(&client_message.request, RequestFromClient::CustomRequest(_)) && client_message.method == "my_custom_method" + )); + } + + /* ---------------------- CLIENT RESPONSES ---------------------- */ + #[test] + fn test_list_tools_result() { + let client_result = ClientResult::CreateMessageResult(CreateMessageResult { + content: CreateMessageResultContent::TextContent(TextContent::new(None, "This is a stub response.".to_string())), + meta: None, + model: "stub-model".to_string(), + role: Role::Assistant, + stop_reason: Some("endTurn".to_string()), + }); + + let message: ClientMessage = ClientMessage::Response(ClientJsonrpcResponse::new( + RequestId::Integer(15), + ResultFromClient::ClientResult(client_result.clone()), + )); + + let message: ClientMessage = re_serialize(message); + + assert!(matches!(message, ClientMessage::Response(client_message) + if matches!(&client_message.result, ResultFromClient::ClientResult(client_result) + if matches!( client_result, ClientResult::CreateMessageResult(_)) + ) + )); + + // test From for ResultFromClient + let message: ClientMessage = + ClientMessage::Response(ClientJsonrpcResponse::new(RequestId::Integer(15), client_result.into())); + + assert!(matches!(message, ClientMessage::Response(client_message) + if matches!(&client_message.result, ResultFromClient::ClientResult(client_result) + if matches!( client_result, ClientResult::CreateMessageResult(_)) + ) + )); + } + /* ---------------------- SERVER RESPONSES ---------------------- */ #[test] @@ -348,6 +422,52 @@ mod test_serialize { if matches!( client_notification, ClientNotification::CancelledNotification(notification) if notification.params.reason == Some("Request timed out".to_string()))) )); } + + #[test] + fn test_client_custom_notification() { + let message: ClientMessage = ClientMessage::Notification(ClientJsonrpcNotification::new( + NotificationFromClient::CustomNotification(json!({"method":"my_notification"})), + )); + + let message: ClientMessage = re_serialize(message); + + // test Display trait + let str = message.to_string(); + assert_eq!(str, "{\"jsonrpc\":\"2.0\",\"method\":\"my_notification\",\"params\":{\"method\":\"my_notification\",\"params\":{\"method\":\"my_notification\"}}}"); + + assert!(matches!(message, ClientMessage::Notification(client_message) + if matches!(&client_message.notification, NotificationFromClient::CustomNotification(_)) && client_message.method == "my_notification" + )); + } + + /* ---------------------- SERVER NOTIFICATIONS ---------------------- */ + #[test] + fn test_server_cancel_notification() { + let cancel_notification = CancelledNotification::new(CancelledNotificationParams { + reason: Some("Request timed out".to_string()), + request_id: RequestId::Integer(15), + }); + let message: ServerMessage = + ServerMessage::Notification(ServerJsonrpcNotification::new(NotificationFromServer::ServerNotification( + ServerNotification::CancelledNotification(cancel_notification.clone()), + ))); + + let message: ServerMessage = re_serialize(message); + + assert!(matches!(message, ServerMessage::Notification(client_message) + if matches!(&client_message.notification,NotificationFromServer::ServerNotification(client_notification) + if matches!( client_notification, ServerNotification::CancelledNotification(_))) + )); + + // test From for NotificationFromServer + let message: ServerMessage = ServerMessage::Notification(ServerJsonrpcNotification::new(cancel_notification.into())); + + assert!(matches!(message, ServerMessage::Notification(client_message) + if matches!(&client_message.notification,NotificationFromServer::ServerNotification(client_notification) + if matches!( client_notification, ServerNotification::CancelledNotification(_))) + )); + } + /* ---------------------- SERVER REQUESTS ---------------------- */ #[test] fn test_server_requests() { @@ -375,6 +495,27 @@ mod test_serialize { )); } + #[test] + fn test_client_custom_server_request() { + let message: ServerMessage = ServerMessage::Request(ServerJsonrpcRequest::new( + RequestId::Integer(15), + RequestFromServer::CustomRequest(json!({"method":"my_custom_method"})), + )); + + // test Display trait + let str = message.to_string(); + assert_eq!( + str, + "{\"id\":15,\"jsonrpc\":\"2.0\",\"method\":\"my_custom_method\",\"params\":{\"method\":\"my_custom_method\"}}" + ); + + let message: ServerMessage = re_serialize(message); + + assert!(matches!(message, ServerMessage::Request(server_message) + if matches!(&server_message.request, RequestFromServer::CustomRequest(_)) && server_message.method == "my_custom_method" + )); + } + /* ---------------------- CLIENT & SERVER ERRORS ---------------------- */ #[test]