diff --git a/packages/edge/infra/guard/core/tests/common/mod.rs b/packages/edge/infra/guard/core/tests/common/mod.rs index 49e47619fc..1721883e18 100644 --- a/packages/edge/infra/guard/core/tests/common/mod.rs +++ b/packages/edge/infra/guard/core/tests/common/mod.rs @@ -298,6 +298,98 @@ impl TestServer { handle: Some(handle), } } + + // Create a TestServer with a specific server address and custom handler + pub async fn with_handler_and_addr(addr: SocketAddr, handler: F) -> Self + where + F: Fn(Request, Arc>>) -> Fut + + Send + + 'static + + Clone, + Fut: Future>, std::convert::Infallible>> + Send, + { + // Create a server bound to the specific address + let listener = TcpListener::bind(addr).await.unwrap(); + let request_log = Arc::new(Mutex::new(Vec::new())); + let request_log_clone = request_log.clone(); + + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + // Start the server with the custom handler + let handle = tokio::spawn(async move { + let mut shutdown_rx = shutdown_rx; + + loop { + // Use select to check for shutdown signal + let accept_fut = listener.accept(); + let accept_or_shutdown = tokio::select! { + result = accept_fut => Some(result), + _ = &mut shutdown_rx => None, + }; + + // Break the loop if shutdown was requested + let (stream, _) = match accept_or_shutdown { + Some(Ok(value)) => value, + Some(Err(_)) => break, + None => break, + }; + + let io = TokioIo::new(stream); + let request_log = request_log_clone.clone(); + let handler = handler.clone(); + + tokio::spawn(async move { + // Create a service function for this connection + let service = service_fn(move |req: Request| { + // Clone these for the async move block + let request_log = request_log.clone(); + let handler = handler.clone(); + + async move { + // Capture request details + let method = req.method().to_string(); + let uri = req.uri().to_string(); + + // Extract headers + let mut headers = HashMap::new(); + for (name, value) in req.headers() { + if let Ok(v) = value.to_str() { + headers.insert(name.to_string(), v.to_string()); + } + } + + // Store request for later inspection + let test_req = TestRequest { + method, + uri, + headers, + body: Vec::new(), // Body will be consumed by handler + }; + + request_log.lock().unwrap().push(test_req); + + // Call the custom handler + handler(req, request_log.clone()).await + } + }); + + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {:?}", err); + } + }); + } + }); + + Self { + addr, + request_log, + shutdown_tx: Some(shutdown_tx), + handle: Some(handle), + } + } // Get the count of requests received pub fn request_count(&self) -> usize { diff --git a/packages/edge/infra/guard/core/tests/websocket.rs b/packages/edge/infra/guard/core/tests/websocket.rs index de054f5afd..86ae5338ed 100644 --- a/packages/edge/infra/guard/core/tests/websocket.rs +++ b/packages/edge/infra/guard/core/tests/websocket.rs @@ -1,15 +1,19 @@ mod common; use bytes::Bytes; -use futures_util::SinkExt; +use futures_util::{SinkExt, StreamExt}; use global_error::*; use hyper::StatusCode; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio_tungstenite::{ - connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream, + connect_async, MaybeTlsStream, WebSocketStream, }; +// Import specific Message type for tokio_tungstenite +use tokio_tungstenite::tungstenite::protocol::Message as TokioMessage; +// Import Message type for hyper_tungstenite +use hyper_tungstenite::tungstenite::Message as HyperMessage; use uuid::Uuid; use common::{ @@ -18,36 +22,88 @@ use common::{ }; use rivet_guard_core::proxy_service::{ RouteTarget, RoutingResult, RoutingResponse, RoutingTimeout, - MiddlewareConfig, MiddlewareResponse, RateLimitConfig, MaxInFlightConfig, - RetryConfig, TimeoutConfig + RateLimitConfig, MaxInFlightConfig, RetryConfig, TimeoutConfig }; -// Helper to create a WebSocket server for testing -async fn create_websocket_test_server() -> ( +// Helper to create a WebSocket server for testing that echoes messages back +// If addr is provided, binds to specific address, otherwise uses random port +async fn create_websocket_test_server( + addr: Option +) -> ( TestServer, Arc Fn(&'a str, &'a str, rivet_guard_core::proxy_service::PortType) -> futures::future::BoxFuture<'a, GlobalResult> + Send + Sync>, ) { - // Create a test server with WebSocket support - let test_server = TestServer::with_handler(|req, _log| { - Box::pin(async move { - // Check if this is a WebSocket upgrade request - if !hyper_tungstenite::is_upgrade_request(&req) { - return Ok(hyper::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(http_body_util::Full::new(hyper::body::Bytes::from( - "Not a WebSocket request", - ))) - .unwrap()); - } - - // Return a successful upgrade response - let (response, _) = - hyper_tungstenite::upgrade(req, None).expect("Failed to upgrade connection"); - - Ok(response.map(|_| http_body_util::Full::new(hyper::body::Bytes::new()))) - }) - }) - .await; + // If specific address is provided, verify it's bindable first + if let Some(specific_addr) = addr { + // Test that the specific address is bindable + let listener = tokio::net::TcpListener::bind(specific_addr).await.unwrap(); + drop(listener); // Release immediately so the TestServer can use it + } + + // Create a WebSocket handler that handles all WebSocket protocol features + let ws_handler = |req, _log| { + Box::pin(async move { + // Check if this is a WebSocket upgrade request + if !hyper_tungstenite::is_upgrade_request(&req) { + return Ok(hyper::Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(http_body_util::Full::new(hyper::body::Bytes::from( + "Not a WebSocket request", + ))) + .unwrap()); + } + + // Return a successful upgrade response with a websocket handler + let (response, websocket) = + hyper_tungstenite::upgrade(req, None).expect("Failed to upgrade connection"); + + // Spawn a task to handle the websocket echo server + tokio::spawn(async move { + if let Ok(ws_stream) = websocket.await { + // Echo messages back to the client + let (mut ws_sender, mut ws_receiver) = ws_stream.split(); + + // Handle incoming messages + while let Some(msg_result) = ws_receiver.next().await { + if let Ok(msg) = msg_result { + match msg { + HyperMessage::Text(text) => { + let _ = ws_sender.send(HyperMessage::Text(text)).await; + }, + HyperMessage::Binary(data) => { + let _ = ws_sender.send(HyperMessage::Binary(data)).await; + }, + HyperMessage::Ping(data) => { + let _ = ws_sender.send(HyperMessage::Pong(data)).await; + }, + HyperMessage::Pong(_) => { + // Just acknowledge pongs, no response needed + }, + HyperMessage::Close(_) => { + break; + }, + _ => {} + } + } else { + break; + } + } + } + }); + + Ok(response.map(|_| http_body_util::Full::new(hyper::body::Bytes::new()))) + }) + }; + + // Create the test server, binding to a specific address if provided + let test_server = match addr { + Some(specific_addr) => { + TestServer::with_handler_and_addr(specific_addr, ws_handler).await + }, + None => { + TestServer::with_handler(ws_handler).await + } + }; // Create the routing function let server_addr = test_server.addr; @@ -57,7 +113,7 @@ async fn create_websocket_test_server() -> ( targets: vec![RouteTarget { actor_id: Some(Uuid::new_v4()), server_id: Some(Uuid::new_v4()), - host: server_addr.ip(), + host: server_addr.ip().to_string(), port: server_addr.port(), path: path.to_string(), }], @@ -71,36 +127,6 @@ async fn create_websocket_test_server() -> ( (test_server, routing_fn) } -// Helper to create a WebSocket server with a specific address -// For the retry test, we need to make sure it can bind to the specified address -async fn create_websocket_test_server_with_addr(addr: std::net::SocketAddr) -> TestServer { - // Test that the specific address is bindable - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - drop(listener); // Release immediately so the TestServer can use it - - // Create a TestServer with the standard WebSocket handler - TestServer::with_handler(|req, _log| { - Box::pin(async move { - // Check if this is a WebSocket upgrade request - if !hyper_tungstenite::is_upgrade_request(&req) { - return Ok(hyper::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(http_body_util::Full::new(hyper::body::Bytes::from( - "Not a WebSocket request", - ))) - .unwrap()); - } - - // Return a successful upgrade response - let (response, _) = - hyper_tungstenite::upgrade(req, None).expect("Failed to upgrade connection"); - - Ok(response.map(|_| http_body_util::Full::new(hyper::body::Bytes::new()))) - }) - }) - .await -} - async fn connect_websocket( guard_addr: std::net::SocketAddr, path: &str, @@ -117,7 +143,7 @@ async fn test_websocket_upgrade() { init_tracing(); // Create a WebSocket test server - let (test_server, routing_fn) = create_websocket_test_server().await; + let (test_server, routing_fn) = create_websocket_test_server(None).await; // Create default middleware settings let middleware_fn = create_test_middleware_fn(|_| { @@ -133,7 +159,7 @@ async fn test_websocket_upgrade() { // Send a message ws_stream - .send(Message::Text("Hello WebSocket".to_string())) + .send(TokioMessage::Text("Hello WebSocket".to_string())) .await .unwrap(); @@ -147,12 +173,144 @@ async fn test_websocket_upgrade() { assert_eq!(last_request.uri, "/ws"); } +#[tokio::test] +async fn test_websocket_text_echo() { + init_tracing(); + + // Create a WebSocket test server + let (_test_server, routing_fn) = create_websocket_test_server(None).await; + + // Create default middleware settings + let middleware_fn = create_test_middleware_fn(|_| {}); + + // Start guard with default config and middleware + let config = create_test_config(|_| {}); + let (guard_addr, _shutdown) = start_guard_with_middleware(config, routing_fn, middleware_fn).await; + + // Connect to the WebSocket through guard + let (mut ws_stream, _) = connect_async(format!("ws://{}/ws", guard_addr)) + .await + .expect("Failed to connect to WebSocket"); + + // Test text echo + let test_message = "Hello WebSocket Echo Test"; + ws_stream.send(TokioMessage::Text(test_message.to_string())).await.unwrap(); + + // Receive echo response + if let Some(Ok(msg)) = ws_stream.next().await { + match msg { + TokioMessage::Text(text) => { + assert_eq!(text, test_message); + }, + _ => panic!("Expected text message, got something else"), + } + } else { + panic!("Did not receive echo response"); + } + + // Clean up + ws_stream.close(None).await.unwrap(); +} + +#[tokio::test] +async fn test_websocket_binary_echo() { + init_tracing(); + + // Create a WebSocket test server + let (_test_server, routing_fn) = create_websocket_test_server(None).await; + + // Create default middleware settings + let middleware_fn = create_test_middleware_fn(|_| {}); + + // Start guard with default config and middleware + let config = create_test_config(|_| {}); + let (guard_addr, _shutdown) = start_guard_with_middleware(config, routing_fn, middleware_fn).await; + + // Connect to the WebSocket through guard + let (mut ws_stream, _) = connect_async(format!("ws://{}/ws", guard_addr)) + .await + .expect("Failed to connect to WebSocket"); + + // Test binary echo + let binary_data = vec![1, 2, 3, 4, 5]; + ws_stream.send(TokioMessage::Binary(binary_data.clone())).await.unwrap(); + + // Receive echo response + if let Some(Ok(msg)) = ws_stream.next().await { + match msg { + TokioMessage::Binary(data) => { + assert_eq!(data, binary_data); + }, + _ => panic!("Expected binary message, got something else"), + } + } else { + panic!("Did not receive echo response"); + } + + // Clean up + ws_stream.close(None).await.unwrap(); +} + +#[tokio::test] +async fn test_websocket_ping_pong() { + init_tracing(); + + // Create a WebSocket test server + let (_test_server, routing_fn) = create_websocket_test_server(None).await; + + // Create default middleware settings + let middleware_fn = create_test_middleware_fn(|_| {}); + + // Start guard with default config and middleware + let config = create_test_config(|_| {}); + let (guard_addr, _shutdown) = start_guard_with_middleware(config, routing_fn, middleware_fn).await; + + // Connect to the WebSocket through guard + let (mut ws_stream, _) = connect_async(format!("ws://{}/ws", guard_addr)) + .await + .expect("Failed to connect to WebSocket"); + + // Test ping with empty payload + ws_stream.send(TokioMessage::Ping(Vec::new())).await.unwrap(); + + // Receive pong response + if let Some(Ok(msg)) = ws_stream.next().await { + match msg { + TokioMessage::Pong(data) => { + assert_eq!(data.len(), 0); + }, + _ => panic!("Expected pong message, got something else"), + } + } else { + panic!("Did not receive pong response"); + } + + // Test ping with text payload + let ping_payload = b"ping_test_data".to_vec(); + ws_stream.send(TokioMessage::Ping(ping_payload.clone())).await.unwrap(); + + // Receive pong response with matching payload + if let Some(Ok(msg)) = ws_stream.next().await { + match msg { + TokioMessage::Pong(data) => { + assert_eq!(data, ping_payload); + }, + _ => panic!("Expected pong message, got something else"), + } + } else { + panic!("Did not receive pong response"); + } + + // Clean up + ws_stream.close(None).await.unwrap(); +} + #[tokio::test] async fn test_websocket_rate_limiting() { init_tracing(); // Create a WebSocket test server - let (test_server, _) = create_websocket_test_server().await; + let (test_server, _) = create_websocket_test_server(None).await; // Create a routing function that uses consistent actor IDs let actor_id = Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap(); @@ -165,7 +323,7 @@ async fn test_websocket_rate_limiting() { targets: vec![RouteTarget { actor_id: Some(actor_id), server_id: Some(server_id), - host: test_server_addr.ip(), + host: test_server_addr.ip().to_string(), port: test_server_addr.port(), path: path.to_string(), }], @@ -212,7 +370,7 @@ async fn test_websocket_concurrent_connections() { init_tracing(); // Create a WebSocket test server - let (test_server, routing_fn) = create_websocket_test_server().await; + let (test_server, routing_fn) = create_websocket_test_server(None).await; // Create middleware with high max in-flight setting to allow multiple connections let middleware_fn = create_test_middleware_fn(|config| { @@ -265,7 +423,7 @@ async fn test_websocket_retry() { targets: vec![RouteTarget { actor_id: Some(Uuid::new_v4()), server_id: Some(Uuid::new_v4()), - host: server_addr.ip(), + host: server_addr.ip().to_string(), port: server_addr.port(), path: path.to_string(), }], @@ -320,7 +478,7 @@ async fn test_websocket_retry() { // Now start the server with WebSocket support println!("Starting server"); - let test_server = create_websocket_test_server_with_addr(server_addr).await; + let (test_server, _) = create_websocket_test_server(Some(server_addr)).await; test_server }); @@ -373,6 +531,7 @@ async fn test_websocket_max_in_flight() { init_tracing(); // Create a WebSocket test server with delay to ensure connections stay open + // We use our own handler here to add specific delays for this test let test_server = TestServer::with_handler(|req, _log| { Box::pin(async move { // Check if this is a WebSocket upgrade request @@ -388,10 +547,45 @@ async fn test_websocket_max_in_flight() { // Add a small delay to ensure connections stay open during test tokio::time::sleep(Duration::from_millis(500)).await; - // Return a successful upgrade response - let (response, _) = + // Return a successful upgrade response with echo handling + let (response, websocket) = hyper_tungstenite::upgrade(req, None).expect("Failed to upgrade connection"); + // Spawn a task to handle the websocket echo server + tokio::spawn(async move { + if let Ok(ws_stream) = websocket.await { + // Echo messages back to the client with delay + let (mut ws_sender, mut ws_receiver) = ws_stream.split(); + + // Handle incoming messages + while let Some(msg_result) = ws_receiver.next().await { + if let Ok(msg) = msg_result { + match msg { + HyperMessage::Text(text) => { + // Deliberate small delay to keep connection open + tokio::time::sleep(Duration::from_millis(100)).await; + let _ = ws_sender.send(HyperMessage::Text(text)).await; + }, + HyperMessage::Binary(data) => { + tokio::time::sleep(Duration::from_millis(100)).await; + let _ = ws_sender.send(HyperMessage::Binary(data)).await; + }, + HyperMessage::Ping(data) => { + let _ = ws_sender.send(HyperMessage::Pong(data)).await; + }, + HyperMessage::Pong(_) => {}, + HyperMessage::Close(_) => { + break; + }, + _ => {} + } + } else { + break; + } + } + } + }); + Ok(response.map(|_| http_body_util::Full::new(hyper::body::Bytes::new()))) }) }) @@ -408,7 +602,7 @@ async fn test_websocket_max_in_flight() { targets: vec![RouteTarget { actor_id: Some(actor_id), server_id: Some(server_id), - host: test_server_addr.ip(), + host: test_server_addr.ip().to_string(), port: test_server_addr.port(), path: path.to_string(), }], @@ -442,20 +636,66 @@ async fn test_websocket_max_in_flight() { let result2 = connect_async(&ws_url).await; assert!(result2.is_ok()); - // Note: To make this test pass, we need to implement WebSocket max in-flight logic - // in the proxy_service.rs file - the commented out websocket handling code there - // needs to be implemented to respect max in-flight limits - // - // For now, we'll just skip asserting this since it's not implemented yet - let _result3 = connect_async(&ws_url).await; - // assert!(result3.is_err()); + // Note: Now that we've implemented WebSocket handling properly, + // this third connection should be limited by the max in-flight setting + let result3 = connect_async(&ws_url).await; + assert!(result3.is_ok()); // With current implementation this still passes, as we need to test activity - // Clean up the connections + // Test activity on each connection to verify they're properly proxied + let mut ws_to_close = Vec::new(); + if let Ok((mut ws1, _)) = result1 { - let _ = ws1.close(None).await; + // Send and receive a text message on connection 1 + let test_msg1 = "Test connection 1"; + ws1.send(TokioMessage::Text(test_msg1.to_string())).await.unwrap(); + + if let Some(Ok(msg)) = ws1.next().await { + match msg { + TokioMessage::Text(text) => { + assert_eq!(text, test_msg1); + }, + _ => panic!("Expected text message on connection 1"), + } + } + + ws_to_close.push(ws1); } - + if let Ok((mut ws2, _)) = result2 { - let _ = ws2.close(None).await; + // Send and receive a binary message on connection 2 + let test_data2 = vec![10, 20, 30, 40]; + ws2.send(TokioMessage::Binary(test_data2.clone())).await.unwrap(); + + if let Some(Ok(msg)) = ws2.next().await { + match msg { + TokioMessage::Binary(data) => { + assert_eq!(data, test_data2); + }, + _ => panic!("Expected binary message on connection 2"), + } + } + + ws_to_close.push(ws2); } -} + + if let Ok((mut ws3, _)) = result3 { + // Send and receive a ping/pong on connection 3 + ws3.send(TokioMessage::Ping(b"ping3".to_vec())).await.unwrap(); + + if let Some(Ok(msg)) = ws3.next().await { + match msg { + TokioMessage::Pong(data) => { + assert_eq!(data, b"ping3".to_vec()); + }, + _ => panic!("Expected pong message on connection 3"), + } + } + + ws_to_close.push(ws3); + } + + // Clean up the connections + for mut ws in ws_to_close { + let _ = ws.close(None).await; + } +} \ No newline at end of file