diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f512f21..28d061b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -67,7 +67,7 @@ jobs: fail-fast: false matrix: settings: - - host: macos-latest + - host: macos-13 target: x86_64-apple-darwin build: pnpm run build --target x86_64-apple-darwin - host: macos-latest @@ -185,7 +185,6 @@ jobs: - '3.11' - '3.12' - '3.13' - # - '3.14-rc' runs-on: ${{ matrix.settings.host }} steps: - uses: actions/checkout@v4 @@ -206,6 +205,11 @@ jobs: with: name: bindings-${{ matrix.settings.target }} path: . + - name: Remove old prebuilt binaries + run: | + echo "Removing old prebuilt binaries from node_modules..." + rm -rf node_modules/@platformatic/python-node-* + shell: bash - name: List packages run: ls -R . shell: bash @@ -273,6 +277,11 @@ jobs: with: name: bindings-${{ matrix.settings.target }} path: . + - name: Remove old prebuilt binaries + run: | + echo "Removing old prebuilt binaries from node_modules..." + rm -rf node_modules/@platformatic/python-node-* + shell: bash - name: List packages run: ls -R . shell: bash diff --git a/Cargo.lock b/Cargo.lock index 0d2dd19..cae1b56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,9 +10,9 @@ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] @@ -117,9 +117,9 @@ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "clap" -version = "4.5.50" +version = "4.5.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2cfd7bf8a6017ddaa4e32ffe7403d547790db06bd171c1c53926faab501623" +checksum = "4c26d721170e0295f191a69bd9a1f93efcdb0aff38684b61ab5750468972e5f5" dependencies = [ "clap_builder", "clap_derive", @@ -127,9 +127,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.50" +version = "4.5.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4c05b9e80c5ccd3a7ef080ad7b6ba7d6fc00a985b8b157197075677c82c7a0" +checksum = "75835f0c7bf681bfd05abe44e965760fea999a5286c6eb2d59883634fd02011a" dependencies = [ "anstream", "anstyle", @@ -181,9 +181,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59c9b8bdf64ee849747c1b12eb861d21aa47fa161564f48332f1afe2373bf899" +checksum = "3ffc71fcdcdb40d6f087edddf7f8f1f8f79e6cf922f555a9ee8779752d4819bd" dependencies = [ "ctor-proc-macro", "dtor", @@ -197,9 +197,9 @@ checksum = "52560adf09603e58c9a7ee1fe1dcb95a16927b17c127f0ac02d6e768a0e25bc1" [[package]] name = "dtor" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e58a0764cddb55ab28955347b45be00ade43d4d6f3ba4bf3dc354e4ec9432934" +checksum = "404d02eeb088a82cfd873006cb713fe411306c7d182c344905e101fb1167d301" dependencies = [ "dtor-proc-macro", ] @@ -225,9 +225,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" dependencies = [ "crc32fast", "miniz_oxide", @@ -377,23 +377,50 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + [[package]] name = "http-handler" version = "1.0.0" -source = "git+https://github.com/platformatic/http-handler#f60dbc830e12b1cedae4c8a4f370bc4133a868fc" +source = "git+https://github.com/platformatic/http-handler#f82b8791b8c2149739f63a55980398800fe9291e" dependencies = [ - "async-trait", "bytes", + "futures-core", "http", + "http-body", + "http-body-util", "napi", "napi-build", "napi-derive", + "tokio", + "tokio-util", ] [[package]] name = "http-rewriter" version = "1.0.0" -source = "git+https://github.com/platformatic/http-rewriter#2c2319e6721f20a0eebcb37904d15aa411cc668f" +source = "git+https://github.com/platformatic/http-rewriter#244abaece1bafa3225334030e2edcfe34a500d43" dependencies = [ "bytes", "http", @@ -640,9 +667,9 @@ checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "proc-macro2" -version = "1.0.101" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" dependencies = [ "unicode-ident", ] @@ -727,6 +754,8 @@ version = "1.0.0" dependencies = [ "async-trait", "bytes", + "http", + "http-body-util", "http-handler", "http-rewriter", "libc", @@ -890,9 +919,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "syn" -version = "2.0.107" +version = "2.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a26dbd934e5451d21ef060c018dae56fc073894c5a7896f882928a76e6d081b" +checksum = "2f17c7e013e88258aa9543dcbe81aca68a667a9ac37cd69c9fbc07858bfe0e2f" dependencies = [ "proc-macro2", "quote", @@ -953,6 +982,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-util" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -965,9 +1007,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" [[package]] name = "unicode-segmentation" diff --git a/Cargo.toml b/Cargo.toml index c17217a..ed6fb40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,13 @@ napi-support = ["dep:napi", "dep:napi-derive", "dep:napi-build", "http-handler/n [lib] name = "python_node" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [dependencies] async-trait = "0.1.88" bytes = "1.10.1" +http = "1.0" +http-body-util = "0.1" http-handler = { git = "https://github.com/platformatic/http-handler" } # http-handler = { path = "../http-handler" } http-rewriter = { git = "https://github.com/platformatic/http-rewriter" } diff --git a/src/asgi/event_loop_handle.rs b/src/asgi/event_loop_handle.rs new file mode 100644 index 0000000..396d203 --- /dev/null +++ b/src/asgi/event_loop_handle.rs @@ -0,0 +1,281 @@ +//! Python event loop handle management for Rust async integration. +//! +//! This module provides [`EventLoopHandle`], a type that wraps a Python asyncio +//! event loop and ensures proper cleanup when dropped, along with factory functions +//! for creating and managing the shared event loop. + +use std::sync::{Arc, Mutex, OnceLock, Weak}; + +use pyo3::prelude::*; + +use crate::HandlerError; + +/// Handle to a shared Python event loop. +/// +/// This handle manages a Python asyncio event loop that runs in a background thread. +/// When the last handle is dropped, the event loop is stopped. +/// +/// # Thread Safety +/// +/// This type implements `Send` and `Sync` to allow sharing across threads, though +/// the actual Python event loop runs in its own dedicated thread. +pub struct EventLoopHandle { + event_loop: Py, +} + +impl EventLoopHandle { + /// Create a new EventLoopHandle with a Python event loop. + /// + /// This constructor: + /// 1. Ensures Python is initialized with proper symbol visibility + /// 2. Creates a new Python asyncio event loop + /// 3. Starts a background thread to run the event loop + /// 4. Returns a handle to the event loop + /// + /// # Errors + /// + /// Returns `HandlerError` if: + /// - Python initialization fails + /// - Creating the event loop fails + /// - Starting the event loop thread fails + pub fn new() -> Result { + // Ensure Python is initialized with proper symbol visibility + crate::asgi::ensure_python_initialized(); + + // Create event loop + let event_loop = Python::attach(|py| -> Result, HandlerError> { + let asyncio = py.import("asyncio")?; + let event_loop = asyncio.call_method0("new_event_loop")?; + let event_loop_py = event_loop.unbind(); + + // Start Python thread that just runs the event loop + let loop_ = event_loop_py.clone_ref(py); + + // Spawn a dedicated std::thread for the Python event loop + // We cannot use tokio's spawn_blocking because: + // 1. It would attach to the current runtime (e.g., test runtime) + // 2. When that runtime drops, it waits for blocking tasks to complete + // 3. But the Python event loop runs forever, causing a deadlock + std::thread::Builder::new() + .name("python-event-loop".to_string()) + .spawn(move || { + Self::loop_thread(loop_); + }) + .expect("Failed to spawn Python event loop thread"); + + Ok(event_loop_py) + })?; + + Ok(Self::with_loop(event_loop)) + } + + /// Create an EventLoopHandle from an existing Python event loop object. + /// + /// # Arguments + /// + /// * `event_loop` - A Python `asyncio.AbstractEventLoop` object + /// + /// # Note + /// + /// The event loop should already be running in a background thread before + /// creating this handle. This handle only manages the lifecycle, it doesn't + /// start the event loop. + pub fn with_loop(event_loop: Py) -> Self { + Self { event_loop } + } + + /// Get or create a shared Python event loop. + /// + /// This method maintains a weak reference to the shared event loop. If the event loop + /// is still alive, it returns a strong reference to it. Otherwise, it creates a new + /// event loop. + /// + /// # Thread Safety + /// + /// This method is thread-safe and uses a mutex to protect concurrent access to + /// the weak reference. + /// + /// # Errors + /// + /// Returns `HandlerError` if: + /// - The mutex is poisoned + /// - Creating a new event loop fails + pub fn get_or_create() -> Result, HandlerError> { + let mut guard = PYTHON_EVENT_LOOP + .get_or_init(|| Mutex::new(Weak::new())) + .lock()?; + + // Try to upgrade the weak reference + if let Some(handle) = guard.upgrade() { + return Ok(handle); + } + + // Create new handle + let new_handle = Arc::new(Self::new()?); + *guard = Arc::downgrade(&new_handle); + + Ok(new_handle) + } + + /// Run a Python event loop forever in the current thread. + /// + /// This function sets the given event loop as the current event loop for the thread + /// and runs it forever. It's intended to be called in a blocking context (e.g., from + /// `tokio::task::spawn_blocking`). + /// + /// # Arguments + /// + /// * `event_loop` - A Python asyncio event loop object to run + /// + /// # Panics + /// + /// If the Python event loop encounters a fatal error, the error is printed to stderr + /// but the function does not panic. + pub fn loop_thread(event_loop: Py) { + Python::attach(|py| { + // Set the event loop for this thread and run it + let asyncio = py.import("asyncio")?; + asyncio.call_method1("set_event_loop", (event_loop.bind(py),))?; + + // Get the current event loop and run it forever + asyncio + .call_method0("get_event_loop")? + .call_method0("run_forever")?; + + Ok::<(), PyErr>(()) + }) + .unwrap_or_else(|e| { + eprintln!("Python event loop thread error: {e}"); + }); + } + + /// Get a reference to the Python event loop object. + /// + /// Returns a reference to the underlying `Py` that represents + /// the Python asyncio event loop. + pub fn event_loop(&self) -> &Py { + &self.event_loop + } +} + +impl Drop for EventLoopHandle { + fn drop(&mut self) { + // Stop the Python event loop when the last handle is dropped + Python::attach(|py| { + if let Err(e) = self.event_loop.bind(py).call_method0("stop") { + eprintln!("Failed to stop Python event loop: {e}"); + } + }); + } +} + +unsafe impl Send for EventLoopHandle {} +unsafe impl Sync for EventLoopHandle {} + +/// Global Python event loop handle storage +static PYTHON_EVENT_LOOP: OnceLock>> = OnceLock::new(); + +#[cfg(test)] +mod tests { + use super::*; + use crate::asgi::ensure_python_initialized; + + fn ensure_test_python() { + ensure_python_initialized(); + } + + #[test] + fn test_event_loop_handle_creation() { + ensure_test_python(); + + Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + let event_loop_py = event_loop.unbind(); + + let _handle = EventLoopHandle::with_loop(event_loop_py); + // Just verify we can create it + }); + } + + #[test] + fn test_event_loop_handle_getter() { + ensure_test_python(); + + Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + let event_loop_py = event_loop.unbind(); + let event_loop_clone = event_loop_py.clone_ref(py); + + let handle = EventLoopHandle::with_loop(event_loop_py); + + // Verify event_loop() returns the same event loop + assert!(handle.event_loop().is(&event_loop_clone)); + }); + } + + #[test] + fn test_event_loop_handle_is_running() { + ensure_test_python(); + + Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + let event_loop_py = event_loop.unbind(); + + let handle = EventLoopHandle::with_loop(event_loop_py); + + // Verify we can check if the loop is running + let is_running: bool = handle + .event_loop() + .bind(py) + .call_method0("is_running") + .unwrap() + .extract() + .unwrap(); + + // Should not be running since we never started it + assert!(!is_running); + }); + } + + #[test] + fn test_event_loop_handle_drop_stops_loop() { + ensure_test_python(); + + let event_loop_py = Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + event_loop.unbind() + }); + + let event_loop_clone = Python::attach(|py| event_loop_py.clone_ref(py)); + + { + let handle = EventLoopHandle::with_loop(event_loop_py); + drop(handle); // Explicitly drop to trigger stop() + } + + // After drop, calling stop again should be idempotent (or at least not crash) + Python::attach(|py| { + // The loop should still be accessible but calling stop again is fine + let result = event_loop_clone.bind(py).call_method0("stop"); + // Either it succeeds (idempotent) or it was already stopped + // We just verify it doesn't panic + let _ = result; + }); + } + + #[test] + fn test_event_loop_handle_send_sync() { + ensure_test_python(); + + // This test verifies that EventLoopHandle implements Send and Sync + fn is_send() {} + fn is_sync() {} + + is_send::(); + is_sync::(); + } +} diff --git a/src/asgi/http.rs b/src/asgi/http.rs index f47ee32..121fe6c 100644 --- a/src/asgi/http.rs +++ b/src/asgi/http.rs @@ -320,26 +320,26 @@ impl<'py> FromPyObject<'py> for HttpSendMessage { let mut headers: Vec<(String, String)> = Vec::new(); if let Ok(headers_list) = headers_py.downcast::() { for item in headers_list.iter() { - if let Ok(header_pair) = item.downcast::() { - if header_pair.len() == 2 { - let name = header_pair.get_item(0)?; - let value = header_pair.get_item(1)?; - - // Convert bytes to string - let name_str = if let Ok(bytes) = name.downcast::() { - String::from_utf8_lossy(bytes.as_bytes()).to_string() - } else { - name.extract::()? - }; - - let value_str = if let Ok(bytes) = value.downcast::() { - String::from_utf8_lossy(bytes.as_bytes()).to_string() - } else { - value.extract::()? - }; - - headers.push((name_str, value_str)); - } + if let Ok(header_pair) = item.downcast::() + && header_pair.len() == 2 + { + let name = header_pair.get_item(0)?; + let value = header_pair.get_item(1)?; + + // Convert bytes to string + let name_str = if let Ok(bytes) = name.downcast::() { + String::from_utf8_lossy(bytes.as_bytes()).to_string() + } else { + name.extract::()? + }; + + let value_str = if let Ok(bytes) = value.downcast::() { + String::from_utf8_lossy(bytes.as_bytes()).to_string() + } else { + value.extract::()? + }; + + headers.push((name_str, value_str)); } } } @@ -414,7 +414,7 @@ mod tests { .header("authorization", "Bearer token123") .header("user-agent", "test-client/1.0") .header("x-custom-header", "custom-value") - .body(bytes::BytesMut::from("request body")) + .body(http_handler::RequestBody::new()) .unwrap(); // Set socket info extension @@ -482,7 +482,7 @@ mod tests { let request = Builder::new() .method(Method::GET) .uri("/") - .body(bytes::BytesMut::new()) + .body(http_handler::RequestBody::new()) .unwrap(); let scope: HttpConnectionScope = (&request) @@ -509,7 +509,7 @@ mod tests { .method(Method::PUT) .uri("http://api.example.com/resource/123") .version(Version::HTTP_2) - .body(bytes::BytesMut::new()) + .body(http_handler::RequestBody::new()) .unwrap(); let scope: HttpConnectionScope = (&request) diff --git a/src/asgi/lifespan.rs b/src/asgi/lifespan.rs index 61c0490..fd61497 100644 --- a/src/asgi/lifespan.rs +++ b/src/asgi/lifespan.rs @@ -5,6 +5,7 @@ use pyo3::types::PyDict; use crate::asgi::AsgiInfo; /// The lifespan scope exists for the duration of the event loop. +#[allow(dead_code)] #[derive(Debug)] pub struct LifespanScope { /// An empty namespace where the application can persist state to be used diff --git a/src/asgi/mod.rs b/src/asgi/mod.rs index 7cf34b8..1c1699e 100644 --- a/src/asgi/mod.rs +++ b/src/asgi/mod.rs @@ -3,50 +3,46 @@ use std::{ ffi::CString, fs::{read_dir, read_to_string}, path::{Path, PathBuf}, - sync::{Arc, Mutex, OnceLock, Weak}, + sync::Arc, }; #[cfg(target_os = "linux")] use std::{ffi::CStr, mem}; use bytes::BytesMut; -use http_handler::{Handler, Request, RequestExt, Response, extensions::DocumentRoot}; +use http_handler::{ + Handler, Request, RequestBody, RequestExt, Response, ResponseException, WebSocketMode, + extensions::DocumentRoot, websocket::WebSocketEncoder, +}; use pyo3::prelude::*; use pyo3::types::PyModule; -use tokio::sync::oneshot; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::{Mutex, oneshot}; use crate::{HandlerError, PythonHandlerTarget}; -/// HTTP response tuple: (status_code, headers, body) -type HttpResponse = (u16, Vec<(String, String)>, Vec); +/// HTTP response tuple: (status_code, headers) +type HttpResponse = (u16, Vec<(String, String)>); /// Result type for HTTP response operations type HttpResponseResult = Result; -/// Global runtime for when no tokio runtime is available -static FALLBACK_RUNTIME: OnceLock = OnceLock::new(); - -fn fallback_handle() -> tokio::runtime::Handle { - tokio::runtime::Handle::try_current().unwrap_or_else(|_| { - // No runtime exists, create a fallback one - let rt = FALLBACK_RUNTIME.get_or_init(|| { - tokio::runtime::Runtime::new().expect("Failed to create fallback tokio runtime") - }); - rt.handle().clone() - }) -} - -/// Global Python event loop handle storage -static PYTHON_EVENT_LOOP: OnceLock>> = OnceLock::new(); - +mod event_loop_handle; mod http; mod http_method; mod http_version; mod info; mod lifespan; +mod python_future_poller; mod receiver; +mod runtime_handle; mod sender; mod websocket; +use event_loop_handle::EventLoopHandle; +use python_future_poller::PythonFuturePoller; + +pub(crate) use runtime_handle::fallback_handle; + pub use http::{HttpConnectionScope, HttpReceiveMessage, HttpSendMessage}; pub use http_method::HttpMethod; pub use http_version::HttpVersion; @@ -60,79 +56,6 @@ pub use websocket::{ WebSocketConnectionScope, WebSocketReceiveMessage, WebSocketSendException, WebSocketSendMessage, }; -/// Handle to a shared Python event loop -pub struct EventLoopHandle { - event_loop: Py, -} - -impl EventLoopHandle { - /// Get the Python event loop object - pub fn event_loop(&self) -> &Py { - &self.event_loop - } -} - -impl Drop for EventLoopHandle { - fn drop(&mut self) { - // Stop the Python event loop when the last handle is dropped - Python::attach(|py| { - if let Err(e) = self.event_loop.bind(py).call_method0("stop") { - eprintln!("Failed to stop Python event loop: {e}"); - } - }); - } -} - -unsafe impl Send for EventLoopHandle {} -unsafe impl Sync for EventLoopHandle {} - -/// Ensure a Python event loop exists and return a handle to it -fn ensure_python_event_loop() -> Result, HandlerError> { - let mut guard = PYTHON_EVENT_LOOP - .get_or_init(|| Mutex::new(Weak::new())) - .lock()?; - - // Try to upgrade the weak reference - if let Some(handle) = guard.upgrade() { - return Ok(handle); - } - - // Create new handle - let new_handle = Arc::new(create_event_loop_handle()?); - *guard = Arc::downgrade(&new_handle); - - Ok(new_handle) -} - -/// Create a new EventLoopHandle with a Python event loop -fn create_event_loop_handle() -> Result { - // Ensure Python symbols are globally available before initializing - #[cfg(target_os = "linux")] - ensure_python_symbols_global(); - - // Initialize Python if not already initialized - Python::initialize(); - - // Create event loop - let event_loop = Python::attach(|py| -> Result, HandlerError> { - let asyncio = py.import("asyncio")?; - let event_loop = asyncio.call_method0("new_event_loop")?; - let event_loop_py = event_loop.unbind(); - - // Start Python thread that just runs the event loop - let loop_ = event_loop_py.clone_ref(py); - - // Try to use current runtime, fallback to creating a new one - fallback_handle().spawn_blocking(move || { - start_python_event_loop_thread(loop_); - }); - - Ok(event_loop_py) - })?; - - Ok(EventLoopHandle { event_loop }) -} - /// Core ASGI handler that loads and manages a Python ASGI application pub struct Asgi { docroot: PathBuf, @@ -152,12 +75,13 @@ impl Asgi { app_target: Option, ) -> Result { let target = app_target.unwrap_or_default(); + let docroot = docroot .map(|d| Ok(PathBuf::from(d))) .unwrap_or_else(|| current_dir().map_err(HandlerError::CurrentDirectoryError))?; // Get or create shared Python event loop - let event_loop_handle = ensure_python_event_loop()?; + let event_loop_handle = EventLoopHandle::get_or_create()?; // Load Python app let app_function = Python::attach(|py| -> Result, HandlerError> { @@ -198,117 +122,549 @@ impl Asgi { pub fn docroot(&self) -> &Path { &self.docroot } - - /// Handle a request synchronously - pub fn handle_sync(&self, request: Request) -> Result { - fallback_handle().block_on(self.handle(request)) - } } -#[async_trait::async_trait] -impl Handler for Asgi { - type Error = HandlerError; - - async fn handle(&self, request: Request) -> Result { - // Set document root extension - let mut request = request; - request.set_document_root(DocumentRoot { - path: self.docroot.clone(), - }); +// Helper function: Forward HTTP request data from DuplexStream to Python +// Returns when stream ends or error occurs +async fn forward_http_request( + mut request_stream: R, + rx: tokio::sync::mpsc::UnboundedSender, +) where + R: tokio::io::AsyncRead + Unpin, +{ + const BUFFER_SIZE: usize = 64 * 1024; // 64KB buffer + let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); + loop { + let n = match request_stream.read_buf(&mut buffer).await { + Ok(n) => n, + Err(_) => break, + }; - // Create ASGI scope - let scope: HttpConnectionScope = (&request).try_into()?; + if n == 0 { + // EOF - send final message + let _ = rx.send(HttpReceiveMessage::Request { + body: vec![], + more_body: false, + }); + break; + } - // Create channels for ASGI communication - let (rx_receiver, rx) = Receiver::http(); - let (tx_sender, tx_receiver) = Sender::http(); + // Send the data we read + let data = buffer.split_to(n).to_vec(); + if rx + .send(HttpReceiveMessage::Request { + body: data, + more_body: true, + }) + .is_err() + { + // Python dropped receiver + break; + } + } +} - // Send request body - let request_message = HttpReceiveMessage::Request { - body: request.body().to_vec(), - more_body: false, - }; - rx.send(request_message).map_err(|_| { - HandlerError::PythonError(PyErr::new::( - "Failed to send request", - )) - })?; +// Helper function: Forward WebSocket request data from DuplexStream to Python +// Returns when stream ends or error occurs +async fn forward_websocket_request( + request_stream: R, + rx: tokio::sync::mpsc::UnboundedSender, +) where + R: tokio::io::AsyncRead + Unpin + Send + 'static, +{ + // JavaScript now sends WebSocket frames (auto-encoded by Request::write()) + // Use WebSocketDecoder to decode them + let mut decoder = http_handler::websocket::WebSocketDecoder::new(request_stream); - // Create response channel - let (response_tx, response_rx) = oneshot::channel(); + loop { + match decoder.read_message().await { + Ok(Some(frame)) => { + // Got a WebSocket frame - forward to Python based on type + if frame.is_text() { + if let Some(text) = frame.payload_as_text() + && rx + .send(WebSocketReceiveMessage::Receive { + text: Some(text), + bytes: None, + }) + .is_err() + { + // Python receiver dropped + break; + } + } else if frame.is_binary() { + if rx + .send(WebSocketReceiveMessage::Receive { + text: None, + bytes: Some(frame.payload), + }) + .is_err() + { + // Python receiver dropped + break; + } + } else if frame.is_close() { + // Got close frame - send disconnect to Python + let (code, reason) = frame.parse_close_payload().unzip(); + let _ = rx.send(WebSocketReceiveMessage::Disconnect { code, reason }); + break; + } + // Ignore ping/pong frames (handled automatically) + } + Ok(None) => { + // Stream ended - send disconnect + let _ = rx.send(WebSocketReceiveMessage::Disconnect { + code: Some(1000), + reason: None, + }); + break; + } + Err(_) => { + // Error reading - send disconnect + let _ = rx.send(WebSocketReceiveMessage::Disconnect { + code: Some(1006), // Abnormal closure + reason: None, + }); + break; + } + } + } +} - // Submit the ASGI app call to Python event loop - let future = Python::attach(|py| { - let scope_py = scope.into_pyobject(py)?; - let coro = self - .app_function - .call1(py, (scope_py, rx_receiver, tx_sender))?; +// Helper function: Handle HTTP response message from Python and write to DuplexStream +// Returns true if the loop should break +async fn handle_http_response_message( + msg: Option>, + response_tx: &mut Option>, + response_stream: &mut W, +) -> bool +where + W: tokio::io::AsyncWrite + Unpin, +{ + match msg { + Some(AcknowledgedMessage { + message: HttpSendMessage::HttpResponseStart { + status, headers, .. + }, + ack, + }) if response_tx.is_some() => { + // Send response.start back to main task (oneshot - only once) + if let Some(tx) = response_tx.take() + && tx.send(Ok((status, headers))).is_err() + { + // Main task dropped receiver - stop forwarding + return true; + } + if ack.send(()).is_err() { + // Python dropped receiver - stop forwarding + return true; + } + } + Some(AcknowledgedMessage { + message: HttpSendMessage::HttpResponseBody { body, more_body }, + ack, + }) => { + // Acknowledge receipt + if ack.send(()).is_err() { + // Python dropped receiver - stop forwarding + return true; + } - let asyncio = py.import("asyncio")?; - let future = asyncio.call_method1( - "run_coroutine_threadsafe", - (coro, self.event_loop_handle.event_loop()), - )?; + // Write body data if not empty + if !body.is_empty() { + if response_stream.write_all(&body).await.is_err() { + // Client disconnected + return true; + } + } - Ok::, HandlerError>(future.unbind()) - })?; + // Check if this was the final chunk + if !more_body { + // Close the write side + let _ = response_stream.shutdown().await; + return true; + } + } + None => { + // Python sender closed + return true; + } + _ => { + // Ignore other message types (e.g., duplicate response.start) + } + } + false +} - // Spawn task to collect response and monitor for Python exceptions - tokio::spawn(collect_response_with_exception_handling( - tx_receiver, - response_tx, - future, - )); +// Helper function: Handle WebSocket response message from Python and write to DuplexStream +// Returns true if the loop should break +async fn handle_websocket_response_message( + msg: Option>, + encoder: &http_handler::websocket::WebSocketEncoder, +) -> bool +where + W: tokio::io::AsyncWrite + Unpin + Send, +{ + match msg { + Some(ack_msg) => { + match ack_msg.message { + WebSocketSendMessage::Send { text, bytes } => { + // Send WebSocket frames to JavaScript (will be auto-decoded by Response::next()) + let result = if let Some(text) = text { + encoder.write_text(&text, false).await + } else if let Some(bytes) = bytes { + encoder.write_binary(&bytes, false).await + } else { + Ok(()) + }; + + if result.is_err() { + // Client disconnected or write error + return true; + } + } + WebSocketSendMessage::Close { code, reason } => { + // Send close frame + let reason_str = reason.as_deref(); + let _ = encoder.write_close(code, reason_str).await; + return true; + } + _ => {} + } + // Acknowledge receipt + if ack_msg.ack.send(()).is_err() { + // Python dropped receiver - stop forwarding + return true; + } + } + None => { + // Python sender closed + return true; + } + } + false +} - // Wait for response - let (status, headers, body) = response_rx.await??; +// Helper function: Handle Python exception +// Returns true if the loop should break +async fn handle_python_exception( + result: Result, PyErr>, + response_tx: Option>, + response_stream: Option<&mut W>, + response_exception: Option<&Arc>>>, +) -> bool +where + W: tokio::io::AsyncWrite + Unpin, +{ + if let Err(py_err) = result { + let error_msg = py_err.to_string(); + + // Python exception - send error via oneshot if response not yet started + if let Some(tx) = response_tx { + let _ = tx.send(Err(HandlerError::PythonError(py_err))); + } else { + // Response already started - store error for later retrieval and close stream + if let Some(exception_holder) = response_exception { + let mut exc = exception_holder.lock().await; + *exc = Some(ResponseException::new(error_msg)); + } - // Build response - let mut builder = http_handler::response::Builder::new().status(status); - for (name, value) in headers { - builder = builder.header(name.as_bytes(), value.as_bytes()); + if let Some(stream) = response_stream { + use tokio::io::AsyncWriteExt; + let _ = stream.shutdown().await; + } } + } + // Always return true when Python future completes (success or error) + // to exit the forwarding loop + true +} - builder - .body(BytesMut::from(&body[..])) - .map_err(HandlerError::HttpHandlerError) +// Helper function: Handle response timeout +// Returns true if the loop should break +fn handle_response_timeout(response_tx: Option>) -> bool { + // Send timeout error via oneshot + if let Some(tx) = response_tx { + let _ = tx.send(Err(HandlerError::NoResponse)); } + true } -/// Load Python library with RTLD_GLOBAL on Linux to expose interpreter symbols -#[cfg(target_os = "linux")] -fn ensure_python_symbols_global() { - // Only perform the promotion once per process - static GLOBALIZE_ONCE: OnceLock<()> = OnceLock::new(); - - GLOBALIZE_ONCE.get_or_init(|| unsafe { - let mut info: libc::Dl_info = mem::zeroed(); - if libc::dladdr(pyo3::ffi::Py_Initialize as *const _, &mut info) == 0 - || info.dli_fname.is_null() - { - eprintln!("unable to locate libpython for RTLD_GLOBAL promotion"); - return; +// Spawn HTTP forwarding task +fn spawn_http_forwarding_task( + request_stream: R, + mut tx_receiver: tokio::sync::mpsc::UnboundedReceiver>, + rx: tokio::sync::mpsc::UnboundedSender, + response_stream: W, + response_tx: oneshot::Sender, + future: Py, + response_exception: Arc>>, +) where + R: tokio::io::AsyncRead + Unpin + Send + 'static, + W: tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + tokio::spawn(async move { + let mut response_tx = Some(response_tx); + let mut future_poller = PythonFuturePoller::new(future); + let timeout = tokio::time::sleep(tokio::time::Duration::from_secs(30)); + tokio::pin!(timeout); + + // Spawn request forwarding as separate task + let mut request_done = Some(tokio::spawn(forward_http_request(request_stream, rx))); + let mut response_stream = response_stream; + + loop { + tokio::select! { + // Forward response messages from Python to Node.js + response_msg = tx_receiver.recv() => { + if handle_http_response_message( + response_msg, + &mut response_tx, + &mut response_stream, + ).await { + break; + } + } + + // Monitor Python future for exceptions + result = Pin::new(&mut future_poller) => { + if handle_python_exception(result, response_tx.take(), Some(&mut response_stream), Some(&response_exception)).await { + break; + } + } + + // Timeout after 30 seconds without response.start + _ = &mut timeout, if response_tx.is_some() => { + if handle_response_timeout(response_tx.take()) { + break; + } + } + + // Wait for request forwarding to complete (only poll once) + _ = async { request_done.as_mut().unwrap().await }, if request_done.is_some() => { + request_done = None; + } + + // Exit loop if all branches are done + else => break, + } } + }); +} + +// Spawn WebSocket forwarding task +fn spawn_websocket_forwarding_task( + request_stream: R, + mut tx_receiver: tokio::sync::mpsc::UnboundedReceiver>, + rx: tokio::sync::mpsc::UnboundedSender, + response_stream: W, + future: Py, +) where + R: tokio::io::AsyncRead + Unpin + Send + 'static, + W: tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + tokio::spawn(async move { + let mut future_poller = PythonFuturePoller::new(future); + + // Create WebSocket encoder for sending frames to client + let encoder = WebSocketEncoder::new(response_stream); + + // Spawn request forwarding as separate task + let mut request_done = Some(tokio::spawn(forward_websocket_request(request_stream, rx))); + + // Track if close frame was sent (write_close also closes the stream) + let mut close_sent = false; + + loop { + tokio::select! { + // Forward WebSocket messages from Python to client + response_msg = tx_receiver.recv() => { + if handle_websocket_response_message(response_msg, &encoder).await { + close_sent = true; + break; + } + } - let path_cstr = CStr::from_ptr(info.dli_fname); - let path_str = path_cstr.to_string_lossy(); + // Monitor Python future for exceptions + result = Pin::new(&mut future_poller) => { + if handle_python_exception::(result, None, None, None).await { + break; + } + } - // Clear any prior dlerror state before attempting to reopen - libc::dlerror(); + // Wait for request forwarding to complete (only poll once) + _ = async { request_done.as_mut().unwrap().await }, if request_done.is_some() => { + request_done = None; + } - let handle = libc::dlopen(info.dli_fname, libc::RTLD_NOW | libc::RTLD_GLOBAL); - if handle.is_null() { - let error = libc::dlerror(); - if !error.is_null() { - let msg = CStr::from_ptr(error).to_string_lossy(); - eprintln!("dlopen({path_str}) failed with RTLD_GLOBAL: {msg}",); - } else { - eprintln!("dlopen({path_str}) returned null without dlerror",); + // Exit loop if all branches complete + else => break, } } + + // Close the response stream only if close frame wasn't sent + // (write_close already closes the stream) + if !close_sent { + let _ = encoder.end().await; + } }); } +impl Handler for Asgi { + type Error = HandlerError; + + async fn handle(&self, request: Request) -> Result { + // Set document root extension + let mut request = request; + request.set_document_root(DocumentRoot { + path: self.docroot.clone(), + }); + + // Check if this is a WebSocket request + let is_websocket = request.extensions().get::().is_some(); + + // Extract parts + let (parts, body) = request.into_parts(); + + // Create response body + let response_body = body.create_response(); + + // Clone bodies for bidirectional forwarding + // RequestBody implements AsyncRead (reads from read_side) + // ResponseBody implements AsyncWrite (writes to write_side) + let request_reader = body.clone(); + let response_writer = response_body.clone(); + + if is_websocket { + // WebSocket mode + // Create WebSocket scope from parts by temporarily reconstructing a request + let temp_request = Request::from_parts(parts.clone(), RequestBody::new()); + let scope: WebSocketConnectionScope = (&temp_request).try_into()?; + + // Create channels for ASGI communication + let (rx_receiver, rx) = Receiver::websocket(); + let (tx_sender, mut tx_receiver) = Sender::websocket(); + + // Send connect + rx.send(WebSocketReceiveMessage::Connect) + .map_err(|_| HandlerError::NoResponse)?; + + // Submit ASGI app to Python + let future = Python::attach(|py| { + let scope_py = scope.into_pyobject(py)?; + let coro = self + .app_function + .call1(py, (scope_py, rx_receiver, tx_sender))?; + + let asyncio = py.import("asyncio")?; + let future = asyncio.call_method1( + "run_coroutine_threadsafe", + (coro, self.event_loop_handle.event_loop()), + )?; + + Ok::, HandlerError>(future.unbind()) + })?; + + // Wait for accept + match tx_receiver.recv().await { + Some(AcknowledgedMessage { + message: WebSocketSendMessage::Accept { .. }, + ack, + }) => { + // Acknowledge receipt + if ack.send(()).is_err() { + // Python dropped receiver - cannot continue + return Err(HandlerError::WebSocketNotAccepted); + } + } + _ => return Err(HandlerError::WebSocketNotAccepted), + } + + // Spawn WebSocket forwarding task + spawn_websocket_forwarding_task(request_reader, tx_receiver, rx, response_writer, future); + + // Return 101 Switching Protocols response with WebSocket body + http_handler::response::Builder::new() + .status(101) + .extension(WebSocketMode) // Mark response as WebSocket for auto-decoding + .body(response_body) + .map_err(HandlerError::HttpHandlerError) + } else { + // HTTP mode + // Create HTTP scope from parts by temporarily reconstructing a request + let temp_request = Request::from_parts(parts.clone(), RequestBody::new()); + let scope: HttpConnectionScope = (&temp_request).try_into()?; + + // Create ASGI channels + let (rx_receiver, rx) = Receiver::http(); + let (tx_sender, tx_receiver) = Sender::http(); + + // Create oneshot channel for sending response.start (or error) back to main task + let (response_tx, response_rx) = oneshot::channel::(); + + // Submit ASGI app to Python to get the future + let future = Python::attach(|py| { + let scope_py = scope.into_pyobject(py)?; + let coro = self + .app_function + .call1(py, (scope_py, rx_receiver, tx_sender))?; + + let asyncio = py.import("asyncio")?; + let future = asyncio.call_method1( + "run_coroutine_threadsafe", + (coro, self.event_loop_handle.event_loop()), + )?; + + Ok::, HandlerError>(future.unbind()) + })?; + + // Create exception holder to capture errors that occur after response.start + let response_exception = Arc::new(Mutex::new(None)); + let response_exception_clone = Arc::clone(&response_exception); + + // Spawn HTTP forwarding task + spawn_http_forwarding_task( + request_reader, + tx_receiver, + rx, + response_writer, + response_tx, + future, + response_exception_clone, + ); + + // Wait for response.start (errors are propagated from the forwarding task) + let (status, headers) = response_rx + .await + .map_err(|_| HandlerError::NoResponse)? // Channel closed without sending + ?; // Unwrap Result from the task + + // Build and return response with headers and streaming body + let mut builder = http_handler::response::Builder::new().status(status); + for (name, value) in headers { + builder = builder.header(name.as_bytes(), value.as_bytes()); + } + + let mut response = builder + .body(response_body) + .map_err(HandlerError::HttpHandlerError)?; + + // Insert response exception extension so NAPI layer can check for errors after stream ends + // The exception will be set if a Python error occurs during streaming + response.extensions_mut().insert(response_exception); + + Ok(response) + } + } +} + +impl Asgi { + /// Handle a request synchronously (continued for compatibility) + pub fn handle_sync(&self, request: Request) -> Result { + fallback_handle().block_on(self.handle(request)) + } +} + /// Find all Python site-packages directories in a virtual environment fn find_python_site_packages(venv_path: &Path) -> Vec { let mut site_packages_paths = Vec::new(); @@ -319,14 +675,14 @@ fn find_python_site_packages(venv_path: &Path) -> Vec { if let Ok(entries) = read_dir(lib_path) { for entry in entries.flatten() { let entry_path = entry.path(); - if entry_path.is_dir() { - if let Some(dir_name) = entry_path.file_name().and_then(|n| n.to_str()) { - // Look for directories matching python3.* pattern - if dir_name.starts_with("python3.") { - let site_packages = entry_path.join("site-packages"); - if site_packages.exists() { - site_packages_paths.push(site_packages); - } + if entry_path.is_dir() + && let Some(dir_name) = entry_path.file_name().and_then(|n| n.to_str()) + { + // Look for directories matching python3.* pattern + if dir_name.starts_with("python3.") { + let site_packages = entry_path.join("site-packages"); + if site_packages.exists() { + site_packages_paths.push(site_packages); } } } @@ -364,111 +720,1085 @@ fn setup_python_paths(py: Python, docroot: &Path) -> PyResult<()> { Ok(()) } -/// Start a Python thread that runs the event loop forever -fn start_python_event_loop_thread(event_loop: Py) { - Python::attach(|py| { - // Set the event loop for this thread and run it - let asyncio = py.import("asyncio")?; - asyncio.call_method1("set_event_loop", (event_loop.bind(py),))?; - - // Get the current event loop and run it forever - asyncio - .call_method0("get_event_loop")? - .call_method0("run_forever")?; - - Ok::<(), PyErr>(()) - }) - .unwrap_or_else(|e| { - eprintln!("Python event loop thread error: {e}"); +/// Ensure Python is initialized exactly once with proper symbol visibility. +/// +/// This function uses a OnceLock to ensure Python initialization happens only once +/// per process, even when called from multiple threads or tests. +pub(crate) fn ensure_python_initialized() { + use std::sync::OnceLock; + static INIT: OnceLock<()> = OnceLock::new(); + INIT.get_or_init(|| { + // On Linux, load Python library with RTLD_GLOBAL to expose interpreter symbols + #[cfg(target_os = "linux")] + unsafe { + let mut info: libc::Dl_info = mem::zeroed(); + if libc::dladdr(pyo3::ffi::Py_Initialize as *const _, &mut info) == 0 + || info.dli_fname.is_null() + { + eprintln!("unable to locate libpython for RTLD_GLOBAL promotion"); + } else { + let path_cstr = CStr::from_ptr(info.dli_fname); + let path_str = path_cstr.to_string_lossy(); + + // Clear any prior dlerror state before attempting to reopen + libc::dlerror(); + + let handle = libc::dlopen(info.dli_fname, libc::RTLD_NOW | libc::RTLD_GLOBAL); + if handle.is_null() { + let error = libc::dlerror(); + if !error.is_null() { + let msg = CStr::from_ptr(error).to_string_lossy(); + eprintln!("dlopen({path_str}) failed with RTLD_GLOBAL: {msg}",); + } else { + eprintln!("dlopen({path_str}) returned null without dlerror",); + } + } + } + } + + Python::initialize(); }); } -/// Collect ASGI response messages while monitoring for Python exceptions -async fn collect_response_with_exception_handling( - mut tx_receiver: tokio::sync::mpsc::UnboundedReceiver>, - response_tx: oneshot::Sender, - python_future: Py, -) { - let mut status = 500u16; - let mut headers = Vec::new(); - let mut body = Vec::new(); - let mut response_started = false; - - // Spawn a task to monitor the Python future for exceptions - let future_clone = Python::attach(|py| python_future.clone_ref(py)); - let mut exception_handle = tokio::task::spawn_blocking(move || { +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use std::fs; + use std::sync::Arc; + use tokio::io::DuplexStream; + + /// Helper to create a test duplex stream pair + fn create_test_streams() -> (DuplexStream, DuplexStream) { + tokio::io::duplex(1024) + } + + #[test] + fn test_find_python_site_packages_empty() { + // Test with a non-existent directory + let temp_dir = std::env::temp_dir().join("nonexistent_venv"); + let result = find_python_site_packages(&temp_dir); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_find_python_site_packages_with_structure() { + // Create a temporary directory structure that mimics a virtual environment + let temp_dir = std::env::temp_dir().join("test_venv_structure"); + let lib_dir = temp_dir.join("lib"); + let python_dir = lib_dir.join("python3.12"); + let site_packages = python_dir.join("site-packages"); + + // Create the directory structure + fs::create_dir_all(&site_packages).ok(); + + let result = find_python_site_packages(&temp_dir); + + // Verify we found at least one site-packages directory + assert!(!result.is_empty()); + assert!(result.iter().any(|p| p.ends_with("site-packages"))); + + // Cleanup + fs::remove_dir_all(&temp_dir).ok(); + } + + #[test] + fn test_find_python_site_packages_multiple_versions() { + // Create a temporary directory with multiple Python versions + let temp_dir = std::env::temp_dir().join("test_venv_multi"); + + for version in &["python3.11", "python3.12"] { + let lib_dir = temp_dir.join("lib"); + let python_dir = lib_dir.join(version); + let site_packages = python_dir.join("site-packages"); + fs::create_dir_all(&site_packages).ok(); + } + + let result = find_python_site_packages(&temp_dir); + + // Should find both site-packages directories + assert!(result.len() >= 1); + + // Cleanup + fs::remove_dir_all(&temp_dir).ok(); + } + + #[tokio::test] + async fn test_setup_python_paths() { + ensure_python_initialized(); + + Python::attach(|py| { + let docroot = PathBuf::from("/test/docroot"); + let result = setup_python_paths(py, &docroot); + + // Should succeed + assert!(result.is_ok()); + + // Verify sys.path was modified + let sys = py.import("sys").unwrap(); + let path = sys.getattr("path").unwrap(); + let path_list: Vec = path.extract().unwrap(); + + // The docroot should be in the path + assert!(path_list.iter().any(|p| p.contains("test/docroot"))); + }); + } + + #[tokio::test] + async fn test_setup_python_paths_with_venv() { + ensure_python_initialized(); + + // Create a test virtual environment structure + let temp_venv = std::env::temp_dir().join("test_venv_for_paths"); + let lib_dir = temp_venv.join("lib"); + let python_dir = lib_dir.join("python3.12"); + let site_packages = python_dir.join("site-packages"); + fs::create_dir_all(&site_packages).ok(); + + // Set VIRTUAL_ENV + unsafe { + env::set_var("VIRTUAL_ENV", temp_venv.to_string_lossy().to_string()); + } + + Python::attach(|py| { + let docroot = PathBuf::from("/test/docroot"); + let result = setup_python_paths(py, &docroot); + assert!(result.is_ok()); + }); + + // Cleanup + unsafe { + env::remove_var("VIRTUAL_ENV"); + } + fs::remove_dir_all(&temp_venv).ok(); + } + + #[tokio::test] + async fn test_handle_python_exception_with_error() { + ensure_python_initialized(); + + let (tx, mut rx) = oneshot::channel::(); + + let py_err = Python::attach(|_py| { + // Create a Python exception + pyo3::exceptions::PyValueError::new_err("Test error") + }); + + let should_break = + handle_python_exception::(Err(py_err), Some(tx), None, None).await; + + // Should return true to break the loop + assert!(should_break); + + // Should have sent an error + let result = rx.try_recv(); + assert!(result.is_ok()); + let result = result.unwrap(); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_handle_python_exception_without_tx() { + ensure_python_initialized(); + + let py_err = Python::attach(|_py| pyo3::exceptions::PyValueError::new_err("Test error")); + + // Should handle gracefully when no sender is provided + let should_break = + handle_python_exception::(Err(py_err), None, None, None).await; + assert!(should_break); + } + + #[tokio::test] + async fn test_handle_python_exception_with_success() { + ensure_python_initialized(); + + let (tx, _rx) = oneshot::channel::(); + + let py_obj = Python::attach(|py| py.None()); + + let should_break = + handle_python_exception::(Ok(py_obj), Some(tx), None, None).await; + + // Should still return true (Python future completed) + assert!(should_break); + } + + #[test] + fn test_handle_response_timeout() { + let (tx, mut rx) = oneshot::channel::(); + + let should_break = handle_response_timeout(Some(tx)); + + // Should return true to break the loop + assert!(should_break); + + // Should have sent a timeout error + let result = rx.try_recv(); + assert!(result.is_ok()); + let result = result.unwrap(); + assert!(result.is_err()); + + match result { + Err(HandlerError::NoResponse) => (), + _ => panic!("Expected NoResponse error"), + } + } + + #[test] + fn test_handle_response_timeout_no_tx() { + // Should handle gracefully when no sender is provided + let should_break = handle_response_timeout(None); + assert!(should_break); + } + + #[tokio::test] + async fn test_handle_http_response_message_start() { + let (_request_stream, mut response_stream) = create_test_streams(); + + let (tx, mut rx) = oneshot::channel::(); + let mut response_tx = Some(tx); + + let (ack_tx, mut ack_rx) = oneshot::channel::<()>(); + + let msg = Some(AcknowledgedMessage { + message: HttpSendMessage::HttpResponseStart { + status: 200, + headers: vec![("content-type".to_string(), "text/plain".to_string())], + trailers: false, + }, + ack: ack_tx, + }); + + let should_break = + handle_http_response_message(msg, &mut response_tx, &mut response_stream).await; + + // Should not break yet + assert!(!should_break); + + // Response should have been sent + assert!(response_tx.is_none()); + + let result = rx.try_recv(); + assert!(result.is_ok()); + let (status, headers) = result.unwrap().unwrap(); + assert_eq!(status, 200); + assert_eq!(headers.len(), 1); + + // Ack should have been sent + assert!(ack_rx.try_recv().is_ok()); + } + + #[tokio::test] + async fn test_handle_http_response_message_body() { + let (_request_stream, mut response_stream) = create_test_streams(); + + let mut response_tx = None; // Already sent + + let (ack_tx, mut ack_rx) = oneshot::channel::<()>(); + + let msg = Some(AcknowledgedMessage { + message: HttpSendMessage::HttpResponseBody { + body: b"Hello, World!".to_vec(), + more_body: true, + }, + ack: ack_tx, + }); + + let should_break = + handle_http_response_message(msg, &mut response_tx, &mut response_stream).await; + + // Should not break (more_body is true) + assert!(!should_break); + + // Ack should have been sent + assert!(ack_rx.try_recv().is_ok()); + } + + #[tokio::test] + async fn test_handle_http_response_message_body_final() { + let (_request_stream, mut response_stream) = create_test_streams(); + + let mut response_tx = None; + + let (ack_tx, mut ack_rx) = oneshot::channel::<()>(); + + let msg = Some(AcknowledgedMessage { + message: HttpSendMessage::HttpResponseBody { + body: b"Final chunk".to_vec(), + more_body: false, // Final chunk + }, + ack: ack_tx, + }); + + let should_break = + handle_http_response_message(msg, &mut response_tx, &mut response_stream).await; + + // Should break (final chunk) + assert!(should_break); + + // Ack should have been sent + assert!(ack_rx.try_recv().is_ok()); + } + + #[tokio::test] + async fn test_handle_http_response_message_none() { + let (_request_stream, mut response_stream) = create_test_streams(); + + let mut response_tx = None; + + let should_break = + handle_http_response_message(None, &mut response_tx, &mut response_stream).await; + + // Should break (channel closed) + assert!(should_break); + } + + #[tokio::test] + async fn test_handle_websocket_response_message_send_text() { + let (_request_stream, response_stream) = create_test_streams(); + let encoder = WebSocketEncoder::new(response_stream); + + let (ack_tx, mut ack_rx) = oneshot::channel::<()>(); + + let msg = Some(AcknowledgedMessage { + message: WebSocketSendMessage::Send { + text: Some("Hello, WebSocket!".to_string()), + bytes: None, + }, + ack: ack_tx, + }); + + let should_break = handle_websocket_response_message(msg, &encoder).await; + + // Should not break + assert!(!should_break); + + // Ack should have been sent + assert!(ack_rx.try_recv().is_ok()); + } + + #[tokio::test] + async fn test_handle_websocket_response_message_send_bytes() { + let (_request_stream, response_stream) = create_test_streams(); + let encoder = WebSocketEncoder::new(response_stream); + + let (ack_tx, mut ack_rx) = oneshot::channel::<()>(); + + let msg = Some(AcknowledgedMessage { + message: WebSocketSendMessage::Send { + text: None, + bytes: Some(vec![1, 2, 3, 4]), + }, + ack: ack_tx, + }); + + let should_break = handle_websocket_response_message(msg, &encoder).await; + + // Should not break + assert!(!should_break); + + // Ack should have been sent + assert!(ack_rx.try_recv().is_ok()); + } + + #[tokio::test] + async fn test_handle_websocket_response_message_close() { + let (_request_stream, response_stream) = create_test_streams(); + let encoder = WebSocketEncoder::new(response_stream); + + let (ack_tx, _ack_rx) = oneshot::channel::<()>(); + + let msg = Some(AcknowledgedMessage { + message: WebSocketSendMessage::Close { + code: Some(1000), + reason: Some("Normal closure".to_string()), + }, + ack: ack_tx, + }); + + let should_break = handle_websocket_response_message(msg, &encoder).await; + + // Should break (close message) + assert!(should_break); + } + + #[tokio::test] + async fn test_handle_websocket_response_message_none() { + let (_request_stream, response_stream) = create_test_streams(); + let encoder = WebSocketEncoder::new(response_stream); + + let should_break = handle_websocket_response_message(None, &encoder).await; + + // Should break (channel closed) + assert!(should_break); + } + + #[tokio::test] + async fn test_forward_http_request_with_data() { + let (request_stream, write_stream) = create_test_streams(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + + // Write some data to the stream + tokio::spawn(async move { + let mut stream = write_stream; + stream.write_all(b"Test data").await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + // Start forwarding + tokio::spawn(forward_http_request(request_stream, tx)); + + // Should receive the data + let msg = rx.recv().await; + assert!(msg.is_some()); + + match msg.unwrap() { + HttpReceiveMessage::Request { body, more_body } => { + assert_eq!(body, b"Test data"); + assert!(more_body); + } + HttpReceiveMessage::Disconnect => panic!("Expected Request message"), + } + + // Should receive EOF + let msg = rx.recv().await; + assert!(msg.is_some()); + + match msg.unwrap() { + HttpReceiveMessage::Request { body, more_body } => { + assert!(body.is_empty()); + assert!(!more_body); + } + HttpReceiveMessage::Disconnect => panic!("Expected Request message"), + } + } + + #[tokio::test] + async fn test_forward_http_request_empty_stream() { + let (request_stream, write_stream) = create_test_streams(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + + // Close the stream immediately + drop(write_stream); + + // Start forwarding + tokio::spawn(forward_http_request(request_stream, tx)); + + // Should receive EOF immediately + let msg = rx.recv().await; + assert!(msg.is_some()); + + match msg.unwrap() { + HttpReceiveMessage::Request { body, more_body } => { + assert!(body.is_empty()); + assert!(!more_body); + } + HttpReceiveMessage::Disconnect => panic!("Expected Request message"), + } + } + + #[tokio::test] + async fn test_forward_websocket_request_text_frame() { + use http_handler::websocket::WebSocketEncoder; + + let (request_stream, write_stream) = create_test_streams(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + + // Write a text frame + tokio::spawn(async move { + let encoder = WebSocketEncoder::new(write_stream); + encoder + .write_text("Hello, WebSocket!", false) + .await + .unwrap(); + encoder.end().await.unwrap(); + }); + + // Start forwarding + tokio::spawn(forward_websocket_request(request_stream, tx)); + + // Should receive the text message + let msg = rx.recv().await; + assert!(msg.is_some()); + + match msg.unwrap() { + WebSocketReceiveMessage::Receive { text, bytes } => { + assert_eq!(text, Some("Hello, WebSocket!".to_string())); + assert!(bytes.is_none()); + } + _ => panic!("Expected Receive message"), + } + } + + #[tokio::test] + async fn test_forward_websocket_request_binary_frame() { + use http_handler::websocket::WebSocketEncoder; + + let (request_stream, write_stream) = create_test_streams(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + + // Write a binary frame + tokio::spawn(async move { + let encoder = WebSocketEncoder::new(write_stream); + encoder.write_binary(&[1, 2, 3, 4], false).await.unwrap(); + encoder.end().await.unwrap(); + }); + + // Start forwarding + tokio::spawn(forward_websocket_request(request_stream, tx)); + + // Should receive the binary message + let msg = rx.recv().await; + assert!(msg.is_some()); + + match msg.unwrap() { + WebSocketReceiveMessage::Receive { text, bytes } => { + assert!(text.is_none()); + assert_eq!(bytes, Some(vec![1, 2, 3, 4])); + } + _ => panic!("Expected Receive message"), + } + } + + #[tokio::test] + async fn test_forward_websocket_request_close_frame() { + use http_handler::websocket::WebSocketEncoder; + + let (request_stream, write_stream) = create_test_streams(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + + // Write a close frame + tokio::spawn(async move { + let encoder = WebSocketEncoder::new(write_stream); + encoder + .write_close(Some(1000), Some("Normal closure")) + .await + .unwrap(); + }); + + // Start forwarding + tokio::spawn(forward_websocket_request(request_stream, tx)); + + // Should receive the disconnect message + let msg = rx.recv().await; + assert!(msg.is_some()); + + match msg.unwrap() { + WebSocketReceiveMessage::Disconnect { code, reason } => { + assert_eq!(code, Some(1000)); + assert!(reason.is_some()); + } + _ => panic!("Expected Disconnect message"), + } + } + + #[test] + fn test_ensure_python_initialized_idempotent() { + // Should be safe to call multiple times + ensure_python_initialized(); + ensure_python_initialized(); + ensure_python_initialized(); + + // Python should be initialized Python::attach(|py| { - let future_bound = future_clone.bind(py); - // Wait for the future to complete (with 30 second timeout) - match future_bound.call_method1("result", (30.0,)) { - Ok(_) => None, // Success - no exception - Err(e) => Some(e), // Exception occurred + // Should be able to use Python + let sys = py.import("sys").unwrap(); + assert!(sys.hasattr("version").unwrap()); + }); + } + + // ===== Integration Tests ===== + // These tests verify the full Asgi handler flow end-to-end + + /// Helper to create a test request and spawn body writing task + /// Returns the request immediately while body writing happens concurrently + fn create_test_request(method: &str, path: &str, body: Vec) -> Request { + use http_handler::{Method, Uri, Version}; + use tokio::io::AsyncWriteExt; + + let method_enum = match method { + "GET" => Method::GET, + "POST" => Method::POST, + "PUT" => Method::PUT, + "DELETE" => Method::DELETE, + _ => Method::GET, + }; + + let uri: Uri = path.parse().unwrap(); + + // Build a basic HTTP request using ::http::Request builder + let http_request = ::http::Request::builder() + .method(method_enum) + .uri(uri) + .version(Version::HTTP_11) + .body(()) + .unwrap(); + + // Split into parts and body + let (parts, _) = http_request.into_parts(); + + // Create request from parts with proper body + let request = Request::from_parts(parts, RequestBody::new()); + + // Spawn a task to write body data to the request stream + // This allows the request to be returned immediately while body writing happens concurrently + let mut body_writer = request.body().clone(); + tokio::spawn(async move { + if !body.is_empty() { + body_writer.write_all(&body).await.unwrap(); } - }) - }); - loop { - tokio::select! { - // Check for messages from the ASGI app - msg = tx_receiver.recv() => { - match msg { - Some(ack_msg) => { - let AcknowledgedMessage { message, ack } = ack_msg; - - match message { - HttpSendMessage::HttpResponseStart { - status: s, - headers: h, - .. - } => { - status = s; - headers = h; - response_started = true; - } - HttpSendMessage::HttpResponseBody { body: b, more_body } => { - if response_started { - body.extend_from_slice(&b); - if !more_body { - let _ = ack.send(()); - let _ = response_tx.send(Ok((status, headers, body))); - return; - } - } - } - } + // Always close the stream when done + body_writer.shutdown().await.unwrap(); + }); - let _ = ack.send(()); - } - None => { - // Channel closed without a complete response - let _ = response_tx.send(Err(if response_started { - HandlerError::ResponseInterrupted - } else { - HandlerError::NoResponse - })); - return; - } + request + } + + /// Helper to read full response body + async fn read_response_body(response: Response) -> (u16, Vec<(String, String)>, Vec) { + use http_body_util::BodyExt; + + let (parts, mut body) = response.into_parts(); + let status = parts.status.as_u16(); + + let headers: Vec<(String, String)> = parts + .headers + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let mut body_bytes = Vec::new(); + while let Some(result) = body.frame().await { + if let Ok(frame) = result { + if let Ok(data) = frame.into_data() { + body_bytes.extend_from_slice(&data); } } - // Check if the Python coroutine raised an exception - exception_result = &mut exception_handle => { - match exception_result { - Ok(Some(py_err)) => { - // Python exception occurred - let _ = response_tx.send(Err(HandlerError::PythonError(py_err))); - return; - } - Ok(None) => { - // Python coroutine completed successfully - // Continue waiting for response messages - } - Err(e) => { - // Tokio task error - let _ = response_tx.send(Err(HandlerError::TokioError(e.to_string()))); - return; - } - } + } + + (status, headers, body_bytes) + } + + #[tokio::test] + async fn test_asgi_integration_basic_request() { + ensure_python_initialized(); + + // Create Asgi handler with echo_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "echo_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create a simple POST request with body + let request = create_test_request("POST", "/test/path", b"Hello, World!".to_vec()); + + // Handle the request + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + + // Read the full response + let (status, headers, body) = read_response_body(response).await; + + // Verify response + assert_eq!(status, 200); + assert!( + headers + .iter() + .any(|(k, v)| k == "content-type" && v.contains("application/json")) + ); + assert!( + headers + .iter() + .any(|(k, v)| k == "x-echo-method" && v == "POST") + ); + assert!( + headers + .iter() + .any(|(k, v)| k == "x-echo-path" && v == "/test/path") + ); + + let body_str = String::from_utf8(body).unwrap(); + assert!(body_str.contains("Hello, World!")); + assert!(body_str.contains("POST")); + assert!(body_str.contains("/test/path")); + } + + #[tokio::test] + async fn test_asgi_integration_streaming_response() { + ensure_python_initialized(); + + // Create Asgi handler with stream_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "stream_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create a GET request to streaming endpoint + let request = create_test_request("GET", "/stream?count=3", vec![]); + + // Handle the request - this should return IMMEDIATELY after headers are ready + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + + // Verify we got headers back (status code available) + let status = response.status().as_u16(); + assert_eq!(status, 200, "Should get 200 status immediately"); + + // Now read the streaming body + let (_, headers, body) = read_response_body(response).await; + + // Verify response headers + assert!( + headers + .iter() + .any(|(k, v)| k == "content-type" && v.contains("text/plain")) + ); + + // Verify we got all chunks + let body_str = String::from_utf8(body).unwrap(); + assert!(body_str.contains("Chunk 1")); + assert!(body_str.contains("Chunk 2")); + assert!(body_str.contains("Chunk 3")); + } + + #[tokio::test] + async fn test_asgi_integration_early_header_return() { + ensure_python_initialized(); + + // Create Asgi handler with stream_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "stream_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create request + let request = create_test_request("GET", "/stream?count=5", vec![]); + + // Track timing - handle() should return quickly with just headers + let start = std::time::Instant::now(); + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + let header_time = start.elapsed(); + + // Headers should be available almost immediately (well before all chunks) + // The stream_app has 0.01s delay per chunk, so 5 chunks = ~50ms + // We should get headers back in much less time + assert!( + header_time.as_millis() < 30, + "Headers should return quickly, took {}ms", + header_time.as_millis() + ); + + // Verify we have a valid response with headers + assert_eq!(response.status().as_u16(), 200); + assert!(response.headers().get("content-type").is_some()); + + // The body stream should still be available for reading + let (_, _, body) = read_response_body(response).await; + let body_str = String::from_utf8(body).unwrap(); + + // Verify all chunks arrived + for i in 1..=5 { + assert!(body_str.contains(&format!("Chunk {}", i))); + } + } + + #[tokio::test] + async fn test_asgi_integration_streaming_request_body() { + ensure_python_initialized(); + + // Create Asgi handler with echo_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "echo_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create request with body + let large_body = "x".repeat(10000); // 10KB body + let request = create_test_request("POST", "/test", large_body.as_bytes().to_vec()); + + // Handle request + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + + // Verify response + let (status, _, body) = read_response_body(response).await; + assert_eq!(status, 200); + + let body_str = String::from_utf8(body).unwrap(); + // The echo app should echo back our large body + assert!( + body_str.contains(&large_body[..100]), + "Should contain start of body" + ); + } + + #[tokio::test] + async fn test_asgi_integration_websocket_connection() { + ensure_python_initialized(); + + // Create Asgi handler with websocket_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "websocket_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create WebSocket upgrade request + let mut request = create_test_request("GET", "/echo", vec![]); + + // Add WebSocket mode extension + request.extensions_mut().insert(http_handler::WebSocketMode); + + // Handle request + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + + // Verify we got 101 Switching Protocols + assert_eq!( + response.status().as_u16(), + 101, + "Should get 101 for WebSocket upgrade" + ); + + // Response body should be the WebSocket stream + // We can't easily test the full WebSocket flow here without more infrastructure, + // but we've verified the upgrade works + } + + #[tokio::test] + async fn test_asgi_integration_error_handling() { + ensure_python_initialized(); + + // Create Asgi handler with error_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "error_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create request that should trigger error + let request = create_test_request("GET", "/error", vec![]); + + // Handle request - should return error + let result = asgi.handle(request).await; + + // Should get an error + assert!(result.is_err(), "Error path should return an error"); + + match result { + Err(HandlerError::PythonError(_)) => { + // Expected - Python raised an exception } + Err(e) => panic!("Expected PythonError, got: {:?}", e), + Ok(_) => panic!("Expected error but got success"), + } + } + + #[tokio::test] + async fn test_asgi_integration_status_codes() { + ensure_python_initialized(); + + // Create Asgi handler with status_app + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "status_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Test different status codes + for status_code in &[200, 201, 404, 500] { + let request = create_test_request("GET", &format!("/status/{}", status_code), vec![]); + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + + assert_eq!( + response.status().as_u16(), + *status_code, + "Should return status code {}", + status_code + ); + } + } + + #[tokio::test] + async fn test_asgi_integration_concurrent_requests() { + ensure_python_initialized(); + + // Create Asgi handler + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Arc::new( + Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "echo_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"), + ); + + // Launch multiple concurrent requests + let mut handles = vec![]; + + for i in 0..10 { + let asgi = Arc::clone(&asgi); + let handle = tokio::spawn(async move { + let body = format!("Request {}", i); + let request = create_test_request("POST", "/test", body.as_bytes().to_vec()); + + let response = asgi + .handle(request) + .await + .expect("Failed to handle request"); + let (status, _, response_body) = read_response_body(response).await; + + assert_eq!(status, 200); + let response_str = String::from_utf8(response_body).unwrap(); + assert!(response_str.contains(&format!("Request {}", i))); + }); + handles.push(handle); } + + // Wait for all requests to complete + for handle in handles { + handle.await.expect("Task failed"); + } + } + + #[tokio::test] + async fn test_asgi_docroot() { + ensure_python_initialized(); + + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "echo_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Verify docroot is set correctly + assert_eq!(asgi.docroot(), test_fixtures.as_path()); + } + + /// Test that replicates the exact NAPI layer pattern to see if we can reproduce the hang + #[test] + fn test_asgi_integration_napi_pattern() { + use tokio::io::AsyncWriteExt; + + ensure_python_initialized(); + + // Create Asgi handler + let test_fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/fixtures"); + let asgi = Asgi::new( + Some(test_fixtures.to_string_lossy().to_string()), + Some(PythonHandlerTarget { + file: "echo_app".to_string(), + function: "app".to_string(), + }), + ) + .expect("Failed to create Asgi handler"); + + // Create request without spawning concurrent body writer + let method_enum = http_handler::Method::POST; + let uri: http_handler::Uri = "/test/path".parse().unwrap(); + let http_request = ::http::Request::builder() + .method(method_enum) + .uri(uri) + .version(http_handler::Version::HTTP_11) + .body(()) + .unwrap(); + let (parts, _) = http_request.into_parts(); + let mut request = Request::from_parts(parts, RequestBody::new()); + + // Replicate the NAPI pattern: use fallback_handle().block_on() + let response = super::fallback_handle().block_on(async { + let body_data = b"Hello, World!"; + + // Write body data synchronously (like NAPI layer does) + { + let body = request.body_mut(); + body + .write_all(body_data) + .await + .expect("Failed to write body"); + } + + // Shutdown stream (like NAPI layer does for non-WebSocket) + { + let body = request.body_mut(); + body.shutdown().await.expect("Failed to shutdown stream"); + } + + // Now call handle + asgi + .handle(request) + .await + .expect("Failed to handle request") + }); + + // Read and verify response + let (status, headers, body) = fallback_handle().block_on(read_response_body(response)); + assert_eq!(status, 200); + assert!( + headers + .iter() + .any(|(k, v)| k == "content-type" && v.contains("application/json")) + ); + + let body_str = String::from_utf8(body).expect("Invalid UTF-8 in response body"); + assert!( + body_str.contains("Hello, World!"), + "Response should echo the request body" + ); } } diff --git a/src/asgi/python_future_poller.rs b/src/asgi/python_future_poller.rs new file mode 100644 index 0000000..9429f70 --- /dev/null +++ b/src/asgi/python_future_poller.rs @@ -0,0 +1,190 @@ +//! Python Future polling implementation for Rust async integration. +//! +//! This module provides [`PythonFuturePoller`], a type that implements the Rust +//! [`Future`] trait by polling a Python `concurrent.futures.Future` object. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pyo3::prelude::*; + +/// Future that polls a Python concurrent.futures.Future for completion. +/// +/// This future polls a Python concurrent.futures.Future and returns a Result +/// containing either the success value or the exception when the future completes. +pub struct PythonFuturePoller(Py); + +impl PythonFuturePoller { + /// Create a new `PythonFuturePoller` from a Python future object. + /// + /// # Arguments + /// + /// * `future` - A Python `concurrent.futures.Future` or `asyncio.Future` object + pub fn new(future: Py) -> Self { + Self(future) + } +} + +impl Future for PythonFuturePoller { + type Output = Result, PyErr>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Python::attach(|py| { + let future_bound = self.0.bind(py); + + // First check if future is done + let is_done: bool = future_bound + .call_method0("done") + .ok() + .and_then(|result| result.extract().ok()) + .unwrap_or(false); + + if is_done { + // Future is done - get the result (Ok for success, Err for exception) + Poll::Ready(match future_bound.call_method0("result") { + Ok(value) => Ok(value.unbind()), + Err(err) => Err(err), + }) + } else { + // Not done yet, wake the task to poll again + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::asgi::ensure_python_initialized; + + /// Ensure Python is initialized for tests (only once) + fn ensure_test_python() { + ensure_python_initialized(); + } + + #[test] + fn test_python_future_poller_creation() { + ensure_test_python(); + + Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let loop_ = asyncio.call_method0("new_event_loop").unwrap(); + let future = loop_.call_method0("create_future").unwrap(); + + let _poller = PythonFuturePoller::new(future.unbind()); + // Just verify we can create it + }); + } + + #[tokio::test] + async fn test_python_future_poller_with_completed_future() { + ensure_test_python(); + + let future = Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let loop_ = asyncio.call_method0("new_event_loop").unwrap(); + let future = loop_.call_method0("create_future").unwrap(); + + // Immediately complete the future with a result + future.call_method1("set_result", (42,)).unwrap(); + + future.unbind() + }); + + let poller = PythonFuturePoller::new(future); + let result = poller.await; + + assert!(result.is_ok()); + Python::attach(|py| { + let value: i32 = result.unwrap().extract(py).unwrap(); + assert_eq!(value, 42); + }); + } + + #[tokio::test] + async fn test_python_future_poller_with_exception() { + ensure_test_python(); + + let future = Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let loop_ = asyncio.call_method0("new_event_loop").unwrap(); + let future = loop_.call_method0("create_future").unwrap(); + + // Set an exception on the future + let exception = py + .import("builtins") + .unwrap() + .getattr("ValueError") + .unwrap() + .call1(("test error",)) + .unwrap(); + future.call_method1("set_exception", (exception,)).unwrap(); + + future.unbind() + }); + + let poller = PythonFuturePoller::new(future); + let result = poller.await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + let err_str = format!("{:?}", err); + assert!( + err_str.contains("ValueError") || err_str.contains("test error"), + "Expected ValueError with 'test error', got: {}", + err_str + ); + } + + #[tokio::test] + async fn test_python_future_poller_with_string_result() { + ensure_test_python(); + + let future = Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let loop_ = asyncio.call_method0("new_event_loop").unwrap(); + let future = loop_.call_method0("create_future").unwrap(); + + // Complete with a string + future.call_method1("set_result", ("hello world",)).unwrap(); + + future.unbind() + }); + + let poller = PythonFuturePoller::new(future); + let result = poller.await; + + assert!(result.is_ok()); + Python::attach(|py| { + let value: String = result.unwrap().extract(py).unwrap(); + assert_eq!(value, "hello world"); + }); + } + + #[tokio::test] + async fn test_python_future_poller_with_none() { + ensure_test_python(); + + let future = Python::attach(|py| { + let asyncio = py.import("asyncio").unwrap(); + let loop_ = asyncio.call_method0("new_event_loop").unwrap(); + let future = loop_.call_method0("create_future").unwrap(); + + // Complete with None + future.call_method1("set_result", (py.None(),)).unwrap(); + + future.unbind() + }); + + let poller = PythonFuturePoller::new(future); + let result = poller.await; + + assert!(result.is_ok()); + Python::attach(|py| { + assert!(result.unwrap().is_none(py)); + }); + } +} diff --git a/src/asgi/receiver.rs b/src/asgi/receiver.rs index e8c9622..2dbe3d6 100644 --- a/src/asgi/receiver.rs +++ b/src/asgi/receiver.rs @@ -74,3 +74,95 @@ impl Receiver { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::asgi::ensure_python_initialized; + use crate::asgi::http::HttpReceiveMessage; + use crate::asgi::lifespan::LifespanReceiveMessage; + use crate::asgi::websocket::WebSocketReceiveMessage; + + #[test] + fn test_receiver_http_creation() { + ensure_python_initialized(); + + let (_receiver, tx) = Receiver::http(); + // Verify we can send a message + let result = tx.send(HttpReceiveMessage::Request { + body: vec![], + more_body: false, + }); + assert!(result.is_ok(), "Should be able to send message"); + } + + #[test] + fn test_receiver_websocket_creation() { + ensure_python_initialized(); + + let (_receiver, tx) = Receiver::websocket(); + // Verify we can send a message + let result = tx.send(WebSocketReceiveMessage::Connect); + assert!(result.is_ok(), "Should be able to send message"); + } + + #[test] + fn test_receiver_lifespan_creation() { + ensure_python_initialized(); + + let (_receiver, tx) = Receiver::lifespan(); + // Verify we can send a message + let result = tx.send(LifespanReceiveMessage::LifespanStartup); + assert!(result.is_ok(), "Should be able to send message"); + } + + #[tokio::test] + async fn test_receiver_http_message_flow() { + ensure_python_initialized(); + + let (_receiver, tx) = Receiver::http(); + + // Send a message and verify it succeeds + let message = HttpReceiveMessage::Request { + body: b"test body".to_vec(), + more_body: false, + }; + let result = tx.send(message); + assert!( + result.is_ok(), + "Should be able to send message through channel" + ); + } + + #[tokio::test] + async fn test_receiver_websocket_message_flow() { + ensure_python_initialized(); + + let (_receiver, tx) = Receiver::websocket(); + + // Send a message and verify it succeeds + let message = WebSocketReceiveMessage::Receive { + bytes: None, + text: Some("test message".to_string()), + }; + let result = tx.send(message); + assert!( + result.is_ok(), + "Should be able to send message through channel" + ); + } + + #[tokio::test] + async fn test_receiver_lifespan_message_flow() { + ensure_python_initialized(); + + let (_receiver, tx) = Receiver::lifespan(); + + // Send multiple messages and verify they succeed + let result1 = tx.send(LifespanReceiveMessage::LifespanStartup); + assert!(result1.is_ok(), "Should be able to send startup message"); + + let result2 = tx.send(LifespanReceiveMessage::LifespanShutdown); + assert!(result2.is_ok(), "Should be able to send shutdown message"); + } +} diff --git a/src/asgi/runtime_handle.rs b/src/asgi/runtime_handle.rs new file mode 100644 index 0000000..1941283 --- /dev/null +++ b/src/asgi/runtime_handle.rs @@ -0,0 +1,139 @@ +//! Tokio runtime handle management for async operations. +//! +//! This module provides [`fallback_handle()`], a function that ensures a Tokio +//! runtime handle is available for async operations, creating a fallback runtime +//! if necessary. + +use std::sync::OnceLock; + +/// Global fallback runtime for when no tokio runtime is available +static FALLBACK_RUNTIME: OnceLock = OnceLock::new(); + +/// Get a tokio runtime handle, creating a fallback if needed. +/// +/// This function attempts to get the handle of the current tokio runtime. +/// If no runtime is available (i.e., not running within a tokio context), +/// it creates and returns a handle to a global fallback runtime. +/// +/// # Returns +/// +/// A `tokio::runtime::Handle` that can be used to spawn tasks and perform +/// async operations. +/// +/// # Panics +/// +/// Panics if creating the fallback runtime fails, though this is extremely +/// unlikely in normal operation. +/// +/// # Thread Safety +/// +/// This function is thread-safe. The fallback runtime is created once and +/// shared across all threads that need it. +pub(crate) fn fallback_handle() -> tokio::runtime::Handle { + tokio::runtime::Handle::try_current().unwrap_or_else(|_| { + // No runtime exists, create a fallback one + let rt = FALLBACK_RUNTIME.get_or_init(|| { + tokio::runtime::Runtime::new().expect("Failed to create fallback tokio runtime") + }); + rt.handle().clone() + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fallback_handle_outside_runtime() { + // This test runs outside a tokio runtime + // Should create and use the fallback runtime + let handle = fallback_handle(); + + // Verify we can use the handle to block_on a future + handle.block_on(async { + // Simple async operation + let value = 42; + assert_eq!(value, 42); + }); + } + + #[tokio::test] + async fn test_fallback_handle_inside_runtime() { + // This test runs inside a tokio runtime + // Should use the current runtime's handle + let handle = fallback_handle(); + + // Verify we can spawn a task + let result = handle + .spawn(async { + // Simple async operation + 42 + }) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_fallback_handle_is_consistent() { + // Calling fallback_handle multiple times outside a runtime + // should return handles to the same fallback runtime + let handle1 = fallback_handle(); + let handle2 = fallback_handle(); + + // Both handles should work + handle1.block_on(async { + assert!(true); + }); + + handle2.block_on(async { + assert!(true); + }); + } + + #[test] + fn test_fallback_handle_can_spawn_blocking() { + // Test that we can use spawn_blocking with the handle + let handle = fallback_handle(); + + let result = handle.block_on(async { + handle + .spawn_blocking(|| { + // CPU-bound work + let mut sum = 0; + for i in 0..100 { + sum += i; + } + sum + }) + .await + }); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 4950); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_fallback_handle_multi_thread() { + // Test that fallback_handle works correctly in a multi-threaded runtime + let handle = fallback_handle(); + + // Spawn multiple tasks concurrently + let tasks: Vec<_> = (0..10) + .map(|i| { + handle.spawn(async move { + // Simple async work + i * 2 + }) + }) + .collect(); + + // Wait for all tasks to complete + for (i, task) in tasks.into_iter().enumerate() { + let result = task.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), i * 2); + } + } +} diff --git a/src/asgi/sender.rs b/src/asgi/sender.rs index 602eeb6..4be5946 100644 --- a/src/asgi/sender.rs +++ b/src/asgi/sender.rs @@ -100,3 +100,88 @@ impl Sender { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::asgi::ensure_python_initialized; + + #[test] + fn test_sender_http_creation() { + ensure_python_initialized(); + + let (_sender, mut rx) = Sender::http(); + // Verify receiver is open + assert!(rx.try_recv().is_err(), "Channel should be empty but open"); + } + + #[test] + fn test_sender_websocket_creation() { + ensure_python_initialized(); + + let (_sender, mut rx) = Sender::websocket(); + // Verify receiver is open + assert!(rx.try_recv().is_err(), "Channel should be empty but open"); + } + + #[test] + fn test_sender_lifespan_creation() { + ensure_python_initialized(); + + let (_sender, mut rx) = Sender::lifespan(); + // Verify receiver is open + assert!(rx.try_recv().is_err(), "Channel should be empty but open"); + } + + #[test] + fn test_sender_http_channel_closed() { + ensure_python_initialized(); + + let (sender, rx) = Sender::http(); + + // Drop the receiver to close the channel + drop(rx); + + // Sender should still exist but attempts to send will fail + // We can't easily test the __call__ method without Python, but we've verified + // the channel setup works + drop(sender); + } + + #[test] + fn test_sender_websocket_channel_closed() { + ensure_python_initialized(); + + let (sender, rx) = Sender::websocket(); + + // Drop the receiver to close the channel + drop(rx); + + // Sender should still exist + drop(sender); + } + + #[test] + fn test_sender_lifespan_channel_closed() { + ensure_python_initialized(); + + let (sender, rx) = Sender::lifespan(); + + // Drop the receiver to close the channel + drop(rx); + + // Sender should still exist + drop(sender); + } + + #[tokio::test] + async fn test_acknowledged_message_structure() { + ensure_python_initialized(); + + let (_sender, mut rx) = Sender::http(); + + // We can't easily send through the sender without Python, but we can verify + // the receiver side works + assert!(rx.try_recv().is_err(), "Channel should be empty initially"); + } +} diff --git a/src/asgi/websocket.rs b/src/asgi/websocket.rs index b2aa8fd..3cc9db2 100644 --- a/src/asgi/websocket.rs +++ b/src/asgi/websocket.rs @@ -1,6 +1,8 @@ +use http_handler::{Request, RequestExt, Version}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; +use std::net::SocketAddr; use crate::asgi::{AsgiInfo, HttpVersion}; @@ -58,6 +60,99 @@ pub struct WebSocketConnectionScope { state: Option>, } +impl TryFrom<&Request> for WebSocketConnectionScope { + type Error = PyErr; + + fn try_from(request: &Request) -> Result { + // Extract HTTP version + let http_version = match request.version() { + Version::HTTP_10 => HttpVersion::V1_0, + Version::HTTP_11 => HttpVersion::V1_1, + Version::HTTP_2 => HttpVersion::V2_0, + Version::HTTP_3 => HttpVersion::V2_0, // treat HTTP/3 as HTTP/2 for ASGI + _ => HttpVersion::V1_1, // default fallback + }; + + // Extract scheme from URI (typically wss or ws for WebSocket) + let scheme = request + .uri() + .scheme_str() + .map(|s| { + if s == "https" || s == "wss" { + "wss" + } else { + "ws" + } + }) + .unwrap_or("ws") + .to_string(); + + // Extract path + let path = request.uri().path().to_string(); + + // Extract raw path (same as path for now, as we don't have the raw bytes) + let raw_path = path.clone(); + + // Extract query string + let query_string = request.uri().query().unwrap_or("").to_string(); + + // Extract root path from DocumentRoot extension + let root_path = request + .document_root() + .map(|doc_root| doc_root.path.to_string_lossy().to_string()) + .unwrap_or_default(); + + // Convert headers + let headers: Vec<(String, String)> = request + .headers() + .iter() + .map(|(name, value)| { + ( + name.as_str().to_lowercase(), + value.to_str().unwrap_or("").to_string(), + ) + }) + .collect(); + + // Extract client and server from socket info if available + let (client, server) = if let Some(socket_info) = request.socket_info() { + let client = socket_info.remote.map(|addr| match addr { + SocketAddr::V4(v4) => (v4.ip().to_string(), v4.port()), + SocketAddr::V6(v6) => (v6.ip().to_string(), v6.port()), + }); + let server = socket_info.local.map(|addr| match addr { + SocketAddr::V4(v4) => (v4.ip().to_string(), v4.port()), + SocketAddr::V6(v6) => (v6.ip().to_string(), v6.port()), + }); + (client, server) + } else { + (None, None) + }; + + // Extract subprotocols from Sec-WebSocket-Protocol header + let subprotocols: Vec = request + .headers() + .get("sec-websocket-protocol") + .and_then(|h| h.to_str().ok()) + .map(|protocols| protocols.split(',').map(|p| p.trim().to_string()).collect()) + .unwrap_or_default(); + + Ok(WebSocketConnectionScope { + http_version, + scheme, + path, + raw_path, + query_string, + root_path, + headers, + client, + server, + subprotocols, + state: None, + }) + } +} + impl<'py> IntoPyObject<'py> for WebSocketConnectionScope { type Target = PyDict; type Output = Bound<'py, Self::Target>; @@ -234,10 +329,7 @@ impl<'py> FromPyObject<'py> for WebSocketSendMessage { let headers: Vec<(String, String)> = dict .get_item("headers")? - .ok_or_else(|| { - PyValueError::new_err("Missing 'headers' key in WebSocket accept message") - })? - .extract()?; + .map_or(Ok(vec![]), |v| v.extract())?; Ok(WebSocketSendMessage::Accept { subprotocol, diff --git a/src/lib.rs b/src/lib.rs index f6830a0..1c7dc97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,10 +11,13 @@ #[cfg(feature = "napi-support")] use std::{ffi::c_char, sync::Arc}; +#[cfg(feature = "napi-support")] +use bytes::{Bytes, BytesMut}; + #[cfg(feature = "napi-support")] use http_handler::napi::{Request as NapiRequest, Response as NapiResponse}; #[cfg(feature = "napi-support")] -use http_handler::{Handler, Request, Response}; +use http_handler::{BodyBuffer, Handler, Request, Response, ResponseBody}; #[cfg(feature = "napi-support")] #[allow(unused_imports)] use http_rewriter::napi::Rewriter; @@ -213,7 +216,11 @@ impl PythonHandler { self.asgi.docroot().display().to_string() } - /// Handle a Python request. + /// Handle a Python request with buffered response (backward compatible). + /// + /// This method uses the same asgi.handle() as handleStream, but buffers the + /// response body before returning. The body is available synchronously via + /// response.body getter. /// /// # Examples /// @@ -229,30 +236,60 @@ impl PythonHandler { /// })); /// /// console.log(response.status); - /// console.log(response.body); + /// console.log(response.body.toString()); // Body is buffered and ready /// ``` #[napi] - pub async fn handle_request(&self, request: &NapiRequest) -> Result { - let response = self - .asgi - .handle(request.clone().into_inner()) - .await - .map_err(|e| Error::from_reason(e.to_string()))?; - Ok(response.into()) + pub fn handle_request( + &self, + request: NapiRequest, + signal: Option, + ) -> AsyncTask { + AsyncTask::with_optional_signal( + PythonRequestTask { + asgi: self.asgi.clone(), + request: Some(request.into_inner()), + }, + signal, + ) + } + + /// Handle a Python request with streaming response. + /// + /// This method uses the same asgi.handle() as handleRequest, but returns + /// immediately with a streaming response. Use AsyncIterator to read chunks. + /// + /// # Examples + /// + /// ```js + /// const python = new Python({ + /// docroot: process.cwd(), + /// argv: process.argv + /// }); + /// + /// const response = await python.handleStream(new Request({ + /// method: 'GET', + /// url: 'http://example.com' + /// })); + /// + /// // Read response via AsyncIterator + /// for await (const chunk of response) { + /// console.log(chunk.toString()); + /// } + /// ``` + #[napi] + pub fn handle_stream( + &self, + request: NapiRequest, + signal: Option, + ) -> AsyncTask { + AsyncTask::with_optional_signal( + PythonStreamTask { + asgi: self.asgi.clone(), + request: Some(request.into_inner()), + }, + signal, + ) } - // pub fn handle_request( - // &self, - // request: &NapiRequest, - // signal: Option, - // ) -> AsyncTask { - // AsyncTask::with_optional_signal( - // PythonRequestTask { - // asgi: self.asgi.clone(), - // request: request.clone().into_inner(), - // }, - // signal, - // ) - // } /// Handle a PHP request synchronously. /// @@ -273,107 +310,198 @@ impl PythonHandler { /// console.log(response.body); /// ``` #[napi] - pub fn handle_request_sync(&self, request: &NapiRequest) -> Result { + pub fn handle_request_sync(&self, request: NapiRequest) -> Result { let mut task = PythonRequestTask { asgi: self.asgi.clone(), - request: request.clone().into_inner(), + request: Some(request.into_inner()), }; task.compute().map(Into::::into) } } -/// Task container to run a Python request in a worker thread. +/// Task for buffered request handling. +/// Uses identical asgi.handle() call, just buffers the body afterward. #[cfg(feature = "napi-support")] pub struct PythonRequestTask { asgi: Arc, - request: Request, + request: Option, } -/// Error types for the Python request handler. -#[allow(clippy::large_enum_variant)] -#[derive(thiserror::Error, Debug)] -pub enum HandlerError { - /// IO errors that may occur during file operations. - #[error("IO Error: {0}")] - IoError(#[from] std::io::Error), - - /// Error when the current directory cannot be determined. - #[error("Failed to get current directory: {0}")] - CurrentDirectoryError(std::io::Error), - - /// Error when the entry point for the Python application is not found. - #[error("Entry point not found: {0}")] - EntrypointNotFoundError(std::io::Error), - - /// Error when converting a string to a C-compatible string. - #[error("Failed to convert string: {0}")] - StringCovertError(#[from] std::ffi::NulError), - - /// Error when a Python operation fails. - #[error("Python error: {0}")] - PythonError(#[from] pyo3::prelude::PyErr), - - /// Error when response channel is closed before sending a response. - #[error("No response sent")] - NoResponse, - - /// Error when response is interrupted. - #[error("Response interrupted")] - ResponseInterrupted, - - /// Error when response channel is closed. - #[error("Response channel closed: {0}")] - ResponseChannelClosed(#[from] RecvError), - - /// Error when unable to send message to Python. - #[error("Unable to send message to Python: {0}")] - UnableToSendMessageToPython(#[from] SendError), - - /// Error when creating an HTTP response fails. - #[error("Failed to create response: {0}")] - HttpHandlerError(#[from] http_handler::Error), - - /// Error when event loop is closed. - #[error("Event loop closed")] - EventLoopClosed, - - /// Error when PYTHON_NODE_WORKERS is invalid - #[error("Invalid PYTHON_NODE_WORKERS count: {0}")] - InvalidWorkerCount(#[from] std::num::ParseIntError), +#[cfg(feature = "napi-support")] +impl Task for PythonRequestTask { + type Output = Response; + type JsValue = NapiResponse; - /// Error when a lock is poisoned - #[error("Lock poisoned: {0}")] - LockPoisoned(String), + fn compute(&mut self) -> Result { + // Take ownership of the request (FromNapiValue already created fresh body with BodyBuffer) + let request = self + .request + .take() + .ok_or_else(|| Error::from_reason("Request already consumed"))?; + + // Use the shared fallback runtime handle + asgi::fallback_handle().block_on(async { + // Spawn task to send pending body data if present (from Request constructor) + // This prevents deadlock when body size exceeds duplex buffer size + let body_writer = if let Some(body_buffer) = request.extensions().get::() { + let data = Bytes::copy_from_slice(body_buffer.as_bytes()); + let mut body = request.body().clone(); + + Some(tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + + body.write_all(&data).await?; + body.shutdown().await?; + Ok::<(), std::io::Error>(()) + })) + } else { + // No body provided - close stream immediately + let mut body = request.body().clone(); + Some(tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + + body.shutdown().await?; + Ok::<(), std::io::Error>(()) + })) + }; + + // Invoke handler immediately (starts reader task) + // Handler returns when headers are ready, body streaming continues in background + let response = self + .asgi + .handle(request) + .await + .map_err(|e| Error::from_reason(e.to_string()))?; + + // Wait for body writing to complete + if let Some(writer) = body_writer { + writer + .await + .map_err(|e| Error::from_reason(format!("Body writer task failed: {}", e)))? + .map_err(|e| Error::from_reason(e.to_string()))?; + } + + // Extract parts to buffer the body + let (mut parts, mut body) = response.into_parts(); + + // Buffer all body chunks - consuming starts immediately after headers + use http_body_util::BodyExt; + let mut buf = BytesMut::new(); + while let Some(result) = body.frame().await { + match result { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + buf.extend_from_slice(&data); + } + } + Err(e) => { + return Err(Error::from_reason(e)); + } + } + } + let bytes = buf.freeze(); + + // Store buffered body in extension + parts.extensions.insert(BodyBuffer::from_bytes(bytes)); + + // Create a dummy ResponseBody since body is buffered in extension + let dummy_body = ResponseBody::new(); + let buffered_response = http::Response::from_parts(parts, dummy_body); + + Ok::, Error>(buffered_response) + }) + } - /// Error when a Tokio task fails - #[error("Tokio task error: {0}")] - TokioError(String), + fn resolve(&mut self, _env: Env, output: Self::Output) -> Result { + Ok(output.into()) + } } -impl From> for HandlerError { - fn from(err: std::sync::PoisonError) -> Self { - HandlerError::LockPoisoned(err.to_string()) - } +/// Task container to run a Python streaming request in a worker thread. +#[cfg(feature = "napi-support")] +pub struct PythonStreamTask { + asgi: Arc, + request: Option, } #[cfg(feature = "napi-support")] -#[cfg_attr(feature = "napi-support", napi)] -impl Task for PythonRequestTask { +#[napi] +impl Task for PythonStreamTask { type Output = Response; - type JsValue = NapiResponse; + type JsValue = Object<'static>; - // Handle the Python request in the worker thread. + // Handle the Python streaming request in the worker thread. fn compute(&mut self) -> Result { - self - .asgi - .handle_sync(self.request.clone()) - .map_err(|err| Error::from_reason(err.to_string())) + // Take ownership of the request to avoid cloning + let request = self + .request + .take() + .ok_or_else(|| Error::from_reason("Request already consumed"))?; + + // Get the current runtime handle or use the global fallback runtime + // This ensures background tasks stay alive even after compute() returns + asgi::fallback_handle().block_on(async { + // Check if this is a WebSocket request + let is_websocket = request + .extensions() + .get::() + .is_some(); + + // Spawn task to send pending body data if present (from Request constructor) + // This prevents deadlock when body size exceeds duplex buffer size + let body_writer = if let Some(body_buffer) = request.extensions().get::() { + let data = Bytes::copy_from_slice(body_buffer.as_bytes()); + let mut body = request.body().clone(); + let should_close = !is_websocket; + + Some(tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + + body.write_all(&data).await?; + if should_close { + body.shutdown().await?; + } + Ok::<(), std::io::Error>(()) + })) + } else if !is_websocket { + // No body provided - close stream immediately for non-WebSocket requests + let mut body = request.body().clone(); + Some(tokio::spawn(async move { + use tokio::io::AsyncWriteExt; + + body.shutdown().await?; + Ok::<(), std::io::Error>(()) + })) + } else { + None + }; + + // Invoke handler immediately (starts reader task) + // For streaming responses, returns when headers are ready + let response = self + .asgi + .handle(request) + .await + .map_err(|e| Error::from_reason(e.to_string()))?; + + // Wait for body writing to complete + if let Some(writer) = body_writer { + writer + .await + .map_err(|e| Error::from_reason(format!("Body writer task failed: {}", e)))? + .map_err(|e| Error::from_reason(e.to_string()))?; + } + + Ok(response) + }) } - // Handle converting the PHP response to a JavaScript response in the main thread. - fn resolve(&mut self, _env: Env, output: Self::Output) -> Result { - Ok(Into::::into(output)) + // Handle converting the Python response to a JavaScript response in the main thread. + fn resolve(&mut self, env: Env, output: Self::Output) -> Result { + // Convert to NapiResponse and set up async iterator + let response: NapiResponse = output.into(); + response.make_streamable(env) } } @@ -488,3 +616,78 @@ mod tests { assert_eq!(set.len(), 1); } } + +/// Error types for the Python request handler. +#[allow(clippy::large_enum_variant)] +#[derive(thiserror::Error, Debug)] +pub enum HandlerError { + /// IO errors that may occur during file operations. + #[error("IO Error: {0}")] + IoError(#[from] std::io::Error), + + /// Error when the current directory cannot be determined. + #[error("Failed to get current directory: {0}")] + CurrentDirectoryError(std::io::Error), + + /// Error when the entry point for the Python application is not found. + #[error("Entry point not found: {0}")] + EntrypointNotFoundError(std::io::Error), + + /// Error when converting a string to a C-compatible string. + #[error("Failed to convert string: {0}")] + StringCovertError(#[from] std::ffi::NulError), + + /// Error when a Python operation fails. + #[error("Python error: {0}")] + PythonError(#[from] pyo3::prelude::PyErr), + + /// Error when response channel is closed before sending a response. + #[error("No response sent")] + NoResponse, + + /// Error when response is interrupted. + #[error("Response interrupted")] + ResponseInterrupted, + + /// Error when response channel is closed. + #[error("Response channel closed: {0}")] + ResponseChannelClosed(#[from] RecvError), + + /// Error when unable to send message to Python. + #[error("Unable to send message to Python: {0}")] + UnableToSendMessageToPython(#[from] SendError), + + /// Error when creating an HTTP response fails. + #[error("Failed to create response: {0}")] + HttpHandlerError(#[from] http_handler::Error), + + /// Error when event loop is closed. + #[error("Event loop closed")] + EventLoopClosed, + + /// Error when PYTHON_NODE_WORKERS is invalid + #[error("Invalid PYTHON_NODE_WORKERS count: {0}")] + InvalidWorkerCount(#[from] std::num::ParseIntError), + + /// Error when a lock is poisoned + #[error("Lock poisoned: {0}")] + LockPoisoned(String), + + /// Error when a Tokio task fails + #[error("Tokio task error: {0}")] + TokioError(String), + + /// Error when request stream has already been consumed + #[error("Request stream already consumed")] + StreamAlreadyConsumed, + + /// Error when WebSocket connection was not accepted + #[error("WebSocket connection not accepted")] + WebSocketNotAccepted, +} + +impl From> for HandlerError { + fn from(err: std::sync::PoisonError) -> Self { + HandlerError::LockPoisoned(err.to_string()) + } +} diff --git a/test/concurrency.test.mjs b/test/concurrency.test.mjs index 74928f4..fea00a1 100644 --- a/test/concurrency.test.mjs +++ b/test/concurrency.test.mjs @@ -136,12 +136,13 @@ test('Python - concurrent handleRequest calls', async (t) => { if (index % 2 === 0) { assert.strictEqual(response.body.toString(), 'Hello, world!'); } else { - assert.strictEqual(response.body.toString(), 'Chunk 1\nChunk 2\nChunk 3\n'); + assert.strictEqual(response.body.toString(), 'Chunk 1\nChunk 2\nChunk 3\nChunk 4\nChunk 5\n'); } }); // Should complete reasonably quickly (streaming requests have 30ms delay) - assert.ok(duration < 200, `Requests took too long: ${duration}ms`); + // Allow 1500ms to account for system load variations and overhead + assert.ok(duration < 2500, `Requests took too long: ${duration}ms`); }); await t.test('handles requests with large payloads concurrently', async () => { diff --git a/test/fixtures/error_app.py b/test/fixtures/error_app.py index 773487d..df69af3 100644 --- a/test/fixtures/error_app.py +++ b/test/fixtures/error_app.py @@ -1,16 +1,16 @@ async def app(scope, receive, send): # Read request to consume it await receive() - + if scope['path'] == '/error': raise Exception('Test error') - + await send({ 'type': 'http.response.start', 'status': 200, 'headers': [], }) - + await send({ 'type': 'http.response.body', 'body': b'OK', diff --git a/test/fixtures/stream_app.py b/test/fixtures/stream_app.py index 0c13cf5..69964a1 100644 --- a/test/fixtures/stream_app.py +++ b/test/fixtures/stream_app.py @@ -1,20 +1,49 @@ import asyncio async def app(scope, receive, send): + path = scope['path'] + query_string = scope.get('query_string', b'').decode('utf-8') + + # Parse simple query parameters + params = {} + if query_string: + for param in query_string.split('&'): + if '=' in param: + key, value = param.split('=', 1) + params[key] = value + # Read request to consume it await receive() - + await send({ 'type': 'http.response.start', 'status': 200, 'headers': [[b'content-type', b'text/plain']], }) - - # Send response in chunks - for i in range(3): + + # Handle different paths + if path == '/empty': + # Send empty response await send({ 'type': 'http.response.body', - 'body': f'Chunk {i + 1}\n'.encode(), - 'more_body': i < 2, + 'body': b'', + 'more_body': False, }) - await asyncio.sleep(0.01) # Small delay to simulate streaming \ No newline at end of file + else: + # Send response in chunks + # Support 'count' parameter to control number of chunks (default: 5) + # Support 'newlines' parameter to control newlines (default: true) + count = int(params.get('count', '5')) + use_newlines = params.get('newlines', 'true').lower() != 'false' + + for i in range(count): + chunk_text = f'Chunk {i + 1}' + if use_newlines: + chunk_text += '\n' + + await send({ + 'type': 'http.response.body', + 'body': chunk_text.encode(), + 'more_body': i < count - 1, + }) + await asyncio.sleep(0.01) # Small delay to simulate streaming \ No newline at end of file diff --git a/test/fixtures/stream_error_app.py b/test/fixtures/stream_error_app.py new file mode 100644 index 0000000..613f880 --- /dev/null +++ b/test/fixtures/stream_error_app.py @@ -0,0 +1,34 @@ +async def app(scope, receive, send): + # Read request to consume it + await receive() + + # Send response start + await send({ + 'type': 'http.response.start', + 'status': 200, + 'headers': [], + }) + + if scope['path'] == '/error-during-stream': + # Send some chunks before error + await send({ + 'type': 'http.response.body', + 'body': b'Chunk 1\n', + 'more_body': True, + }) + + await send({ + 'type': 'http.response.body', + 'body': b'Chunk 2\n', + 'more_body': True, + }) + + # Now raise an error during streaming + raise Exception('Error during streaming') + + # Normal response + await send({ + 'type': 'http.response.body', + 'body': b'OK', + 'more_body': False, + }) diff --git a/test/fixtures/websocket_app.py b/test/fixtures/websocket_app.py new file mode 100644 index 0000000..2b6317b --- /dev/null +++ b/test/fixtures/websocket_app.py @@ -0,0 +1,122 @@ +async def app(scope, receive, send): + """ + WebSocket ASGI application for testing. + + Supports various test scenarios based on the path: + - /echo: Echo back received messages + - /uppercase: Convert text messages to uppercase + - /close: Accept connection then immediately close + - /ping-pong: Respond to 'ping' with 'pong' + """ + + if scope['type'] == 'websocket': + # Accept the WebSocket connection + await send({ + 'type': 'websocket.accept', + }) + + path = scope.get('path', '/') + + if path == '/close': + # Immediately close after accepting + await send({ + 'type': 'websocket.close', + 'code': 1000, + 'reason': 'Normal closure', + }) + return + + # Handle messages until disconnect + while True: + message = await receive() + + if message['type'] == 'websocket.disconnect': + # Client disconnected + break + + if message['type'] == 'websocket.receive': + text = message.get('text') + bytes_data = message.get('bytes') + + if path == '/echo': + # Echo back the message + if text is not None: + await send({ + 'type': 'websocket.send', + 'text': text, + }) + elif bytes_data is not None: + await send({ + 'type': 'websocket.send', + 'bytes': bytes_data, + }) + + elif path == '/uppercase': + # Convert to uppercase (text only) + if text is not None: + await send({ + 'type': 'websocket.send', + 'text': text.upper(), + }) + elif bytes_data is not None: + # For binary, convert to uppercase if valid UTF-8 + try: + text = bytes_data.decode('utf-8') + await send({ + 'type': 'websocket.send', + 'text': text.upper(), + }) + except UnicodeDecodeError: + # Send back as-is if not valid UTF-8 + await send({ + 'type': 'websocket.send', + 'bytes': bytes_data, + }) + + elif path == '/ping-pong': + # Respond to 'ping' with 'pong' + if text == 'ping': + await send({ + 'type': 'websocket.send', + 'text': 'pong', + }) + else: + # Echo other messages + if text is not None: + await send({ + 'type': 'websocket.send', + 'text': text, + }) + elif bytes_data is not None: + await send({ + 'type': 'websocket.send', + 'bytes': bytes_data, + }) + + else: + # Default: echo + if text is not None: + await send({ + 'type': 'websocket.send', + 'text': text, + }) + elif bytes_data is not None: + await send({ + 'type': 'websocket.send', + 'bytes': bytes_data, + }) + + else: + # Not a WebSocket request - return 426 Upgrade Required + await send({ + 'type': 'http.response.start', + 'status': 426, + 'headers': [ + (b'upgrade', b'WebSocket'), + ], + }) + await send({ + 'type': 'http.response.body', + 'body': b'Upgrade Required', + 'more_body': False, + }) diff --git a/test/handler.test.mjs b/test/handler.test.mjs index d72dca5..3b5f45f 100644 --- a/test/handler.test.mjs +++ b/test/handler.test.mjs @@ -120,12 +120,12 @@ test('Python', async t => { const request = new Request({ method: 'GET', - url: '/stream' + url: '/' }) const response = await python.handleRequest(request) strictEqual(response.status, 200, 'should return 200 status') - strictEqual(response.body.toString(), 'Chunk 1\nChunk 2\nChunk 3\n', 'should concatenate all chunks') + strictEqual(response.body.toString(), 'Chunk 1\nChunk 2\nChunk 3\nChunk 4\nChunk 5\n', 'should concatenate all chunks') }) await t.test('handleRequest - root_path', async () => { diff --git a/test/streaming.test.mjs b/test/streaming.test.mjs new file mode 100644 index 0000000..b1895e2 --- /dev/null +++ b/test/streaming.test.mjs @@ -0,0 +1,269 @@ +import { test } from 'node:test' +import assert, { strictEqual, deepStrictEqual } from 'node:assert' +import { join } from 'node:path' + +import { Python, Request } from '../index.js' + +const fixturesDir = join(import.meta.dirname, 'fixtures') + +test('Python - streaming', async (t) => { + await t.test('handleStream - basic response', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'main:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/' + }) + + const res = await python.handleStream(req) + strictEqual(res.status, 200) + + // Collect streaming body + let body = '' + for await (const chunk of res) { + body += chunk.toString('utf8') + } + strictEqual(body, 'Hello, world!') + }) + + await t.test('handleStream - chunked output', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'stream_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/?count=3&newlines=false' + }) + + const res = await python.handleStream(req) + strictEqual(res.status, 200) + + // Collect all chunks + const chunks = [] + for await (const chunk of res) { + chunks.push(chunk.toString('utf8')) + } + + // Verify complete body (chunks may be combined) + const body = chunks.join('') + strictEqual(body, 'Chunk 1Chunk 2Chunk 3') + + // Verify we received at least one chunk (streaming is working) + assert.ok(chunks.length > 0, 'should receive at least one chunk') + assert.ok(chunks.length <= 3, 'should not receive more than 3 chunks') + }) + + await t.test('handleStream - headers available immediately', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'echo_app:app' + }) + + const req = new Request({ + method: 'POST', + url: 'http://example.com/test', + headers: { + 'Content-Type': 'application/json' + }, + body: Buffer.from(JSON.stringify({ status: 'ok' })) + }) + + const res = await python.handleStream(req) + + // Headers should be available immediately + strictEqual(res.status, 200) + strictEqual(res.headers.get('content-type'), 'application/json') + + // Body can be consumed after + let body = '' + for await (const chunk of res) { + body += chunk.toString('utf8') + } + + const responseBody = JSON.parse(body) + strictEqual(responseBody.method, 'POST') + strictEqual(responseBody.path, '/test') + }) + + await t.test('handleStream - POST with buffered body', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'echo_app:app' + }) + + const req = new Request({ + method: 'POST', + url: 'http://example.com/echo', + headers: { + 'Content-Type': 'text/plain' + }, + body: Buffer.from('Hello from client!') + }) + + const res = await python.handleStream(req) + strictEqual(res.status, 200) + + let body = '' + for await (const chunk of res) { + body += chunk.toString('utf8') + } + + const responseBody = JSON.parse(body) + strictEqual(responseBody.body, 'Hello from client!') + }) + + await t.test('handleStream - POST with streamed body', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'echo_app:app' + }) + + const req = new Request({ + method: 'POST', + url: 'http://example.com/echo', + headers: { + 'Content-Type': 'text/plain' + } + }) + + // Stream the body in chunks using write() and end() + await req.write('Hello ') + await req.write('from ') + await req.write('streaming!') + await req.end() + + const res = await python.handleStream(req) + strictEqual(res.status, 200) + + let body = '' + for await (const chunk of res) { + body += chunk.toString('utf8') + } + + const responseBody = JSON.parse(body) + strictEqual(responseBody.body, 'Hello from streaming!') + }) + + await t.test('handleStream - empty response', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'stream_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/empty' + }) + + const res = await python.handleStream(req) + strictEqual(res.status, 200) + + let body = '' + for await (const chunk of res) { + body += chunk.toString('utf8') + } + strictEqual(body, '') + }) + + await t.test('handleStream - large streaming response', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'stream_app:app' + }) + + const chunkCount = 100 + const req = new Request({ + method: 'GET', + url: `http://example.com/?count=${chunkCount}` + }) + + const res = await python.handleStream(req) + strictEqual(res.status, 200) + + // Collect all chunks + const chunks = [] + for await (const chunk of res) { + chunks.push(chunk.toString('utf8')) + } + + // Join all chunks to get complete body + const body = chunks.join('') + + // Generate expected body - all chunks concatenated + const expectedBody = Array.from({ length: chunkCount }, (_, i) => `Chunk ${i + 1}\n`).join('') + + // Verify we received all expected data in correct order + strictEqual(body, expectedBody, 'should receive all chunks in correct order') + + // Verify we received at least some chunks (streaming is working) + assert.ok(chunks.length > 0, 'should receive at least one chunk') + assert.ok(chunks.length <= chunkCount, 'should not receive more chunks than sent') + }) + + await t.test('handleStream - error handling', async (t) => { + await t.test('exception before response.start', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'error_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/error' + }) + + // Should throw error before response.start is sent + await assert.rejects( + async () => await python.handleStream(req), + (err) => { + // Verify error message contains "Test error" + return err.message.includes('Test error') + }, + 'Should throw Python exception before response.start' + ) + }) + + await t.test('exception after response.start during streaming', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'stream_error_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/error-during-stream' + }) + + const res = await python.handleStream(req) + strictEqual(res.status, 200, 'should return 200 status (response.start sent)') + + // Collect chunks until error + const chunks = [] + await assert.rejects( + async () => { + for await (const chunk of res) { + chunks.push(chunk.toString('utf8')) + } + }, + (err) => { + // Verify error message contains "Error during streaming" + // err might be a string or Error object + const errorMsg = typeof err === 'string' ? err : err.message + return errorMsg.includes('Error during streaming') + }, + 'Should propagate exception as error in stream' + ) + + // Verify we received data before the error (chunks may be combined) + const body = chunks.join('') + strictEqual(body, 'Chunk 1\nChunk 2\n', 'should receive expected data before error') + assert.ok(chunks.length > 0, 'should receive at least one chunk before error') + assert.ok(chunks.length <= 2, 'should not receive more than 2 chunks') + }) + }) +}) diff --git a/test/websocket-integration.test.mjs b/test/websocket-integration.test.mjs new file mode 100644 index 0000000..e31dae4 --- /dev/null +++ b/test/websocket-integration.test.mjs @@ -0,0 +1,236 @@ +import { test } from 'node:test' +import { strictEqual } from 'node:assert' +import { join } from 'node:path' +import { createServer, request as httpRequest } from 'node:http' +import { once } from 'node:events' + +import { Python, Request } from '../index.js' + +const fixturesDir = join(import.meta.dirname, 'fixtures') + +test('Python - WebSocket Integration', async (t) => { + await t.test('HTTP upgrade to WebSocket simulation', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + // Create a simple HTTP server + const server = createServer(async (nodeReq, nodeRes) => { + // Check if this is a WebSocket upgrade request + const isUpgrade = nodeReq.headers.upgrade?.toLowerCase() === 'websocket' + + if (isUpgrade) { + // In a real implementation, you'd: + // 1. Perform WebSocket handshake + // 2. Switch protocols + // 3. Forward socket data to Python + + // For this test, we simulate by creating a WebSocket request + const req = new Request({ + method: nodeReq.method, + url: `http://${nodeReq.headers.host}${nodeReq.url}`, + headers: nodeReq.headers, + websocket: true + }) + + const res = await python.handleStream(req) + + // Simulate sending a message through the WebSocket + await req.write('Integration test message') + + // Read response + const chunk = await res.next() + const response = chunk.toString('utf8') + + await req.end() + + // Send response back through HTTP for test purposes + nodeRes.writeHead(200, { 'Content-Type': 'text/plain' }) + nodeRes.end(response) + } else { + // Regular HTTP request + const req = new Request({ + method: nodeReq.method, + url: `http://${nodeReq.headers.host}${nodeReq.url}`, + headers: nodeReq.headers, + websocket: false + }) + + const res = await python.handleRequest(req) + nodeRes.writeHead(res.status, Object.fromEntries(res.headers.entries())) + nodeRes.end(res.body) + } + }) + + server.listen(0) + await once(server, 'listening') + + const { port } = server.address() + + try { + // Test WebSocket upgrade request using http.request (fetch doesn't support upgrade headers) + const upgradeResponse = await new Promise((resolve, reject) => { + const req = httpRequest({ + hostname: 'localhost', + port, + path: '/echo', + method: 'GET', + headers: { + 'Upgrade': 'websocket', + 'Connection': 'Upgrade', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': '13' + } + }, (res) => { + let data = '' + res.on('data', chunk => { data += chunk }) + res.on('end', () => { + resolve({ status: res.statusCode, body: data }) + }) + }) + req.on('error', reject) + req.end() + }) + + strictEqual(upgradeResponse.body, 'Integration test message', 'Should echo the message through WebSocket') + + // Test regular HTTP request (should get 426 Upgrade Required) + const httpResponse = await fetch(`http://localhost:${port}/echo`) + strictEqual(httpResponse.status, 426, 'Should require upgrade for WebSocket app') + const body = await httpResponse.text() + strictEqual(body, 'Upgrade Required') + } finally { + server.close() + } + }) + + await t.test('Bidirectional communication simulation', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + // Simulate a bidirectional WebSocket connection + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = await python.handleStream(req) + + // Simulate multiple back-and-forth messages + const messages = [ + { send: 'Message 1', expect: 'Message 1' }, + { send: 'Message 2', expect: 'Message 2' }, + { send: 'Message 3', expect: 'Message 3' } + ] + + for (const { send, expect } of messages) { + // Client sends message + await req.write(send) + + // Server responds + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), expect) + } + + await req.end() + }) + + await t.test('Concurrent WebSocket connections', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + // Create multiple concurrent WebSocket connections + const connections = [] + const numConnections = 10 + + for (let i = 0; i < numConnections; i++) { + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = python.handleStream(req) + connections.push({ req, res, id: i }) + } + + // Wait for all connections to establish + await Promise.all(connections.map(c => c.res)) + + // Send messages on all connections concurrently + const sendPromises = connections.map(async ({ req, res, id }) => { + const message = `Connection ${id}` + await req.write(message) + + const awaitedRes = await res + const chunk = await awaitedRes.next() + strictEqual(chunk.toString('utf8'), message) + + await req.end() + }) + + await Promise.all(sendPromises) + }) + + await t.test('WebSocket with large message payload', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = await python.handleStream(req) + + // Send a large message (1MB) + const largeMessage = 'x'.repeat(1024 * 1024) + await req.write(largeMessage) + + // Receive and verify + const chunk = await res.next() + strictEqual(chunk.toString('utf8').length, largeMessage.length) + strictEqual(chunk.toString('utf8'), largeMessage) + + await req.end() + }) + + await t.test('WebSocket message buffering', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = await python.handleStream(req) + + // Send multiple messages quickly without waiting for responses + const messages = ['fast1', 'fast2', 'fast3', 'fast4', 'fast5'] + + for (const msg of messages) { + await req.write(msg) + } + + // Now read all responses + for (const msg of messages) { + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), msg) + } + + await req.end() + }) +}) diff --git a/test/websocket.test.mjs b/test/websocket.test.mjs new file mode 100644 index 0000000..6c1947e --- /dev/null +++ b/test/websocket.test.mjs @@ -0,0 +1,215 @@ +import { test } from 'node:test' +import { strictEqual } from 'node:assert' +import { join } from 'node:path' + +import { Python, Request } from '../index.js' + +const fixturesDir = join(import.meta.dirname, 'fixtures') + +test('Python - WebSocket', async (t) => { + await t.test('basic WebSocket echo', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = await python.handleStream(req) + + // WebSocket accepts don't have traditional HTTP status codes + // The response iterator handles the WebSocket messages + + // Send a text message + await req.write('Hello WebSocket!') + + // Read the echo response + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), 'Hello WebSocket!') + + // Send another message + await req.write('Second message') + + const chunk2 = await res.next() + strictEqual(chunk2.toString('utf8'), 'Second message') + + // Close the connection + await req.end() + }) + + await t.test('WebSocket binary messages', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = await python.handleStream(req) + + // Send binary data + const binaryData = Buffer.from([0x01, 0x02, 0x03, 0x04, 0x05]) + await req.write(binaryData) + + // Read the echo response + const chunk = await res.next() + strictEqual(Buffer.compare(chunk, binaryData), 0, 'Binary data should match') + + await req.end() + }) + + await t.test('WebSocket uppercase transformation', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/uppercase', + websocket: true + }) + + const res = await python.handleStream(req) + + // Send lowercase text + await req.write('hello world') + + // Should receive uppercase + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), 'HELLO WORLD') + + await req.end() + }) + + await t.test('WebSocket ping-pong', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/ping-pong', + websocket: true + }) + + const res = await python.handleStream(req) + + // Send ping + await req.write('ping') + + // Should receive pong + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), 'pong') + + // Send other message + await req.write('hello') + + // Should be echoed + const chunk2 = await res.next() + strictEqual(chunk2.toString('utf8'), 'hello') + + await req.end() + }) + + await t.test('WebSocket immediate close', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/close', + websocket: true + }) + + const res = await python.handleStream(req) + + // Server immediately closes after accepting + const chunk = await res.next() + + // Should receive null/undefined indicating connection closed + strictEqual(chunk, null, 'Connection should be closed') + }) + + await t.test('WebSocket multiple messages in sequence', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: true + }) + + const res = await python.handleStream(req) + + const messages = ['msg1', 'msg2', 'msg3', 'msg4', 'msg5'] + + for (const msg of messages) { + await req.write(msg) + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), msg, `Should echo ${msg}`) + } + + await req.end() + }) + + await t.test('WebSocket with headers', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + headers: { + 'Sec-WebSocket-Protocol': 'chat', + 'Sec-WebSocket-Version': '13', + 'Origin': 'http://example.com' + }, + websocket: true + }) + + const res = await python.handleStream(req) + + // Send and receive a message to verify connection works + await req.write('test') + const chunk = await res.next() + strictEqual(chunk.toString('utf8'), 'test') + + await req.end() + }) + + await t.test('Non-WebSocket request to WebSocket app', async () => { + const python = new Python({ + docroot: fixturesDir, + appTarget: 'websocket_app:app' + }) + + const req = new Request({ + method: 'GET', + url: 'http://example.com/echo', + websocket: false // Regular HTTP request + }) + + const res = await python.handleRequest(req) + + // Should receive 426 Upgrade Required + strictEqual(res.status, 426) + strictEqual(res.body.toString('utf8'), 'Upgrade Required') + }) +})