Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions packages/edge/infra/guard/core/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,98 @@ impl TestServer {
handle: Some(handle),
}
}

// Create a TestServer with a specific server address and custom handler
pub async fn with_handler_and_addr<F, Fut>(addr: SocketAddr, handler: F) -> Self
where
F: Fn(Request<hyper::body::Incoming>, Arc<Mutex<Vec<TestRequest>>>) -> Fut
+ Send
+ 'static
+ Clone,
Fut: Future<Output = Result<Response<Full<Bytes>>, std::convert::Infallible>> + Send,
{
// Create a server bound to the specific address
let listener = TcpListener::bind(addr).await.unwrap();
let request_log = Arc::new(Mutex::new(Vec::new()));
let request_log_clone = request_log.clone();

let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();

// Start the server with the custom handler
let handle = tokio::spawn(async move {
let mut shutdown_rx = shutdown_rx;

loop {
// Use select to check for shutdown signal
let accept_fut = listener.accept();
let accept_or_shutdown = tokio::select! {
result = accept_fut => Some(result),
_ = &mut shutdown_rx => None,
};

// Break the loop if shutdown was requested
let (stream, _) = match accept_or_shutdown {
Some(Ok(value)) => value,
Some(Err(_)) => break,
None => break,
};

let io = TokioIo::new(stream);
let request_log = request_log_clone.clone();
let handler = handler.clone();

tokio::spawn(async move {
// Create a service function for this connection
let service = service_fn(move |req: Request<hyper::body::Incoming>| {
// Clone these for the async move block
let request_log = request_log.clone();
let handler = handler.clone();

async move {
// Capture request details
let method = req.method().to_string();
let uri = req.uri().to_string();

// Extract headers
let mut headers = HashMap::new();
for (name, value) in req.headers() {
if let Ok(v) = value.to_str() {
headers.insert(name.to_string(), v.to_string());
}
}

// Store request for later inspection
let test_req = TestRequest {
method,
uri,
headers,
body: Vec::new(), // Body will be consumed by handler
};

request_log.lock().unwrap().push(test_req);

// Call the custom handler
handler(req, request_log.clone()).await
}
});

if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
});

Self {
addr,
request_log,
shutdown_tx: Some(shutdown_tx),
handle: Some(handle),
}
}
Comment on lines +302 to +392
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Duplicate logic; similar to with_addr. Consider consolidating common parts for reuse.


// Get the count of requests received
pub fn request_count(&self) -> usize {
Expand Down
Loading
Loading