diff --git a/Cargo.lock b/Cargo.lock index 3348380929..b269512b47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8794,6 +8794,16 @@ dependencies = [ "wac-graph 0.10.0", ] +[[package]] +name = "spin-connection-semaphore" +version = "4.1.0-pre0" +dependencies = [ + "anyhow", + "spin-telemetry", + "tokio", + "tracing", +] + [[package]] name = "spin-core" version = "4.1.0-pre0" @@ -8991,6 +9001,7 @@ dependencies = [ "spin-factors-test", "spin-telemetry", "spin-world", + "terminal", "tokio", "tokio-rustls 0.26.4", "tower-service", @@ -9028,6 +9039,7 @@ dependencies = [ "anyhow", "futures", "mysql_async", + "serde", "spin-core", "spin-factor-otel", "spin-factor-outbound-networking", @@ -9048,6 +9060,7 @@ name = "spin-factor-outbound-networking" version = "4.1.0-pre0" dependencies = [ "anyhow", + "async-trait", "futures-util", "http 1.3.1", "ip_network", @@ -9056,6 +9069,8 @@ dependencies = [ "rustls-pki-types", "rustls-platform-verifier", "serde", + "spin-connection-semaphore", + "spin-factor-outbound-mqtt", "spin-factor-variables", "spin-factor-wasi", "spin-factors", @@ -9064,11 +9079,13 @@ dependencies = [ "spin-manifest", "spin-outbound-networking-config", "spin-serde", + "spin-world", "tempfile", "tokio", "toml 0.8.19", "tracing", "url", + "wasmtime", "wasmtime-wasi", "webpki-root-certs", ] @@ -9087,6 +9104,7 @@ dependencies = [ "postgres-native-tls", "postgres_range", "rust_decimal", + "serde", "serde_json", "spin-common", "spin-core", @@ -9114,6 +9132,7 @@ version = "4.1.0-pre0" dependencies = [ "anyhow", "redis", + "serde", "spin-core", "spin-factor-otel", "spin-factor-outbound-networking", @@ -9165,9 +9184,11 @@ dependencies = [ "async-trait", "bytes", "spin-common", + "spin-connection-semaphore", "spin-factors", "spin-factors-test", "tokio", + "tracing", "wasmtime", "wasmtime-wasi", ] diff --git a/crates/connection-semaphore/Cargo.toml b/crates/connection-semaphore/Cargo.toml new file mode 100644 index 0000000000..eeb77a7103 --- /dev/null +++ b/crates/connection-semaphore/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "spin-connection-semaphore" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } + +[dependencies] +anyhow = { workspace = true } +spin-telemetry = { path = "../telemetry" } +tokio = { workspace = true, features = ["sync"] } +tracing = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt"] } + +[lints] +workspace = true diff --git a/crates/connection-semaphore/src/lib.rs b/crates/connection-semaphore/src/lib.rs new file mode 100644 index 0000000000..5b900c6e10 --- /dev/null +++ b/crates/connection-semaphore/src/lib.rs @@ -0,0 +1,302 @@ +use std::sync::Arc; + +use anyhow::anyhow; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +/// Wraps an optional global and an optional factor-specific semaphore. +#[derive(Clone)] +pub struct ConnectionSemaphore { + global: Option>, + factor_specific: Option>, + factor: &'static str, +} + +impl ConnectionSemaphore { + /// Creates a new `ConnectionSemaphore`. + pub fn new( + global: Option>, + factor_specific_limit: Option, + factor: &'static str, + ) -> Self { + Self { + global, + factor_specific: factor_specific_limit.map(|n| Arc::new(Semaphore::new(n))), + factor, + } + } + + #[cfg(test)] + pub(crate) fn from_raw( + global: Option>, + factor_specific: Option>, + factor: &'static str, + ) -> Self { + Self { + global, + factor_specific, + factor, + } + } + + /// Acquire both configured semaphore slots, returning a permit that holds + /// them until dropped. + /// + /// When both a global and a factor-specific semaphore are configured, this + /// method acquires factor-specific first, then global, ensuring the global + /// permit is never held while blocking on a factor-specific backlog. + pub async fn acquire(&self) -> anyhow::Result { + /// Acquires a single permit from `sem`, trying non-blocking first. + /// + /// Sets `*waited = true` if a blocking wait was required. + async fn acquire_one( + sem: &Arc, + waited: &mut bool, + label: &str, + ) -> anyhow::Result { + match sem.clone().try_acquire_owned() { + Ok(p) => Ok(p), + Err(TryAcquireError::NoPermits) => { + *waited = true; + sem.clone() + .acquire_owned() + .await + .map_err(|_| anyhow!("{label} connection semaphore closed")) + } + Err(_) => Err(anyhow!("{label} connection semaphore closed")), + } + } + let mut waited = false; + let start = std::time::Instant::now(); + + // Acquire factor-specific first, then global. This ensures we never hold + // the global permit while blocking on factor-specific backlog. + let factor_specific = match &self.factor_specific { + Some(f) => Some(acquire_one(f, &mut waited, "factor").await?), + None => None, + }; + // It's fine to hold the factor-specific permit while waiting for the global slot, since + // other consumers of the factor-specific would also end up waiting for the same global slot. + let global = match &self.global { + Some(g) => Some(acquire_one(g, &mut waited, "global").await?), + None => None, + }; + + let factor = self.factor; + if waited { + spin_telemetry::histogram!( + outbound_connection_permit_wait_duration_ms = start.elapsed().as_millis() as f64, + kind = factor + ); + } + spin_telemetry::monotonic_counter!( + outbound_connection_permits_acquired = 1, + kind = factor, + waited = waited + ); + + Ok(ConnectionPermit { + _global: global, + _factor_specific: factor_specific, + }) + } + + /// Attempt to acquire both configured slots without waiting. + /// Returns `None` if either semaphore is exhausted. + /// + /// If the global permit is acquired but the factor-specific permit is not + /// available, the global permit is released before returning `None`. + pub fn try_acquire(&self) -> Option { + match self.try_acquire_permits() { + Ok(permit) => { + spin_telemetry::monotonic_counter!( + outbound_connection_permits_acquired = 1, + kind = self.factor, + waited = false + ); + Some(permit) + } + Err(limit) => { + spin_telemetry::monotonic_counter!( + outbound_connection_permits_rejected = 1, + kind = self.factor, + limit = limit + ); + None + } + } + } + + /// Inner logic for [`Self::try_acquire`], separated so the caller can emit + /// telemetry based on whether a permit was obtained. + /// + /// Returns `Err("global")` or `Err("factor")` to indicate which limit was + /// exhausted, so the caller can tag the rejection metric accordingly. + fn try_acquire_permits(&self) -> Result { + // Acquire global first. If it fails, nothing is consumed. + let global = match &self.global { + Some(s) => match s.clone().try_acquire_owned() { + Ok(p) => Some(p), + Err(_) => return Err("global"), + }, + None => None, + }; + // Now attempt the factor-specific permit. + // On failure, `global` is dropped here, releasing the global slot. + let factor_specific = match &self.factor_specific { + Some(s) => match s.clone().try_acquire_owned() { + Ok(p) => Some(p), + Err(_) => return Err("factor"), + }, + None => None, + }; + Ok(ConnectionPermit { + _global: global, + _factor_specific: factor_specific, + }) + } +} + +/// Holds up to two semaphore permits (global + factor-specific). +/// Both permits are released when this value is dropped. +/// All-`None` fields are valid and represent the no-limits case. +/// +/// Fields are intentionally prefixed with `_` — they exist solely to be dropped. +pub struct ConnectionPermit { + _global: Option, + _factor_specific: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn no_limits_acquire_always_succeeds() { + let sem = ConnectionSemaphore::new(None, None, "test"); + let permit = sem.acquire().await.expect("should succeed"); + drop(permit); + let _permit2 = sem.acquire().await.expect("should succeed again"); + } + + #[test] + fn no_limits_try_acquire_always_succeeds() { + let sem = ConnectionSemaphore::new(None, None, "test"); + let permit = sem.try_acquire().expect("should succeed"); + drop(permit); + let _permit2 = sem.try_acquire().expect("should succeed again"); + } + + #[test] + fn global_limit_only_exhausted() { + let global = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test"); + let permit1 = sem.try_acquire().expect("first should succeed"); + assert!( + sem.try_acquire().is_none(), + "second should fail: global exhausted" + ); + drop(permit1); + assert_eq!(global.available_permits(), 1); + let _permit3 = sem.try_acquire().expect("after release should succeed"); + } + + #[test] + fn factor_limit_only_exhausted() { + let sem = ConnectionSemaphore::new(None, Some(1), "test"); + let permit1 = sem.try_acquire().expect("first should succeed"); + assert!( + sem.try_acquire().is_none(), + "second should fail: factor exhausted" + ); + drop(permit1); + let _permit3 = sem.try_acquire().expect("after release should succeed"); + } + + #[test] + fn both_limits_global_exhausted_first() { + let global = Arc::new(Semaphore::new(1)); + let factor = Arc::new(Semaphore::new(2)); + let sem = ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test"); + + let permit1 = sem.try_acquire().expect("first should succeed"); + // After permit1: global=0, factor=1 + let factor_before = factor.available_permits(); + + // Second try_acquire should fail because global is exhausted. + assert!(sem.try_acquire().is_none(), "should fail: global exhausted"); + // Factor must NOT have been consumed by the failed attempt. + assert_eq!( + factor.available_permits(), + factor_before, + "factor permits should not be consumed when global is exhausted" + ); + drop(permit1); + } + + #[test] + fn both_limits_factor_exhausted_global_released() { + let global = Arc::new(Semaphore::new(2)); + let factor = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test"); + + let permit1 = sem.try_acquire().expect("first should succeed"); + // Global still has 1, factor exhausted + let result = sem.try_acquire(); + assert!(result.is_none(), "should fail: factor exhausted"); + // Global slot must have been released (back to 1) + assert_eq!(global.available_permits(), 1); + drop(permit1); + assert_eq!(global.available_permits(), 2); + } + + #[tokio::test] + async fn acquire_waits_for_release() { + let global = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::new(Some(global.clone()), None, "test"); + + let permit = sem.try_acquire().expect("first should succeed"); + + let sem2 = sem.clone(); + let handle = tokio::spawn(async move { + let _p = sem2.acquire().await.expect("should eventually acquire"); + }); + + drop(permit); // release so the spawned task can proceed + handle.await.expect("task should complete"); + } + + /// Verifies that when factor-specific is exhausted, acquire() releases + /// the global permit while waiting — so other connection types aren't blocked. + #[tokio::test] + async fn acquire_releases_global_while_waiting_for_factor() { + let global = Arc::new(Semaphore::new(1)); + let factor = Arc::new(Semaphore::new(1)); + let sem = ConnectionSemaphore::from_raw(Some(global.clone()), Some(factor.clone()), "test"); + + // Exhaust factor-specific from outside. + let _factor_hold = factor.clone().acquire_owned().await.unwrap(); + + let global_clone = global.clone(); + let sem_clone = sem.clone(); + let handle = tokio::spawn(async move { + sem_clone + .acquire() + .await + .expect("should succeed after factor is released") + }); + + // Yield twice: first to let the spawned task run until it blocks waiting + // for factor-specific; second to confirm it has released the global permit. + tokio::task::yield_now().await; + tokio::task::yield_now().await; + + assert_eq!( + global_clone.available_permits(), + 1, + "global should be free while acquire() waits for factor-specific" + ); + + drop(_factor_hold); + handle.await.expect("task should complete"); + } +} diff --git a/crates/factor-outbound-http/Cargo.toml b/crates/factor-outbound-http/Cargo.toml index 7ccfd3f63b..619a6cc382 100644 --- a/crates/factor-outbound-http/Cargo.toml +++ b/crates/factor-outbound-http/Cargo.toml @@ -22,6 +22,7 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-telemetry = { path = "../telemetry" } spin-world = { path = "../world" } +terminal = { path = "../terminal" } tokio = { workspace = true, features = ["macros", "rt", "net"] } tokio-rustls = { workspace = true } tower-service = { workspace = true } diff --git a/crates/factor-outbound-http/src/lib.rs b/crates/factor-outbound-http/src/lib.rs index e7b43e0bed..fc7d9cf609 100644 --- a/crates/factor-outbound-http/src/lib.rs +++ b/crates/factor-outbound-http/src/lib.rs @@ -16,14 +16,14 @@ use intercept::OutboundHttpInterceptor; use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ - ComponentTlsClientConfigs, OutboundNetworkingFactor, + ComponentTlsClientConfigs, ConnectionSemaphore, OutboundNetworkingFactor, + build_connection_semaphore, config::{allowed_hosts::OutboundAllowedHosts, blocked_networks::BlockedNetworks}, }; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; -use tokio::sync::Semaphore; use wasmtime_wasi_http::WasiHttpCtx; pub use wasmtime_wasi_http::p2::{ @@ -56,14 +56,15 @@ impl Factor for OutboundHttpFactor { mut ctx: ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); + Ok(AppState { wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled), connection_pooling_enabled: config.connection_pooling_enabled, - concurrent_outbound_connections_semaphore: config - .max_concurrent_connections - // Permit count is the max concurrent connections + 1. - // i.e., 0 concurrent connections means 1 total connection. - .map(|n| Arc::new(Semaphore::new(n + 1))), + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "http", + config.max_concurrent_connections, + ), }) } @@ -87,10 +88,7 @@ impl Factor for OutboundHttpFactor { spin_http_client: None, wasi_http_clients: ctx.app_state().wasi_http_clients.clone(), connection_pooling_enabled: ctx.app_state().connection_pooling_enabled, - concurrent_outbound_connections_semaphore: ctx - .app_state() - .concurrent_outbound_connections_semaphore - .clone(), + semaphore: ctx.app_state().semaphore.clone(), otel, }, }) @@ -121,8 +119,8 @@ struct InstanceHttpHooks { wasi_http_clients: wasi::HttpClients, /// Whether connection pooling is enabled for this instance. connection_pooling_enabled: bool, - /// A semaphore to limit the number of concurrent outbound connections. - concurrent_outbound_connections_semaphore: Option>, + /// Semaphore to limit concurrent outbound connections. + semaphore: ConnectionSemaphore, /// Manages access to the OtelFactor state. otel: OtelFactorState, } @@ -153,66 +151,6 @@ impl InstanceState { impl SelfInstanceBuilder for InstanceState {} -/// Helper module for acquiring permits from the outbound connections semaphore. -/// -/// This is used by the outbound HTTP implementations to limit concurrent outbound connections. -mod concurrent_outbound_connections { - use super::*; - - /// Acquires a semaphore permit for the given interface, if a semaphore is configured. - pub async fn acquire_semaphore<'a>( - interface: &str, - semaphore: &'a Option>, - ) -> Option> { - let s = semaphore.as_ref()?; - acquire(interface, || s.try_acquire(), async || s.acquire().await).await - } - - /// Acquires an owned semaphore permit for the given interface, if a semaphore is configured. - pub async fn acquire_owned_semaphore( - interface: &str, - semaphore: &Option>, - ) -> Option { - let s = semaphore.as_ref()?; - acquire( - interface, - || s.clone().try_acquire_owned(), - async || s.clone().acquire_owned().await, - ) - .await - } - - /// Helper function to acquire a semaphore permit, either immediately or by waiting. - /// - /// Allows getting either a borrowed or owned permit. - async fn acquire( - interface: &str, - try_acquire: impl Fn() -> Result, - acquire: impl AsyncFnOnce() -> Result, - ) -> Option { - // Try to acquire a permit without waiting first - // Keep track of whether we had to wait for metrics purposes. - let mut waited = false; - let permit = match try_acquire() { - Ok(p) => Ok(p), - // No available permits right now; wait for one - Err(tokio::sync::TryAcquireError::NoPermits) => { - waited = true; - acquire().await.map_err(|_| ()) - } - Err(_) => Err(()), - }; - if permit.is_ok() { - spin_telemetry::monotonic_counter!( - outbound_http.concurrent_connection_permits_acquired = 1, - interface = interface, - waited = waited - ); - } - permit.ok() - } -} - pub type Request = http::Request; pub type Response = http::Response; @@ -268,8 +206,8 @@ pub struct AppState { wasi_http_clients: wasi::HttpClients, /// Whether connection pooling is enabled for this app. connection_pooling_enabled: bool, - /// A semaphore to limit the number of concurrent outbound connections. - concurrent_outbound_connections_semaphore: Option>, + /// Semaphore to limit concurrent outbound connections. + semaphore: ConnectionSemaphore, } /// Removes IPs in the given [`BlockedNetworks`]. diff --git a/crates/factor-outbound-http/src/runtime_config/spin.rs b/crates/factor-outbound-http/src/runtime_config/spin.rs index fc32c2cdcc..77e5c0800d 100644 --- a/crates/factor-outbound-http/src/runtime_config/spin.rs +++ b/crates/factor-outbound-http/src/runtime_config/spin.rs @@ -7,16 +7,36 @@ use spin_factors::runtime_config::toml::GetTomlValue; /// ```toml /// [outbound_http] /// connection_pooling = true # optional, defaults to true -/// max_concurrent_requests = 10 # optional, defaults to unlimited +/// max_connections = 10 # optional, defaults to unlimited; 0 = no connections allowed +/// # max_concurrent_requests is deprecated, use max_connections instead /// ``` pub fn config_from_table( table: &impl GetTomlValue, ) -> anyhow::Result> { if let Some(outbound_http) = table.get("outbound_http") { - let outbound_http_toml = outbound_http.clone().try_into::()?; + let toml = outbound_http.clone().try_into::()?; + + let max_connections = match (toml.max_connections, toml.max_concurrent_requests) { + (Some(_), Some(_)) => anyhow::bail!( + "cannot set both `max_connections` and `max_concurrent_requests` in \ + `[outbound_http]`; use `max_connections` only" + ), + (Some(n), None) => Some(n), + (None, Some(n)) => { + terminal::warn!( + "`max_concurrent_requests` in `[outbound_http]` is deprecated; \ + use `max_connections` instead (note: `max_connections = 0` blocks all \ + connections, whereas `max_concurrent_requests = 0` allowed 1 connection)" + ); + // Preserve old semaphore semantics: n+1 permits so that 0 allowed 1 connection + Some(n + 1) + } + (None, None) => None, + }; + Ok(Some(super::RuntimeConfig { - connection_pooling_enabled: outbound_http_toml.connection_pooling, - max_concurrent_connections: outbound_http_toml.max_concurrent_requests, + connection_pooling_enabled: toml.connection_pooling, + max_concurrent_connections: max_connections, })) } else { Ok(None) @@ -29,5 +49,8 @@ struct OutboundHttpToml { #[serde(default)] connection_pooling: bool, #[serde(default)] + max_connections: Option, + /// Deprecated. Use `max_connections` instead. + #[serde(default)] max_concurrent_requests: Option, } diff --git a/crates/factor-outbound-http/src/spin.rs b/crates/factor-outbound-http/src/spin.rs index 5c47204bfb..ef352d7e85 100644 --- a/crates/factor-outbound-http/src/spin.rs +++ b/crates/factor-outbound-http/src/spin.rs @@ -112,11 +112,12 @@ impl spin_http::Host for crate::InstanceState { // If we're limiting concurrent outbound requests, acquire a permit // Note: since we don't have access to the underlying connection, we can only // limit the number of concurrent requests, not connections. - let permit = crate::concurrent_outbound_connections::acquire_semaphore( - "spin", - &self.hooks.concurrent_outbound_connections_semaphore, - ) - .await; + let permit = self + .hooks + .semaphore + .acquire() + .await + .map_err(|_| HttpError::RuntimeError)?; let resp = client.execute(req).await.map_err(log_reqwest_error)?; drop(permit); diff --git a/crates/factor-outbound-http/src/wasi.rs b/crates/factor-outbound-http/src/wasi.rs index 2fc151562a..11171e45a1 100644 --- a/crates/factor-outbound-http/src/wasi.rs +++ b/crates/factor-outbound-http/src/wasi.rs @@ -35,7 +35,6 @@ use spin_factors::RuntimeFactorsInstanceState; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, net::TcpStream, - sync::{OwnedSemaphorePermit, Semaphore}, time::timeout, }; use tokio_rustls::client::TlsStream; @@ -53,6 +52,8 @@ use wasmtime_wasi_http::{ p3::{self, bindings::http::types as p3_types}, }; +use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore}; + use crate::{ InstanceHttpHooks, OutboundHttpFactor, SelfRequestOrigin, intercept::{InterceptOutcome, OutboundHttpInterceptor}, @@ -184,9 +185,7 @@ impl p3::WasiHttpHooks for InstanceHttpHooks { self_request_origin: self.self_request_origin.clone(), blocked_networks: self.blocked_networks.clone(), http_clients: self.wasi_http_clients.clone(), - concurrent_outbound_connections_semaphore: self - .concurrent_outbound_connections_semaphore - .clone(), + semaphore: self.semaphore.clone(), }; let config = OutgoingRequestConfig { use_tls: request.uri().scheme() == Some(&Scheme::HTTPS), @@ -442,9 +441,7 @@ impl p2::WasiHttpHooks for InstanceHttpHooks { self_request_origin: self.self_request_origin.clone(), blocked_networks: self.blocked_networks.clone(), http_clients: self.wasi_http_clients.clone(), - concurrent_outbound_connections_semaphore: self - .concurrent_outbound_connections_semaphore - .clone(), + semaphore: self.semaphore.clone(), }; Ok(HostFutureIncomingResponse::Pending( wasmtime_wasi::runtime::spawn( @@ -470,7 +467,7 @@ struct RequestSender { self_request_origin: Option, request_interceptor: Option>, http_clients: HttpClients, - concurrent_outbound_connections_semaphore: Option>, + semaphore: ConnectionSemaphore, } impl RequestSender { @@ -624,8 +621,7 @@ impl RequestSender { connect_timeout, tls_client_config, override_connect_addr, - concurrent_outbound_connections_semaphore: self - .concurrent_outbound_connections_semaphore, + semaphore: self.semaphore, }, async move { if use_tls { @@ -719,8 +715,8 @@ struct ConnectOptions { tls_client_config: Option, /// If set, override the address to connect to instead of using the given `uri`'s authority. override_connect_addr: Option, - /// A semaphore to limit the number of concurrent outbound connections. - concurrent_outbound_connections_semaphore: Option>, + /// Semaphore to limit concurrent outbound connections. + semaphore: ConnectionSemaphore, } impl ConnectOptions { @@ -758,11 +754,7 @@ impl ConnectOptions { let connect = async { // If we're limiting concurrent outbound requests, acquire a permit - let permit = crate::concurrent_outbound_connections::acquire_owned_semaphore( - "wasi", - &self.concurrent_outbound_connections_semaphore, - ) - .await; + let permit = self.semaphore.acquire().await; (TcpStream::connect(&*socket_addrs).await, permit) }; @@ -771,6 +763,7 @@ impl ConnectOptions { let (stream, permit) = timeout(self.connect_timeout, connect) .await .map_err(|_| ErrorCode::ConnectionTimeout)?; + let permit = permit.map_err(|_| ErrorCode::ConnectionRefused)?; let stream = stream.map_err(|err| match err.kind() { std::io::ErrorKind::AddrNotAvailable => dns_error("address not available".into(), 0), _ => ErrorCode::ConnectionRefused, @@ -912,7 +905,7 @@ impl AsyncWrite for RustlsStream { } } -/// A TCP stream that holds an optional permit indicating that it is allowed to exist. +/// A TCP stream that holds a permit indicating that it is allowed to exist. struct PermittedTcpStream { /// The wrapped TCP stream. inner: TcpStream, @@ -920,7 +913,7 @@ struct PermittedTcpStream { /// /// When this stream is dropped, the permit is also dropped, allowing another /// connection to be established. - _permit: Option, + _permit: ConnectionPermit, } impl Connection for PermittedTcpStream { @@ -1219,11 +1212,12 @@ mod tests { /// `ConnectionTimeout` within the configured deadline. #[tokio::test] async fn connect_timeout_applies_to_permit_acquisition() { - // Create a semaphore with exactly 1 permit and hold it immediately, - // leaving 0 permits available. This simulates all outbound-connection - // slots being occupied. - let semaphore = Arc::new(Semaphore::new(1)); - let _held = semaphore.clone().try_acquire_owned().unwrap(); + // Create a semaphore with exactly 1 permit and immediately exhaust it, leaving + // 0 permits available. This simulates all outbound-connection slots being occupied. + let conn_semaphore = ConnectionSemaphore::new(None, Some(1), "test"); + let _held = conn_semaphore + .try_acquire() + .expect("exhausting the single permit"); let options = ConnectOptions { // No blocked networks; we want the address to pass the filter. @@ -1233,7 +1227,7 @@ mod tests { tls_client_config: None, // Skip DNS by supplying the address directly. override_connect_addr: Some("127.0.0.1:1".parse().unwrap()), - concurrent_outbound_connections_semaphore: Some(semaphore), + semaphore: conn_semaphore, }; // `connect_tcp` must time out while waiting for a permit rather than diff --git a/crates/factor-outbound-mqtt/Cargo.toml b/crates/factor-outbound-mqtt/Cargo.toml index 5561c30744..c72a17691b 100644 --- a/crates/factor-outbound-mqtt/Cargo.toml +++ b/crates/factor-outbound-mqtt/Cargo.toml @@ -6,9 +6,9 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } -serde = { workspace = true } # Upstream hasn't been updating dependencies: https://github.com/bytebeamio/rumqtt/issues/1046 rumqttc = { git = "https://github.com/spinframework/rumqtt", rev = "65b7b39a70b12d1781acb61cc07f1f1b680e7643", default-features = false, features = ["use-rustls-no-provider", "url"] } +serde = { workspace = true, features = ["derive"] } spin-core = { path = "../core" } spin-factor-otel = { path = "../factor-otel" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } diff --git a/crates/factor-outbound-mqtt/src/host.rs b/crates/factor-outbound-mqtt/src/host.rs index efc6a6d229..5a090f3772 100644 --- a/crates/factor-outbound-mqtt/src/host.rs +++ b/crates/factor-outbound-mqtt/src/host.rs @@ -7,6 +7,7 @@ use spin_core::{ }; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts; +use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore}; use spin_world::spin::mqtt::mqtt as v3; use spin_world::v2::mqtt as v2; use tracing::{Level, instrument}; @@ -15,8 +16,9 @@ use crate::{ClientCreator, allowed_hosts::AllowedHostChecker}; pub struct InstanceState { allowed_hosts: AllowedHostChecker, - connections: spin_resource_table::Table>, + connections: spin_resource_table::Table<(Arc, ConnectionPermit)>, create_client: Arc, + semaphore: ConnectionSemaphore, otel: OtelFactorState, max_payload_size_bytes: Option, } @@ -25,6 +27,7 @@ impl InstanceState { pub fn new( allowed_hosts: OutboundAllowedHosts, create_client: Arc, + semaphore: ConnectionSemaphore, otel: OtelFactorState, max_payload_size_bytes: Option, ) -> Self { @@ -32,6 +35,7 @@ impl InstanceState { allowed_hosts: AllowedHostChecker::new(allowed_hosts), create_client, connections: spin_resource_table::Table::new(1024), + semaphore, otel, max_payload_size_bytes, } @@ -60,8 +64,15 @@ impl InstanceState { password: String, keep_alive_interval: Duration, ) -> Result, v2::Error> { + let permit = self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::TooManyConnections)?; + let client = + (self.create_client).create(address, username, password, keep_alive_interval)?; self.connections - .push((self.create_client).create(address, username, password, keep_alive_interval)?) + .push((client, permit)) .map(Resource::new_own) .map_err(|_| v2::Error::TooManyConnections) } @@ -72,7 +83,7 @@ impl InstanceState { .ok_or(v2::Error::Other( "could not find connection for resource".into(), )) - .map(|c| c.as_ref()) + .map(|(c, _permit)| c.as_ref()) } fn get_conn_v3( @@ -81,7 +92,7 @@ impl InstanceState { ) -> Result, v3::Error> { self.connections .get(connection.rep()) - .cloned() + .map(|(c, _permit)| c.clone()) .ok_or(v3::Error::Other( "could not find connection for resource".into(), )) @@ -110,10 +121,14 @@ impl v3::HostConnectionWithStore for crate::MqttFactorData { password: String, keep_alive_interval_in_secs: u64, ) -> Result, v3::Error> { - let (allowed_host_checker, create_client) = accessor.with(|mut access| { + let (allowed_host_checker, create_client, semaphore) = accessor.with(|mut access| { let host = access.get(); host.otel.reparent_tracing_span(); - (host.allowed_hosts.clone(), host.create_client.clone()) + ( + host.allowed_hosts.clone(), + host.create_client.clone(), + host.semaphore.clone(), + ) }); if !allowed_host_checker @@ -126,19 +141,22 @@ impl v3::HostConnectionWithStore for crate::MqttFactorData { ))); } - let client = create_client - .create( - address, - username, - password, - Duration::from_secs(keep_alive_interval_in_secs), - ) - .unwrap(); + let permit = semaphore + .acquire() + .await + .map_err(|_| v3::Error::TooManyConnections)?; + + let client = create_client.create( + address, + username, + password, + Duration::from_secs(keep_alive_interval_in_secs), + )?; accessor.with(|mut access| { let host = access.get(); host.connections - .push(client) + .push((client, permit)) .map(Resource::new_own) .map_err(|_| v3::Error::TooManyConnections) }) diff --git a/crates/factor-outbound-mqtt/src/lib.rs b/crates/factor-outbound-mqtt/src/lib.rs index b42444c788..2b74331f1d 100644 --- a/crates/factor-outbound-mqtt/src/lib.rs +++ b/crates/factor-outbound-mqtt/src/lib.rs @@ -9,9 +9,12 @@ use host::InstanceState; use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS}; use spin_core::async_trait; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::{ + ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, +}; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, + anyhow, }; use spin_world::spin::mqtt::mqtt as v3; use spin_world::v2::mqtt as v2; @@ -20,6 +23,7 @@ use tokio::sync::Mutex; pub use host::MqttClient; use crate::host::other_error_v3; +use crate::runtime_config::RuntimeConfig; pub struct OutboundMqttFactor { create_client: Arc, @@ -32,11 +36,14 @@ impl OutboundMqttFactor { } pub struct AppState { + /// Optional maximum payload size in bytes for MQTT messages. If `None`, no limit is enforced. max_payload_size_bytes: Option, + /// Semaphore to limit concurrent outbound MQTT connections. + pub semaphore: ConnectionSemaphore, } impl Factor for OutboundMqttFactor { - type RuntimeConfig = runtime_config::RuntimeConfig; + type RuntimeConfig = RuntimeConfig; type AppState = AppState; type InstanceBuilder = InstanceState; @@ -51,7 +58,13 @@ impl Factor for OutboundMqttFactor { mut ctx: ConfigureAppContext, ) -> anyhow::Result { let config = ctx.take_runtime_config().unwrap_or_default(); + Ok(AppState { + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "mqtt", + config.max_connections, + ), max_payload_size_bytes: config.max_payload_size_bytes, }) } @@ -68,6 +81,7 @@ impl Factor for OutboundMqttFactor { Ok(InstanceState::new( allowed_hosts, self.create_client.clone(), + ctx.app_state().semaphore.clone(), otel, ctx.app_state().max_payload_size_bytes, )) diff --git a/crates/factor-outbound-mqtt/src/runtime_config.rs b/crates/factor-outbound-mqtt/src/runtime_config.rs index 786a2eb6f7..702f04233d 100644 --- a/crates/factor-outbound-mqtt/src/runtime_config.rs +++ b/crates/factor-outbound-mqtt/src/runtime_config.rs @@ -9,4 +9,9 @@ pub struct RuntimeConfig { /// should set this to prevent tenants from sending excessively large payloads. /// Configure via `[outbound_mqtt] max_payload_size_bytes` in the runtime config TOML. pub max_payload_size_bytes: Option, + /// If set, limits the number of concurrent outbound MQTT connections. + /// + /// When `None` (the default), no limit is enforced. Operators in multi-tenant deployments + /// should set this to prevent tenants from exhausting connection resources. + pub max_connections: Option, } diff --git a/crates/factor-outbound-mqtt/src/runtime_config/spin.rs b/crates/factor-outbound-mqtt/src/runtime_config/spin.rs index debe8d79e7..b7c3400193 100644 --- a/crates/factor-outbound-mqtt/src/runtime_config/spin.rs +++ b/crates/factor-outbound-mqtt/src/runtime_config/spin.rs @@ -1,4 +1,4 @@ -use anyhow::Context; +use anyhow::Context as _; use serde::Deserialize; use spin_factors::runtime_config::toml::GetTomlValue; @@ -8,6 +8,7 @@ use spin_factors::runtime_config::toml::GetTomlValue; /// ```toml /// [outbound_mqtt] /// max_payload_size_bytes = 65536 # optional, no limit by default +/// max_connections = 10 # optional, defaults to unlimited /// ``` pub fn config_from_table( table: &impl GetTomlValue, @@ -19,6 +20,7 @@ pub fn config_from_table( .context("failed to parse [outbound_mqtt] table")?; Ok(Some(super::RuntimeConfig { max_payload_size_bytes: toml.max_payload_size_bytes, + max_connections: toml.max_connections, })) } else { Ok(None) @@ -28,6 +30,6 @@ pub fn config_from_table( #[derive(Debug, Default, Deserialize)] #[serde(deny_unknown_fields)] struct OutboundMqttToml { - #[serde(default)] max_payload_size_bytes: Option, + max_connections: Option, } diff --git a/crates/factor-outbound-mqtt/tests/factor_test.rs b/crates/factor-outbound-mqtt/tests/factor_test.rs index b532fd5293..dbe4d6a08e 100644 --- a/crates/factor-outbound-mqtt/tests/factor_test.rs +++ b/crates/factor-outbound-mqtt/tests/factor_test.rs @@ -9,7 +9,7 @@ use spin_factor_variables::VariablesFactor; use spin_factors::{RuntimeFactors, anyhow}; use spin_factors_test::{TestEnvironment, toml}; use spin_world::spin::mqtt::mqtt::{Error, Qos}; -use spin_world::v2::mqtt as v2; +use spin_world::v2::mqtt as v2_mqtt; pub struct MockMqttClient {} @@ -62,7 +62,7 @@ fn test_env() -> TestEnvironment { #[tokio::test] async fn disallowed_host_fails() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; let env = TestEnvironment::new(factors()).extend_manifest(toml! { [component.test-component] @@ -82,14 +82,14 @@ async fn disallowed_host_fails() -> anyhow::Result<()> { let Err(err) = res else { bail!("expected Err, got Ok"); }; - assert!(matches!(err, v2::Error::ConnectionFailed(_))); + assert!(matches!(err, v2_mqtt::Error::ConnectionFailed(_))); Ok(()) } #[tokio::test] async fn allowed_host_succeeds() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; let mut state = test_env().build_instance_state().await?; @@ -111,7 +111,7 @@ async fn allowed_host_succeeds() -> anyhow::Result<()> { #[tokio::test] async fn exercise_publish() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; let mut state = test_env().build_instance_state().await?; @@ -131,7 +131,7 @@ async fn exercise_publish() -> anyhow::Result<()> { res, "message".to_string(), b"test message".to_vec(), - v2::Qos::ExactlyOnce, + v2_mqtt::Qos::ExactlyOnce, ) .await?; @@ -140,13 +140,14 @@ async fn exercise_publish() -> anyhow::Result<()> { #[tokio::test] async fn oversized_payload_rejected() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; const LIMIT: usize = 10; let env = test_env().runtime_config(TestFactorsRuntimeConfig { mqtt: Some(spin_factor_outbound_mqtt::runtime_config::RuntimeConfig { max_payload_size_bytes: Some(LIMIT), + ..Default::default() }), ..Default::default() })?; @@ -166,10 +167,15 @@ async fn oversized_payload_rejected() -> anyhow::Result<()> { let oversized = vec![0u8; LIMIT + 1]; let err = state .mqtt - .publish(conn, "topic".to_string(), oversized, v2::Qos::AtMostOnce) + .publish( + conn, + "topic".to_string(), + oversized, + v2_mqtt::Qos::AtMostOnce, + ) .await; assert!( - matches!(err, Err(v2::Error::Other(_))), + matches!(err, Err(v2_mqtt::Error::Other(_))), "expected Other error for oversized payload, got {err:?}" ); @@ -178,13 +184,14 @@ async fn oversized_payload_rejected() -> anyhow::Result<()> { #[tokio::test] async fn payload_at_limit_succeeds() -> anyhow::Result<()> { - use v2::HostConnection; + use v2_mqtt::HostConnection; const LIMIT: usize = 10; let env = test_env().runtime_config(TestFactorsRuntimeConfig { mqtt: Some(spin_factor_outbound_mqtt::runtime_config::RuntimeConfig { max_payload_size_bytes: Some(LIMIT), + ..Default::default() }), ..Default::default() })?; @@ -208,9 +215,72 @@ async fn payload_at_limit_succeeds() -> anyhow::Result<()> { conn, "topic".to_string(), exactly_limit, - v2::Qos::AtMostOnce, + v2_mqtt::Qos::AtMostOnce, + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn connection_limit_blocks_when_exhausted() -> anyhow::Result<()> { + use v2_mqtt::HostConnection; + + let env = TestEnvironment::new(factors()) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["mqtt://*:*"] + }) + .runtime_config(TestFactorsRuntimeConfig { + mqtt: Some(spin_factor_outbound_mqtt::runtime_config::RuntimeConfig { + max_connections: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + + // Open first connection - should succeed immediately. + let conn1 = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await?; + + // Second open should block (wait for a permit) since the limit is 1. + let timed_out = tokio::time::timeout( + Duration::from_millis(10), + state.mqtt.open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ), + ) + .await + .is_err(); + assert!(timed_out, "expected second open to block when limit is 1"); + + // Releasing the first connection returns its permit to the semaphore. + state.mqtt.drop(conn1).await?; + + // Now a new connection should succeed. + let conn2 = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, ) .await?; + state.mqtt.drop(conn2).await?; Ok(()) } diff --git a/crates/factor-outbound-mysql/Cargo.toml b/crates/factor-outbound-mysql/Cargo.toml index 64f51e3db0..b99020589e 100644 --- a/crates/factor-outbound-mysql/Cargo.toml +++ b/crates/factor-outbound-mysql/Cargo.toml @@ -10,6 +10,7 @@ doctest = false [dependencies] anyhow = { workspace = true } futures = { workspace = true } +serde = { workspace = true } # Removing default features for mysql_async to remove flate2/zlib feature mysql_async = { version = "0.35", default-features = false, features = [ "minimal-rust", @@ -23,7 +24,7 @@ spin-resource-table = { path = "../table" } spin-telemetry = { path = "../telemetry" } spin-world = { path = "../world" } spin-wasi-async = { path = "../wasi-async" } -tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio = { workspace = true, features = ["rt-multi-thread", "sync"] } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/factor-outbound-mysql/src/host.rs b/crates/factor-outbound-mysql/src/host.rs index 6f5577d80f..e86d2ecd9e 100644 --- a/crates/factor-outbound-mysql/src/host.rs +++ b/crates/factor-outbound-mysql/src/host.rs @@ -1,12 +1,14 @@ +use std::sync::Arc; + use anyhow::Result; use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; +use spin_factor_outbound_networking::ConnectionPermit; use spin_telemetry::traces::{self, Blame}; use spin_world::MAX_HOST_BUFFERED_BYTES; use spin_world::spin::mysql::mysql as v3; use spin_world::v1::mysql as v1; use spin_world::v2::mysql as v2; use spin_world::v2::rdbms_types as v2_types; -use std::sync::Arc; use tokio::sync::Mutex; use tracing::field::Empty; use tracing::{Level, instrument}; @@ -15,7 +17,11 @@ use crate::client::Client; use crate::{InstanceState, InstanceStateInner, MysqlFactorData}; impl InstanceStateInner { - async fn open_connection(&mut self, address: &str) -> Result { + async fn open_connection( + &mut self, + address: &str, + permit: ConnectionPermit, + ) -> Result { spin_factor_outbound_networking::record_address_fields(address); if !self.is_address_allowed(address).await.map_err(|e| { @@ -40,7 +46,7 @@ impl InstanceStateInner { err })?; self.connections - .push(Arc::new(Mutex::new(client))) + .push((Arc::new(Mutex::new(client)), permit)) .map_err(|_| { // The guest exceeded the host-imposed connection limit. let err = v2::Error::ConnectionFailed("too many connections".into()); @@ -50,13 +56,16 @@ impl InstanceStateInner { } fn get_client(&mut self, connection: u32) -> Result>, v2::Error> { - self.connections.get(connection).cloned().ok_or_else(|| { - // The connection table is managed entirely by the host, so a - // missing handle indicates a host-side bug, not a guest mistake. - let err = v2::Error::ConnectionFailed("no connection found".into()); - traces::mark_as_error(&err, Some(Blame::Host)); - err - }) + self.connections + .get(connection) + .map(|(conn, _permit)| conn.clone()) + .ok_or_else(|| { + // The connection table is managed entirely by the host, so a + // missing handle indicates a host-side bug, not a guest mistake. + let err = v2::Error::ConnectionFailed("no connection found".into()); + traces::mark_as_error(&err, Some(Blame::Host)); + err + }) } async fn is_address_allowed(&self, address: &str) -> Result { @@ -72,7 +81,7 @@ impl v3::Host for InstanceState { impl v3::HostConnection for InstanceState { async fn drop(&mut self, connection: Resource) -> Result<()> { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.connections.remove(connection.rep()); Ok(()) } @@ -90,10 +99,18 @@ impl v3::HostConnectionWithStore for MysqlFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let state = accessor.with(|mut access| access.get().0.clone()); - let mut state = state.lock().await; + let (state_arc, semaphore) = accessor.with(|mut access| { + let host = access.get(); + (host.inner.clone(), host.semaphore.clone()) + }); + let permit = semaphore.acquire().await.map_err(|_| { + v3::Error::from(v2::Error::ConnectionFailed("too many connections".into())) + })?; + let mut state = state_arc.lock().await; state.otel.reparent_tracing_span(); - Ok(Resource::new_own(state.open_connection(&address).await?)) + Ok(Resource::new_own( + state.open_connection(&address, permit).await?, + )) } #[instrument(name = "spin_outbound_mysql.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))] @@ -103,7 +120,7 @@ impl v3::HostConnectionWithStore for MysqlFactorData { statement: String, params: Vec, ) -> Result<(), v3::Error> { - let state = accessor.with(|mut access| access.get().0.clone()); + let state = accessor.with(|mut access| access.get().inner.clone()); let client = { let mut state = state.lock().await; state.otel.reparent_tracing_span(); @@ -125,7 +142,7 @@ impl v3::HostConnectionWithStore for MysqlFactorData { statement: String, params: Vec, ) -> Result { - let state = accessor.with(|mut access| access.get().0.clone()); + let state = accessor.with(|mut access| access.get().inner.clone()); let client = { let mut state = state.lock().await; state.otel.reparent_tracing_span(); @@ -161,9 +178,17 @@ impl v2::Host for InstanceState {} impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_mysql.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v2::Error> { - let mut state = self.0.lock().await; + let permit = self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?; + let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); - state.open_connection(&address).await.map(Resource::new_own) + state + .open_connection(&address, permit) + .await + .map(Resource::new_own) } #[instrument(name = "spin_outbound_mysql.execute", skip(self, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "mysql", otel.name = statement))] @@ -173,7 +198,7 @@ impl v2::HostConnection for InstanceState { statement: String, params: Vec, ) -> Result<(), v2::Error> { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); state .get_client(connection.rep())? @@ -191,7 +216,7 @@ impl v2::HostConnection for InstanceState { statement: String, params: Vec, ) -> Result { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.otel.reparent_tracing_span(); state .get_client(connection.rep())? @@ -203,7 +228,7 @@ impl v2::HostConnection for InstanceState { } async fn drop(&mut self, connection: Resource) -> Result<()> { - let mut state = self.0.lock().await; + let mut state = self.inner.lock().await; state.connections.remove(connection.rep()); Ok(()) } @@ -218,13 +243,23 @@ impl v2_types::Host for InstanceState { /// Delegate a function call to the v2::HostConnection implementation macro_rules! delegate { ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ + let permit = $self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::ConnectionFailed("too many connections".into()))?; let connection = { - let mut state = $self.0.lock().await; - Resource::new_own(state.open_connection(&$address).await?) + let mut state = $self.inner.lock().await; + Resource::new_own(state.open_connection(&$address, permit).await?) }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(Into::into) + .map_err(Into::into); + $self.inner.lock().await.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-mysql/src/lib.rs b/crates/factor-outbound-mysql/src/lib.rs index 10a51f1e5f..72966d610e 100644 --- a/crates/factor-outbound-mysql/src/lib.rs +++ b/crates/factor-outbound-mysql/src/lib.rs @@ -1,26 +1,35 @@ pub mod client; mod host; +pub mod runtime_config; + +use std::sync::Arc; use client::Client; use mysql_async::Conn as MysqlClient; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; use spin_factor_outbound_networking::{ - OutboundNetworkingFactor, config::allowed_hosts::OutboundAllowedHosts, + ConnectionPermit, ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, + config::allowed_hosts::OutboundAllowedHosts, }; use spin_factors::{Factor, FactorData, InitContext, RuntimeFactors, SelfInstanceBuilder}; use spin_world::spin::mysql::mysql as v3; use spin_world::v1::mysql as v1; use spin_world::v2::mysql as v2; -use std::sync::Arc; use tokio::sync::Mutex; pub struct OutboundMysqlFactor { _phantom: std::marker::PhantomData, } +pub struct AppState { + /// Semaphore to limit concurrent outbound MySQL connections. + pub semaphore: ConnectionSemaphore, +} + impl Factor for OutboundMysqlFactor { - type RuntimeConfig = (); - type AppState = (); + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl InitContext) -> anyhow::Result<()> { @@ -32,9 +41,17 @@ impl Factor for OutboundMysqlFactor { fn configure_app( &self, - _ctx: spin_factors::ConfigureAppContext, + mut ctx: spin_factors::ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + let config = ctx.take_runtime_config().unwrap_or_default(); + + Ok(AppState { + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "mysql", + config.max_connections, + ), + }) } fn prepare( @@ -46,11 +63,14 @@ impl Factor for OutboundMysqlFactor { .allowed_hosts(); let otel = OtelFactorState::from_prepare_context(&mut ctx)?; - Ok(InstanceState(Arc::new(Mutex::new(InstanceStateInner { - allowed_hosts, - connections: Default::default(), - otel, - })))) + Ok(InstanceState { + inner: Arc::new(Mutex::new(InstanceStateInner { + allowed_hosts, + connections: Default::default(), + otel, + })), + semaphore: ctx.app_state().semaphore.clone(), + }) } } @@ -70,11 +90,14 @@ impl OutboundMysqlFactor { pub struct InstanceStateInner { allowed_hosts: OutboundAllowedHosts, - connections: spin_resource_table::Table>>, + connections: spin_resource_table::Table<(Arc>, ConnectionPermit)>, otel: OtelFactorState, } -pub struct InstanceState(Arc>>); +pub struct InstanceState { + pub(crate) inner: Arc>>, + pub semaphore: ConnectionSemaphore, +} impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-mysql/src/runtime_config.rs b/crates/factor-outbound-mysql/src/runtime_config.rs new file mode 100644 index 0000000000..5a96047a1f --- /dev/null +++ b/crates/factor-outbound-mysql/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound MySQL. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound MySQL connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-mysql/src/runtime_config/spin.rs b/crates/factor-outbound-mysql/src/runtime_config/spin.rs new file mode 100644 index 0000000000..85c38253cc --- /dev/null +++ b/crates/factor-outbound-mysql/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound MySQL from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_mysql] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_mysql) = table.get("outbound_mysql") { + let toml = outbound_mysql.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundMysqlToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/factor-outbound-networking/Cargo.toml b/crates/factor-outbound-networking/Cargo.toml index f340c27093..9836e16cf1 100644 --- a/crates/factor-outbound-networking/Cargo.toml +++ b/crates/factor-outbound-networking/Cargo.toml @@ -13,6 +13,7 @@ rustls = { workspace = true } rustls-pki-types = { workspace = true } rustls-platform-verifier = { workspace = true } serde = { workspace = true } +spin-connection-semaphore = { path = "../connection-semaphore" } spin-factor-variables = { path = "../factor-variables" } spin-factor-wasi = { path = "../factor-wasi" } spin-factors = { path = "../factors" } @@ -20,16 +21,21 @@ spin-locked-app = { path = "../locked-app" } spin-manifest = { path = "../manifest" } spin-outbound-networking-config = { path = "../outbound-networking-config" } spin-serde = { path = "../serde" } +tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } opentelemetry-semantic-conventions = { workspace = true } url = { workspace = true } webpki-root-certs = "1.0.7" [dev-dependencies] +async-trait = { workspace = true } +spin-factor-outbound-mqtt = { path = "../factor-outbound-mqtt" } spin-factors-test = { path = "../factors-test" } +spin-world = { path = "../world" } tempfile = { workspace = true } tokio = { workspace = true, features = ["macros", "rt"] } toml = { workspace = true } +wasmtime = { workspace = true } wasmtime-wasi = { workspace = true } [features] diff --git a/crates/factor-outbound-networking/src/connection_semaphore.rs b/crates/factor-outbound-networking/src/connection_semaphore.rs new file mode 100644 index 0000000000..f91c0f8583 --- /dev/null +++ b/crates/factor-outbound-networking/src/connection_semaphore.rs @@ -0,0 +1 @@ +pub use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore}; diff --git a/crates/factor-outbound-networking/src/lib.rs b/crates/factor-outbound-networking/src/lib.rs index 5b20c46be3..cbf91a40e2 100644 --- a/crates/factor-outbound-networking/src/lib.rs +++ b/crates/factor-outbound-networking/src/lib.rs @@ -1,4 +1,5 @@ mod allowed_hosts; +pub mod connection_semaphore; pub mod runtime_config; mod tls; @@ -7,18 +8,20 @@ use std::{collections::HashMap, sync::Arc}; use futures_util::FutureExt as _; use opentelemetry_semantic_conventions::attribute::SERVER_PORT; use spin_factor_variables::VariablesFactor; -use spin_factor_wasi::{SocketAddrUse, WasiFactor}; +use spin_factor_wasi::{SocketAddrUse, SocketPermitState, WasiFactor}; use spin_factors::{ ConfigureAppContext, Error, Factor, FactorInstanceBuilder, PrepareContext, RuntimeFactors, anyhow::{self, Context}, }; use spin_outbound_networking_config::allowed_hosts::{DisallowedHostHandler, OutboundAllowedHosts}; +use tokio::sync::Semaphore; use url::Url; use crate::{ allowed_hosts::allowed_outbound_hosts, runtime_config::RuntimeConfig, tls::TlsClientConfigs, }; pub use allowed_hosts::validate_service_chaining_for_components; +pub use connection_semaphore::{ConnectionPermit, ConnectionSemaphore}; pub use crate::tls::{ComponentTlsClientConfigs, TlsClientConfig}; use config::allowed_hosts::AllowedHostsConfig; @@ -69,15 +72,44 @@ impl Factor for OutboundNetworkingFactor { client_tls_configs, blocked_ip_networks: block_networks, block_private_networks, + max_socket_connections, + max_total_connections, } = ctx.take_runtime_config().unwrap_or_default(); let blocked_networks = BlockedNetworks::new(block_networks, block_private_networks); let tls_client_configs = TlsClientConfigs::new(client_tls_configs)?; + let global_connection_semaphore = + max_total_connections.map(|n| Arc::new(Semaphore::new(n))); + + if let (Some(socket_cap), Some(global_cap)) = + (max_socket_connections, max_total_connections) + && socket_cap > global_cap + { + tracing::warn!( + "outbound_networking max_socket_connections ({socket_cap}) exceeds \ + max_total_connections ({global_cap}); the global limit will be the effective \ + cap for TCP/UDP sockets" + ); + } + + let socket_connection_semaphore = + if max_socket_connections.is_some() || global_connection_semaphore.is_some() { + Some(ConnectionSemaphore::new( + global_connection_semaphore.clone(), + max_socket_connections, + "wasi-sockets", + )) + } else { + None + }; Ok(AppState { component_allowed_hosts, blocked_networks, tls_client_configs, + socket_connection_semaphore, + global_connection_semaphore, + max_total_connections, }) } @@ -123,10 +155,18 @@ impl Factor for OutboundNetworkingFactor { self.disallowed_host_handler.clone(), ); let blocked_networks = ctx.app_state().blocked_networks.clone(); + let permit_state = ctx + .app_state() + .socket_connection_semaphore + .clone() + .map(SocketPermitState::new); match ctx.instance_builder::() { Ok(wasi_builder) => { - // Update Wasi socket allowed ports + if let Some(state) = permit_state { + wasi_builder.set_socket_permit_state(state); + } + let allowed_hosts = allowed_hosts.clone(); wasi_builder.outbound_socket_addr_check(move |addr, addr_use| { let allowed_hosts = allowed_hosts.clone(); @@ -185,6 +225,42 @@ pub struct AppState { blocked_networks: BlockedNetworks, /// TLS client configs tls_client_configs: TlsClientConfigs, + /// Pre-built semaphore for TCP/UDP socket quota enforcement (global + socket-specific). + /// `None` means no limits are configured. + socket_connection_semaphore: Option, + /// App-wide semaphore capping total concurrent outbound connections across ALL types. + /// `None` means unlimited. + global_connection_semaphore: Option>, + /// The configured global connection limit (for warning comparisons in other factors). + max_total_connections: Option, +} + +/// Builds a [`ConnectionSemaphore`] for an outbound factor, incorporating the optional global +/// connection limit from the networking factor's app state. +/// +/// Emits a warning when the per-factor limit exceeds the global cap (the global limit would +/// be the effective ceiling in that case). +pub fn build_connection_semaphore( + networking: Option<&AppState>, + factor: &'static str, + factor_limit: Option, +) -> ConnectionSemaphore { + if let (Some(per_factor), Some(global_limit)) = ( + factor_limit, + networking.and_then(|n| n.max_total_connections), + ) && per_factor > global_limit + { + tracing::warn!( + "outbound_{factor} max_connections ({per_factor}) exceeds global \ + max_total_connections ({global_limit}); the global limit will be the \ + effective cap" + ); + } + ConnectionSemaphore::new( + networking.and_then(|n| n.global_connection_semaphore.clone()), + factor_limit, + factor, + ) } pub struct InstanceBuilder { diff --git a/crates/factor-outbound-networking/src/runtime_config.rs b/crates/factor-outbound-networking/src/runtime_config.rs index 887742febb..278882a3e2 100644 --- a/crates/factor-outbound-networking/src/runtime_config.rs +++ b/crates/factor-outbound-networking/src/runtime_config.rs @@ -12,6 +12,12 @@ pub struct RuntimeConfig { pub block_private_networks: bool, /// TLS client configs pub client_tls_configs: Vec, + /// Maximum number of outbound TCP/UDP socket connections across all instances of this app. + /// `None` means unlimited (default). + pub max_socket_connections: Option, + /// Maximum number of outbound connections across ALL connection types (global cap). + /// `None` means unlimited (default). + pub max_total_connections: Option, } /// TLS configuration for one or more component(s) and host(s). diff --git a/crates/factor-outbound-networking/src/runtime_config/spin.rs b/crates/factor-outbound-networking/src/runtime_config/spin.rs index f41c4a0d75..e97fba28cd 100644 --- a/crates/factor-outbound-networking/src/runtime_config/spin.rs +++ b/crates/factor-outbound-networking/src/runtime_config/spin.rs @@ -46,52 +46,49 @@ impl SpinRuntimeConfig { &self, table: &impl GetTomlValue, ) -> anyhow::Result> { - let maybe_blocked_networks = self - .blocked_networks_from_table(table) + let maybe_outbound_networking = self + .outbound_networking_from_table(table) .context("failed to parse [outbound_networking] table")?; let maybe_tls_configs = self .tls_configs_from_table(table) .context("failed to parse [[client_tls]] table")?; - if maybe_blocked_networks.is_none() && maybe_tls_configs.is_none() { + if maybe_outbound_networking.is_none() && maybe_tls_configs.is_none() { return Ok(None); } - let (blocked_ip_networks, block_private_networks) = - maybe_blocked_networks.unwrap_or_default(); - - let client_tls_configs = maybe_tls_configs.unwrap_or_default(); + let outbound_networking = maybe_outbound_networking.unwrap_or_default(); + let mut blocked_ip_networks = vec![]; + let mut block_private_networks = false; + for block_network in outbound_networking.block_networks { + match block_network { + CidrOrPrivate::Cidr(ip_network) => blocked_ip_networks.push(ip_network), + CidrOrPrivate::Private => { + block_private_networks = true; + } + } + } let runtime_config = super::RuntimeConfig { blocked_ip_networks, block_private_networks, - client_tls_configs, + client_tls_configs: maybe_tls_configs.unwrap_or_default(), + max_socket_connections: outbound_networking.max_socket_connections, + max_total_connections: outbound_networking.max_total_connections, }; Ok(Some(runtime_config)) } - /// Attempts to parse (blocked_ip_networks, block_private_networks) from a - /// `[outbound_networking]` table. - fn blocked_networks_from_table( + /// Attempts to parse the `[outbound_networking]` table. + fn outbound_networking_from_table( &self, table: &impl GetTomlValue, - ) -> anyhow::Result, bool)>> { + ) -> anyhow::Result> { let Some(value) = table.get("outbound_networking") else { return Ok(None); }; let outbound_networking: OutboundNetworkingToml = value.clone().try_into()?; - - let mut ip_networks = vec![]; - let mut private_networks = false; - for block_network in outbound_networking.block_networks { - match block_network { - CidrOrPrivate::Cidr(ip_network) => ip_networks.push(ip_network), - CidrOrPrivate::Private => { - private_networks = true; - } - } - } - Ok(Some((ip_networks, private_networks))) + Ok(Some(outbound_networking)) } fn tls_configs_from_table( @@ -225,6 +222,8 @@ fn deserialize_hosts<'de, D: Deserializer<'de>>(deserializer: D) -> Result, + max_socket_connections: Option, + max_total_connections: Option, } #[derive(Debug)] diff --git a/crates/factor-outbound-networking/tests/factor_test.rs b/crates/factor-outbound-networking/tests/factor_test.rs index ce7f0bd479..bdeaa8fa79 100644 --- a/crates/factor-outbound-networking/tests/factor_test.rs +++ b/crates/factor-outbound-networking/tests/factor_test.rs @@ -1,12 +1,58 @@ +use std::sync::Arc; +use std::time::Duration; + +use spin_factor_outbound_mqtt::{ClientCreator, MqttClient, OutboundMqttFactor}; use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::runtime_config::RuntimeConfig; use spin_factor_outbound_networking::runtime_config::spin::SpinRuntimeConfig; use spin_factor_variables::VariablesFactor; use spin_factor_wasi::{DummyFilesMounter, WasiFactor}; -use spin_factors::{RuntimeFactors, anyhow}; +use spin_factors::anyhow::Context as _; +use spin_factors::{App, RuntimeFactors, anyhow}; use spin_factors_test::{TestEnvironment, toml}; +use spin_world::spin::mqtt::mqtt as v3_mqtt; +use spin_world::v2::mqtt as v2_mqtt; use wasmtime_wasi::p2::bindings::sockets::instance_network::Host; +use wasmtime_wasi::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily}; +use wasmtime_wasi::p2::bindings::sockets::tcp as p2_tcp; +use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; use wasmtime_wasi::sockets::SocketAddrUse; +struct MockMqttClient; + +#[async_trait::async_trait] +impl MqttClient for MockMqttClient { + async fn publish_bytes( + &self, + _topic: String, + _qos: v3_mqtt::Qos, + _payload: Vec, + ) -> anyhow::Result<(), v3_mqtt::Error> { + Ok(()) + } +} + +impl ClientCreator for MockMqttClient { + fn create( + &self, + _address: String, + _username: String, + _password: String, + _keep_alive_interval: Duration, + ) -> anyhow::Result, v3_mqtt::Error> { + Ok(Arc::new(MockMqttClient)) + } +} + +#[derive(RuntimeFactors)] +struct TestFactorsWithMqtt { + wasi: WasiFactor, + variables: VariablesFactor, + networking: OutboundNetworkingFactor, + mqtt: OutboundMqttFactor, +} + #[derive(RuntimeFactors)] struct TestFactors { wasi: WasiFactor, @@ -81,3 +127,379 @@ async fn wasi_factor_is_optional() -> anyhow::Result<()> { .await?; Ok(()) } + +#[tokio::test] +async fn socket_quota_blocks_excess_connections() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_socket_connections: Some(2), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // First two connections should be accepted (non-blocking connect initiated) + let net1 = sockets.instance_network()?; + let sock1 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock1, net1, addr.into()).await?; + + let net2 = sockets.instance_network()?; + let sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock2, net2, addr.into()).await?; + + // Third should fail — quota exhausted + let net3 = sockets.instance_network()?; + let sock3 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock3, net3, addr.into()) + .await + .unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); + Ok(()) +} + +#[tokio::test] +async fn socket_quota_releases_on_instance_drop() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_socket_connections: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let locked_app = env.build_locked_app().await?; + let TestEnvironment { + factors, + runtime_config, + .. + } = env; + let app = App::new("test-app", locked_app); + let configured_app = factors.configure_app(app, runtime_config)?; + let component_id = configured_app + .app() + .components() + .last() + .context("no components")? + .id() + .to_string(); + + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // First instance: fill the quota (1 socket) + { + let builders = factors.prepare(&configured_app, &component_id)?; + let mut state = factors.build_instance_state(builders)?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; + // sockets state dropped here releasing the permit back to the semaphore + } + + // Second instance: quota should be fully available again + let builders = factors.prepare(&configured_app, &component_id)?; + let mut state = factors.build_instance_state(builders)?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; + Ok(()) +} + +#[tokio::test] +async fn no_socket_quota_allows_unlimited() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors).extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }); + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + for _ in 0..10 { + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()).await?; + } + Ok(()) +} + +#[tokio::test] +async fn socket_quota_still_enforces_allowed_hosts() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_socket_connections: Some(10), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + + // Allowed host succeeds + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let allowed_addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, allowed_addr.into()).await?; + + // Disallowed host is rejected even with quota available + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let disallowed_addr: std::net::SocketAddr = "1.2.3.4:80".parse().unwrap(); + assert!( + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, disallowed_addr.into()) + .await + .is_err() + ); + Ok(()) +} + +#[tokio::test] +async fn socket_quota_releases_on_socket_drop() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_socket_connections: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // Acquire the only permit via start_connect. Save the rep so we can reconstruct + // a handle afterwards — start_connect consumes the Resource but leaves the socket + // alive in the ResourceTable. + let net1 = sockets.instance_network()?; + let sock1 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let sock1_rep = sock1.rep(); + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock1, net1, addr.into()).await?; + + // A second start_connect should fail while the permit is held. + let net2 = sockets.instance_network()?; + let sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock2, net2, addr.into()) + .await + .unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); + + // Explicitly drop sock1 before finish_connect — this should release the permit. + let sock1_handle = + wasmtime::component::Resource::::new_own(sock1_rep); + p2_tcp::HostTcpSocket::drop(&mut sockets, sock1_handle)?; + + // After the drop the quota is free again, so a new start_connect must succeed. + let net3 = sockets.instance_network()?; + let sock3 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock3, net3, addr.into()).await?; + + Ok(()) +} + +#[tokio::test] +async fn socket_quota_blocks_excess_udp_sockets() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_socket_connections: Some(2), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + + // First two UDP socket creations should succeed. + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + + // Third should fail — quota exhausted. + let err = + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4).unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); + Ok(()) +} + +#[tokio::test] +async fn socket_quota_shared_between_tcp_and_udp() -> anyhow::Result<()> { + let factors = TestFactors { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsRuntimeConfig { + networking: Some(RuntimeConfig { + max_socket_connections: Some(2), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + + // Consume one permit with a TCP connection. + let net = sockets.instance_network()?; + let tcp_sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, tcp_sock, net, addr.into()).await?; + + // Consume the second permit with a UDP socket — quota now full. + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + + // Any further allocation must fail — shared quota exhausted. + // UDP: + let err = + p2_udp_create::Host::create_udp_socket(&mut sockets, IpAddressFamily::Ipv4).unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); + // TCP: + let net = sockets.instance_network()?; + let tcp_sock2 = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, tcp_sock2, net, addr.into()) + .await + .unwrap_err(); + assert_eq!(err.downcast_ref(), Some(&ErrorCode::NewSocketLimit)); + Ok(()) +} + +/// Verifies that the global connection limit is shared across factors: a permit +/// held by an MQTT connection blocks a WASI TCP socket (and vice-versa). +#[tokio::test] +async fn global_connection_limit_enforced_across_factors() -> anyhow::Result<()> { + use v2_mqtt::HostConnection as _; + + let factors = TestFactorsWithMqtt { + wasi: WasiFactor::new(DummyFilesMounter), + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor::new(), + mqtt: OutboundMqttFactor::new(Arc::new(MockMqttClient)), + }; + let env = TestEnvironment::new(factors) + .extend_manifest(toml! { + [component.test-component] + source = "does-not-exist.wasm" + allowed_outbound_hosts = ["mqtt://*:*", "*://123.0.2.1:12345"] + }) + .runtime_config(TestFactorsWithMqttRuntimeConfig { + networking: Some(RuntimeConfig { + max_total_connections: Some(1), + ..Default::default() + }), + ..Default::default() + })?; + + let mut state = env.build_instance_state().await?; + + // Acquire the single global permit via an MQTT connection. + let conn = state + .mqtt + .open( + "mqtt://mqtt.test:1883".to_string(), + "username".to_string(), + "password".to_string(), + 1, + ) + .await?; + + // With the global permit held by MQTT, a TCP socket start_connect must fail immediately. + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let addr: std::net::SocketAddr = "123.0.2.1:12345".parse().unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + let err = p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()) + .await + .unwrap_err(); + assert_eq!( + err.downcast_ref(), + Some(&ErrorCode::NewSocketLimit), + "TCP socket should fail while global permit is held by MQTT" + ); + drop(sockets); + + // Releasing the MQTT connection returns the global permit. + state.mqtt.drop(conn).await?; + + // Now the TCP socket start_connect must succeed. + let mut sockets = WasiFactor::get_sockets_impl(&mut state).unwrap(); + let net = sockets.instance_network()?; + let sock = p2_tcp_create::Host::create_tcp_socket(&mut sockets, IpAddressFamily::Ipv4)?; + p2_tcp::HostTcpSocket::start_connect(&mut sockets, sock, net, addr.into()) + .await + .expect("TCP socket should succeed after MQTT connection is released"); + + Ok(()) +} diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index 45dcc7a22f..b7cac624b8 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -7,6 +7,7 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } bytes = {workspace = true } +serde = { workspace = true } chrono = { workspace = true } deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] } futures = { workspace = true } @@ -27,7 +28,7 @@ spin-resource-table = { path = "../table" } spin-telemetry = { path = "../telemetry" } spin-wasi-async = { path = "../wasi-async" } spin-world = { path = "../world" } -tokio = { workspace = true, features = ["rt-multi-thread"] } +tokio = { workspace = true, features = ["rt-multi-thread", "sync"] } tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 5faf7663a8..f18a721b31 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -24,6 +24,11 @@ impl InstanceState { address: &str, root_ca: Option, ) -> Result, v4::Error> { + let permit = self.semaphore.acquire().await.map_err(|_| { + let err = v4::Error::ConnectionFailed("too many connections".into()); + traces::mark_as_error(&err, Some(Blame::Guest)); + err + })?; let client = self .client_factory .get_client(address, root_ca) @@ -37,7 +42,7 @@ impl InstanceState { err })?; self.connections - .push(client) + .push((client, permit)) .map_err(|_| { // The guest exceeded the host-imposed connection limit. let err = v4::Error::ConnectionFailed("too many connections".into()); @@ -51,13 +56,16 @@ impl InstanceState { &self, connection: Resource, ) -> Result<&CF::Client, v4::Error> { - self.connections.get(connection.rep()).ok_or_else(|| { - // The connection table is managed entirely by the host, so a - // missing handle indicates a host-side bug, not a guest mistake. - let err = v4::Error::ConnectionFailed("no connection found".into()); - traces::mark_as_error(&err, Some(Blame::Host)); - err - }) + self.connections + .get(connection.rep()) + .map(|(client, _permit)| client) + .ok_or_else(|| { + // The connection table is managed entirely by the host, so a + // missing handle indicates a host-side bug, not a guest mistake. + let err = v4::Error::ConnectionFailed("no connection found".into()); + traces::mark_as_error(&err, Some(Blame::Host)); + err + }) } fn allowed_host_checker(&self) -> AllowedHostChecker { @@ -260,7 +268,10 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio ) -> Result { let client = accessor.with(|mut access| { let host = access.get(); - host.connections.get(connection.rep()).unwrap().clone() + host.connections + .get(connection.rep()) + .map(|(client, _permit)| client.clone()) + .unwrap() }); client @@ -286,7 +297,10 @@ impl spin_world::spin::postgres4_2_0::postgres::HostConnectio > { let client = accessor.with(|mut access| { let host = access.get(); - host.connections.get(connection.rep()).unwrap().clone() + host.connections + .get(connection.rep()) + .map(|(client, _permit)| client.clone()) + .unwrap() }); let QueryAsyncResult { @@ -368,11 +382,17 @@ impl crate::PgFactorData { address: &str, root_ca: Option, ) -> Result, v4::Error> { - let cf = accessor.with(|mut access| { + let (cf, semaphore) = accessor.with(|mut access| { let host = access.get(); - host.client_factory.clone() + (host.client_factory.clone(), host.semaphore.clone()) }); + let permit = semaphore.acquire().await.map_err(|_| { + let err = v4::Error::ConnectionFailed("too many connections".into()); + traces::mark_as_error(&err, Some(Blame::Guest)); + err + })?; + let client = cf.get_client(address, root_ca).await.map_err(|e| { let err = v4::Error::ConnectionFailed(format!("{e:?}")); traces::mark_as_error(&err, Some(Blame::Guest)); @@ -382,7 +402,7 @@ impl crate::PgFactorData { accessor.with(|mut access| { let host = access.get(); host.connections - .push(client) + .push((client, permit)) .map_err(|_| { let err = v4::Error::ConnectionFailed("too many connections".into()); traces::mark_as_error(&err, Some(Blame::Guest)); @@ -429,7 +449,7 @@ impl v4::Host for InstanceState { } } -/// Delegate a function call to the v3::HostConnection implementation +/// Delegate a function call to the v4::HostConnection implementation macro_rules! delegate { ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ $self.ensure_address_allowed(&$address).await?; @@ -437,9 +457,14 @@ macro_rules! delegate { Ok(c) => c, Err(e) => return Err(e.into()), }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(|e| e.into()) + .map_err(|e| e.into()); + $self.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index d20cfe492f..b36198d3bc 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,14 +1,19 @@ mod allowed_hosts; pub mod client; mod host; +pub mod runtime_config; mod types; -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; +use std::sync::Arc; use allowed_hosts::AllowedHostChecker; use client::ClientFactory; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::{ + ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, +}; use spin_factors::{ ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, }; @@ -17,9 +22,15 @@ pub struct OutboundPgFactor { _phantom: std::marker::PhantomData, } +pub struct AppState { + pub client_factories: HashMap>, + /// Semaphore to limit concurrent outbound PostgreSQL connections. + pub semaphore: ConnectionSemaphore, +} + impl Factor for OutboundPgFactor { - type RuntimeConfig = (); - type AppState = HashMap>; + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { @@ -36,13 +47,22 @@ impl Factor for OutboundPgFactor { fn configure_app( &self, - ctx: ConfigureAppContext, + mut ctx: ConfigureAppContext, ) -> anyhow::Result { + let config = ctx.take_runtime_config().unwrap_or_default(); let mut client_factories = HashMap::new(); for comp in ctx.app().components() { client_factories.insert(comp.id().to_string(), Arc::new(CF::default())); } - Ok(client_factories) + + Ok(AppState { + client_factories, + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "pg", + config.max_connections, + ), + }) } fn prepare( @@ -53,7 +73,11 @@ impl Factor for OutboundPgFactor { .instance_builder::()? .allowed_hosts(); let otel = OtelFactorState::from_prepare_context(&mut ctx)?; - let cf = ctx.app_state().get(ctx.app_component().id()).unwrap(); + let cf = ctx + .app_state() + .client_factories + .get(ctx.app_component().id()) + .unwrap(); Ok(InstanceState { allowed_host_checker: AllowedHostChecker::new(allowed_hosts), @@ -61,6 +85,7 @@ impl Factor for OutboundPgFactor { connections: Default::default(), otel, builders: Default::default(), + semaphore: ctx.app_state().semaphore.clone(), }) } } @@ -82,9 +107,13 @@ impl OutboundPgFactor { pub struct InstanceState { allowed_host_checker: AllowedHostChecker, client_factory: Arc, - connections: spin_resource_table::Table, + connections: spin_resource_table::Table<( + CF::Client, + spin_factor_outbound_networking::ConnectionPermit, + )>, otel: OtelFactorState, builders: spin_resource_table::Table, + pub semaphore: ConnectionSemaphore, } impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/src/runtime_config.rs b/crates/factor-outbound-pg/src/runtime_config.rs new file mode 100644 index 0000000000..7cf9745c0e --- /dev/null +++ b/crates/factor-outbound-pg/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound PostgreSQL. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound PostgreSQL connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-pg/src/runtime_config/spin.rs b/crates/factor-outbound-pg/src/runtime_config/spin.rs new file mode 100644 index 0000000000..b82c60ea43 --- /dev/null +++ b/crates/factor-outbound-pg/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound PostgreSQL from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_pg] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_pg) = table.get("outbound_pg") { + let toml = outbound_pg.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundPgToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/factor-outbound-redis/Cargo.toml b/crates/factor-outbound-redis/Cargo.toml index 6518459d79..55e641ce6e 100644 --- a/crates/factor-outbound-redis/Cargo.toml +++ b/crates/factor-outbound-redis/Cargo.toml @@ -7,13 +7,14 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } redis = { workspace = true , features = ["tokio-comp", "tokio-native-tls-comp", "aio"] } +serde = { workspace = true } spin-core = { path = "../core" } spin-factor-otel = { path = "../factor-otel" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-resource-table = { path = "../table" } spin-world = { path = "../world" } -tokio = { workspace = true } +tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } [dev-dependencies] diff --git a/crates/factor-outbound-redis/src/host.rs b/crates/factor-outbound-redis/src/host.rs index 61fd05708b..0454e98537 100644 --- a/crates/factor-outbound-redis/src/host.rs +++ b/crates/factor-outbound-redis/src/host.rs @@ -6,6 +6,7 @@ use redis::io::AsyncDNSResolver; use redis::{AsyncCommands, FromRedisValue, Value, aio::MultiplexedConnection}; use spin_core::wasmtime::component::{Accessor, Resource}; use spin_factor_otel::OtelFactorState; +use spin_factor_outbound_networking::ConnectionSemaphore; use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks; use spin_world::MAX_HOST_BUFFERED_BYTES; use spin_world::spin::redis::redis as v3; @@ -19,7 +20,11 @@ use crate::allowed_hosts::AllowedHostChecker; pub struct InstanceState { pub(crate) allowed_host_checker: AllowedHostChecker, pub blocked_networks: BlockedNetworks, - pub connections: spin_resource_table::Table, + pub connections: spin_resource_table::Table<( + MultiplexedConnection, + spin_factor_outbound_networking::ConnectionPermit, + )>, + pub semaphore: ConnectionSemaphore, pub otel: OtelFactorState, } @@ -32,6 +37,11 @@ impl InstanceState { &mut self, address: String, ) -> Result, v2::Error> { + let permit = self + .semaphore + .acquire() + .await + .map_err(|_| v2::Error::TooManyConnections)?; let config = AsyncConnectionConfig::new() .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone())); let conn = redis::Client::open(address.as_str()) @@ -40,7 +50,7 @@ impl InstanceState { .await .map_err(other_error_v2)?; self.connections - .push(conn) + .push((conn, permit)) .map(Resource::new_own) .map_err(|_| v2::Error::TooManyConnections) } @@ -51,6 +61,7 @@ impl InstanceState { ) -> Result<&mut MultiplexedConnection, v2::Error> { self.connections .get_mut(connection.rep()) + .map(|(conn, _permit)| conn) .ok_or(v2::Error::Other( "could not find connection for resource".into(), )) @@ -62,7 +73,7 @@ impl InstanceState { ) -> Result { self.connections .get(connection.rep()) - .cloned() + .map(|(conn, _permit)| conn.clone()) .ok_or(v3::Error::Other( "could not find connection for resource".into(), )) @@ -229,12 +240,13 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor: &Accessor, address: String, ) -> Result, v3::Error> { - let (allowed_host_checker, blocked_networks) = accessor.with(|mut access| { + let (allowed_host_checker, blocked_networks, semaphore) = accessor.with(|mut access| { let host = access.get(); host.otel.reparent_tracing_span(); ( host.allowed_host_checker.clone(), host.blocked_networks.clone(), + host.semaphore.clone(), ) }); @@ -246,6 +258,11 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { return Err(v3::Error::InvalidAddress); } + let permit = semaphore + .acquire() + .await + .map_err(|_| v3::Error::TooManyConnections)?; + let config = AsyncConnectionConfig::new().set_dns_resolver(SpinDnsResolver(blocked_networks)); let conn = redis::Client::open(address.as_str()) @@ -257,7 +274,7 @@ impl v3::HostConnectionWithStore for crate::RedisFactorData { accessor.with(|mut access| { let host = access.get(); host.connections - .push(conn) + .push((conn, permit)) .map(Resource::new_own) .map_err(|_| v3::Error::TooManyConnections) }) @@ -532,9 +549,14 @@ macro_rules! delegate { Ok(c) => c, Err(_) => return Err(v1::Error::Error), }; - ::$name($self, connection, $($arg),*) + // v1 has no persistent connections, so remove the table entry immediately + // after the call to release the semaphore permit. + let rep = connection.rep(); + let result = ::$name($self, connection, $($arg),*) .await - .map_err(|_| v1::Error::Error) + .map_err(|_| v1::Error::Error); + $self.connections.remove(rep); + result }}; } diff --git a/crates/factor-outbound-redis/src/lib.rs b/crates/factor-outbound-redis/src/lib.rs index 494c5ca800..47a05ff745 100644 --- a/crates/factor-outbound-redis/src/lib.rs +++ b/crates/factor-outbound-redis/src/lib.rs @@ -1,9 +1,13 @@ mod allowed_hosts; mod host; +pub mod runtime_config; use host::InstanceState; +use runtime_config::RuntimeConfig; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_networking::{ + ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore, +}; use spin_factors::{ ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, anyhow, @@ -24,9 +28,14 @@ impl OutboundRedisFactor { } } +pub struct AppState { + /// Semaphore to limit concurrent outbound Redis connections. + pub semaphore: ConnectionSemaphore, +} + impl Factor for OutboundRedisFactor { - type RuntimeConfig = (); - type AppState = (); + type RuntimeConfig = RuntimeConfig; + type AppState = AppState; type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { @@ -38,9 +47,17 @@ impl Factor for OutboundRedisFactor { fn configure_app( &self, - _ctx: ConfigureAppContext, + mut ctx: ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + let config = ctx.take_runtime_config().unwrap_or_default(); + + Ok(AppState { + semaphore: build_connection_semaphore( + ctx.app_state::().ok(), + "redis", + config.max_connections, + ), + }) } fn prepare( @@ -54,6 +71,7 @@ impl Factor for OutboundRedisFactor { allowed_host_checker: AllowedHostChecker::new(outbound_networking.allowed_hosts()), blocked_networks: outbound_networking.blocked_networks(), connections: spin_resource_table::Table::new(1024), + semaphore: ctx.app_state().semaphore.clone(), otel, }) } diff --git a/crates/factor-outbound-redis/src/runtime_config.rs b/crates/factor-outbound-redis/src/runtime_config.rs new file mode 100644 index 0000000000..38d2d7ea7d --- /dev/null +++ b/crates/factor-outbound-redis/src/runtime_config.rs @@ -0,0 +1,8 @@ +pub mod spin; + +/// Runtime configuration for outbound Redis. +#[derive(Default)] +pub struct RuntimeConfig { + /// If set, limits the number of concurrent outbound Redis connections. + pub max_connections: Option, +} diff --git a/crates/factor-outbound-redis/src/runtime_config/spin.rs b/crates/factor-outbound-redis/src/runtime_config/spin.rs new file mode 100644 index 0000000000..82c0efeaff --- /dev/null +++ b/crates/factor-outbound-redis/src/runtime_config/spin.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; +use spin_factors::runtime_config::toml::GetTomlValue; + +/// Get the runtime configuration for outbound Redis from a TOML table. +/// +/// Expects table to be in the format: +/// ```toml +/// [outbound_redis] +/// max_connections = 10 # optional, defaults to unlimited +/// ``` +pub fn config_from_table( + table: &impl GetTomlValue, +) -> anyhow::Result> { + if let Some(outbound_redis) = table.get("outbound_redis") { + let toml = outbound_redis.clone().try_into::()?; + Ok(Some(super::RuntimeConfig { + max_connections: toml.max_connections, + })) + } else { + Ok(None) + } +} + +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +struct OutboundRedisToml { + #[serde(default)] + max_connections: Option, +} diff --git a/crates/factor-wasi/Cargo.toml b/crates/factor-wasi/Cargo.toml index 3647670beb..e957b22ba7 100644 --- a/crates/factor-wasi/Cargo.toml +++ b/crates/factor-wasi/Cargo.toml @@ -8,8 +8,10 @@ edition = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } spin-common = { path = "../common" } +spin-connection-semaphore = { path = "../connection-semaphore" } spin-factors = { path = "../factors" } -tokio = { workspace = true } +tokio = { workspace = true, features = ["sync"] } +tracing = { workspace = true } wasmtime = { workspace = true } wasmtime-wasi = { workspace = true } diff --git a/crates/factor-wasi/src/lib.rs b/crates/factor-wasi/src/lib.rs index f86332764b..ff7e5fc1ea 100644 --- a/crates/factor-wasi/src/lib.rs +++ b/crates/factor-wasi/src/lib.rs @@ -1,4 +1,5 @@ mod io; +pub mod sockets; pub mod spin; mod wasi_2023_10_18; mod wasi_2023_11_10; @@ -8,6 +9,7 @@ use std::{ io::{Read, Write}, net::SocketAddr, path::Path, + sync::Arc, }; use io::{PipeReadStream, PipedWriteStream}; @@ -23,6 +25,7 @@ use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; use wasmtime_wasi::{DirPerms, FilePerms, ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView}; +pub use sockets::{SocketPermitState, SpinSockets, SpinSocketsView}; pub use wasmtime_wasi::sockets::SocketAddrUse; pub struct WasiFactor { @@ -58,11 +61,14 @@ impl WasiFactor { pub fn get_sockets_impl( runtime_instance_state: &mut impl RuntimeFactorsInstanceState, - ) -> Option> { + ) -> Option> { let (state, table) = runtime_instance_state.get_with_table::()?; - Some(WasiSocketsCtxView { - ctx: state.ctx.sockets(), - table, + Some(SpinSocketsView { + inner: WasiSocketsCtxView { + ctx: state.ctx.sockets(), + table, + }, + permit_state: state.socket_permit_state.clone(), }) } } @@ -176,6 +182,27 @@ trait InitContextExt: InitContext { add_to_linker(self.linker(), &O::default(), Self::get_sockets) } + fn get_spin_sockets(data: &mut Self::StoreData) -> SpinSocketsView<'_> { + let (state, table) = Self::get_data_with_table(data); + SpinSocketsView { + inner: WasiSocketsCtxView { + ctx: state.ctx.sockets(), + table, + }, + permit_state: state.socket_permit_state.clone(), + } + } + + fn link_spin_sockets_bindings( + &mut self, + add_to_linker: fn( + &mut wasmtime::component::Linker, + fn(&mut Self::StoreData) -> SpinSocketsView<'_>, + ) -> wasmtime::Result<()>, + ) -> wasmtime::Result<()> { + add_to_linker(self.linker(), Self::get_spin_sockets) + } + fn link_io_bindings( &mut self, add_to_linker: fn( @@ -208,7 +235,7 @@ trait InitContextExt: InitContext { fn(&mut Self::StoreData) -> WasiClocksCtxView<'_>, fn(&mut Self::StoreData) -> WasiCliCtxView<'_>, fn(&mut Self::StoreData) -> WasiFilesystemCtxView<'_>, - fn(&mut Self::StoreData) -> WasiSocketsCtxView<'_>, + fn(&mut Self::StoreData) -> SpinSocketsView<'_>, ) -> anyhow::Result<()>, ) -> anyhow::Result<()> { add_to_linker( @@ -218,7 +245,7 @@ trait InitContextExt: InitContext { Self::get_clocks, Self::get_cli, Self::get_filesystem, - Self::get_sockets, + Self::get_spin_sockets, ) } } @@ -294,13 +321,17 @@ impl Factor for WasiFactor { ctx.link_cli_bindings(p3::bindings::cli::terminal_stdout::add_to_linker::<_, WasiCli>)?; ctx.link_cli_bindings(p2::bindings::cli::terminal_stderr::add_to_linker::<_, WasiCli>)?; ctx.link_cli_bindings(p3::bindings::cli::terminal_stderr::add_to_linker::<_, WasiCli>)?; - ctx.link_sockets_bindings(p2::bindings::sockets::tcp::add_to_linker::<_, WasiSockets>)?; - ctx.link_sockets_bindings( - p2::bindings::sockets::tcp_create_socket::add_to_linker::<_, WasiSockets>, + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::tcp::add_to_linker::<_, SpinSockets>, )?; - ctx.link_sockets_bindings(p2::bindings::sockets::udp::add_to_linker::<_, WasiSockets>)?; - ctx.link_sockets_bindings( - p2::bindings::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>, + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>, + )?; + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::udp::add_to_linker::<_, SpinSockets>, + )?; + ctx.link_spin_sockets_bindings( + p2::bindings::sockets::udp_create_socket::add_to_linker::<_, SpinSockets>, )?; ctx.link_sockets_bindings( p2::bindings::sockets::instance_network::add_to_linker::<_, WasiSockets>, @@ -314,6 +345,7 @@ impl Factor for WasiFactor { ctx.link_sockets_bindings( p3::bindings::sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>, )?; + // TODO(rylev): switch to SpinSockets once possible ctx.link_sockets_bindings(p3::bindings::sockets::types::add_to_linker::<_, WasiSockets>)?; ctx.link_all_bindings(wasi_2023_10_18::add_to_linker)?; @@ -339,7 +371,10 @@ impl Factor for WasiFactor { self.files_mounter .mount_files(ctx.app_component(), mount_ctx)?; - let mut builder = InstanceBuilder { ctx: wasi_ctx }; + let mut builder = InstanceBuilder { + ctx: wasi_ctx, + socket_permit_state: None, + }; // Apply environment variables builder.env(ctx.app_component().environment()); @@ -396,6 +431,7 @@ impl MountFilesContext<'_> { pub struct InstanceBuilder { ctx: WasiCtxBuilder, + socket_permit_state: Option>, } impl InstanceBuilder { @@ -466,14 +502,23 @@ impl FactorInstanceBuilder for InstanceBuilder { type InstanceState = InstanceState; fn build(self) -> anyhow::Result { - let InstanceBuilder { ctx: mut wasi_ctx } = self; + let InstanceBuilder { + ctx: mut wasi_ctx, + socket_permit_state, + } = self; Ok(InstanceState { ctx: wasi_ctx.build(), + socket_permit_state, }) } } impl InstanceBuilder { + /// Sets the socket permit state for per-connection quota tracking. + pub fn set_socket_permit_state(&mut self, state: Arc) { + self.socket_permit_state = Some(state); + } + pub fn outbound_socket_addr_check(&mut self, check: F) where F: Fn(SocketAddr, SocketAddrUse) -> Fut + Send + Sync + Clone + 'static, @@ -496,4 +541,5 @@ impl InstanceBuilder { pub struct InstanceState { ctx: WasiCtx, + socket_permit_state: Option>, } diff --git a/crates/factor-wasi/src/sockets.rs b/crates/factor-wasi/src/sockets.rs new file mode 100644 index 0000000000..9980c9c2ff --- /dev/null +++ b/crates/factor-wasi/src/sockets.rs @@ -0,0 +1,538 @@ +//! Socket quota tracking and WASI socket host implementations. +//! +//! This module provides [`SocketPermitState`], [`SpinSocketsView`], and +//! [`SpinSockets`] — the types needed to intercept WASI TCP/UDP socket +//! creation and enforce a per-app cap on the number of concurrently open +//! sockets. + +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use spin_connection_semaphore::{ConnectionPermit, ConnectionSemaphore}; +use wasmtime::component::{HasData, Resource}; +use wasmtime_wasi::p2::bindings::sockets::network::{ + ErrorCode as SocketErrorCode, Host as NetworkHost, Network, +}; +use wasmtime_wasi::p2::bindings::sockets::tcp::{self as p2_tcp, IpSocketAddress, ShutdownType}; +use wasmtime_wasi::p2::bindings::sockets::tcp_create_socket as p2_tcp_create; +use wasmtime_wasi::p2::bindings::sockets::udp as p2_udp; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; +use wasmtime_wasi::p2::{DynInputStream, DynOutputStream, DynPollable}; +use wasmtime_wasi::sockets::{TcpSocket, UdpSocket, WasiSocketsCtxView}; + +/// Shared state for tracking per-socket semaphore permits. Permits are +/// acquired when a socket is allocated (at `start_connect` for TCP, at +/// `create_udp_socket` for UDP) and released when the socket resource is dropped. +pub struct SocketPermitState { + semaphore: ConnectionSemaphore, + /// Active permits keyed by socket resource rep, released when the resource is dropped. + active: Mutex>, +} + +impl SocketPermitState { + pub fn new(semaphore: ConnectionSemaphore) -> Arc { + Arc::new(Self { + semaphore, + active: Mutex::new(HashMap::new()), + }) + } +} + +/// A view over WASI socket state that carries an optional per-instance socket +/// permit store, enabling per-connection quota tracking. +pub struct SpinSocketsView<'a> { + pub(crate) inner: WasiSocketsCtxView<'a>, + pub(crate) permit_state: Option>, +} + +impl<'a> std::ops::Deref for SpinSocketsView<'a> { + type Target = WasiSocketsCtxView<'a>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl std::ops::DerefMut for SpinSocketsView<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +/// [`HasData`] accessor for [`SpinSocketsView`], used in place of [`WasiSockets`] +/// when registering TCP socket bindings so that `start_connect` and `drop` can +/// participate in socket quota tracking. +pub struct SpinSockets; + +impl HasData for SpinSockets { + type Data<'a> = SpinSocketsView<'a>; +} + +impl SpinSocketsView<'_> { + /// Attempts to acquire a connection permit from the semaphore. + /// + /// Returns `Ok(None)` when no quota is configured, `Ok(Some(permit))` on + /// success, or `Err(())` when the quota is exhausted. + /// + /// The returned permit is unregistered — call [`Self::register_permit`] once + /// the socket resource rep is known to tie its lifetime to the socket. + pub(crate) fn try_acquire(&self) -> Result, ()> { + let Some(state) = &self.permit_state else { + return Ok(None); + }; + state.semaphore.try_acquire().map(Some).ok_or(()) + } + + /// Registers `permit` under `socket_rep` so it is held until the socket is + /// dropped. No-op when `permit` is `None` (no quota configured). + pub(crate) fn register_permit(&self, socket_rep: u32, permit: Option) { + let (Some(state), Some(permit)) = (&self.permit_state, permit) else { + return; + }; + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .insert(socket_rep, permit); + } + + /// Releases the connection permit for `socket_rep`, if any. + pub(crate) fn release_permit(&self, socket_rep: u32) { + if let Some(state) = &self.permit_state { + state + .active + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(&socket_rep); + } + } +} + +impl p2_tcp::Host for SpinSocketsView<'_> {} + +impl p2_tcp::HostTcpSocket for SpinSocketsView<'_> { + async fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::start_bind(&mut self.inner, this, network, local_address).await + } + + fn finish_bind(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::finish_bind(&mut self.inner, this) + } + + async fn start_connect( + &mut self, + this: Resource, + network: Resource, + remote_address: IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + let socket_rep = this.rep(); + // Unlike outbound HTTP (which queues when its permit pool is exhausted), + // sockets fail immediately. Waiting would risk deadlock if a component + // holds sockets open across async yield points, and raw-socket callers + // are better positioned to implement their own retry logic. + let Ok(permit) = self.try_acquire() else { + tracing::warn!("TCP socket connection refused: connection quota exhausted"); + return Err(SocketErrorCode::NewSocketLimit.into()); + }; + let result = + p2_tcp::HostTcpSocket::start_connect(&mut self.inner, this, network, remote_address) + .await; + if result.is_ok() { + self.register_permit(socket_rep, permit); + } + // On error, `permit` is dropped here, automatically releasing the semaphore slot. + result + } + + fn finish_connect( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<(Resource, Resource)> + { + p2_tcp::HostTcpSocket::finish_connect(&mut self.inner, this) + } + + fn start_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::start_listen(&mut self.inner, this) + } + + fn finish_listen(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::finish_listen(&mut self.inner, this) + } + + fn accept( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<( + Resource, + Resource, + Resource, + )> { + p2_tcp::HostTcpSocket::accept(&mut self.inner, this) + } + + fn local_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::local_address(&mut self.inner, this) + } + + fn remote_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::remote_address(&mut self.inner, this) + } + + fn is_listening(&mut self, this: Resource) -> wasmtime::Result { + p2_tcp::HostTcpSocket::is_listening(&mut self.inner, this) + } + + fn address_family( + &mut self, + this: Resource, + ) -> wasmtime::Result { + p2_tcp::HostTcpSocket::address_family(&mut self.inner, this) + } + + fn set_listen_backlog_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_listen_backlog_size(&mut self.inner, this, value) + } + + fn keep_alive_enabled( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_enabled(&mut self.inner, this) + } + + fn set_keep_alive_enabled( + &mut self, + this: Resource, + value: bool, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_enabled(&mut self.inner, this, value) + } + + fn keep_alive_idle_time( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_idle_time(&mut self.inner, this) + } + + fn set_keep_alive_idle_time( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_idle_time(&mut self.inner, this, value) + } + + fn keep_alive_interval( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_interval(&mut self.inner, this) + } + + fn set_keep_alive_interval( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_interval(&mut self.inner, this, value) + } + + fn keep_alive_count( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::keep_alive_count(&mut self.inner, this) + } + + fn set_keep_alive_count( + &mut self, + this: Resource, + value: u32, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_keep_alive_count(&mut self.inner, this, value) + } + + fn hop_limit(&mut self, this: Resource) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::hop_limit(&mut self.inner, this) + } + + fn set_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_hop_limit(&mut self.inner, this, value) + } + + fn receive_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::receive_buffer_size(&mut self.inner, this) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_receive_buffer_size(&mut self.inner, this, value) + } + + fn send_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_tcp::HostTcpSocket::send_buffer_size(&mut self.inner, this) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::set_send_buffer_size(&mut self.inner, this, value) + } + + fn subscribe(&mut self, this: Resource) -> wasmtime::Result> { + p2_tcp::HostTcpSocket::subscribe(&mut self.inner, this) + } + + fn shutdown( + &mut self, + this: Resource, + shutdown_type: ShutdownType, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_tcp::HostTcpSocket::shutdown(&mut self.inner, this, shutdown_type) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.release_permit(this.rep()); + p2_tcp::HostTcpSocket::drop(&mut self.inner, this) + } +} + +impl NetworkHost for SpinSocketsView<'_> { + fn convert_error_code( + &mut self, + error: wasmtime_wasi::p2::SocketError, + ) -> wasmtime::Result { + NetworkHost::convert_error_code(&mut self.inner, error) + } + + fn network_error_code( + &mut self, + err: Resource, + ) -> wasmtime::Result> { + NetworkHost::network_error_code(&mut self.inner, err) + } +} + +impl wasmtime_wasi::p2::bindings::sockets::network::HostNetwork for SpinSocketsView<'_> { + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + wasmtime_wasi::p2::bindings::sockets::network::HostNetwork::drop(&mut self.inner, this) + } +} + +impl p2_tcp_create::Host for SpinSocketsView<'_> { + fn create_tcp_socket( + &mut self, + address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, + ) -> wasmtime_wasi::p2::SocketResult> { + p2_tcp_create::Host::create_tcp_socket(&mut self.inner, address_family) + } +} + +impl p2_udp::Host for SpinSocketsView<'_> {} + +impl p2_udp::HostUdpSocket for SpinSocketsView<'_> { + async fn start_bind( + &mut self, + this: Resource, + network: Resource, + local_address: p2_udp::IpSocketAddress, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::start_bind(&mut self.inner, this, network, local_address).await + } + + fn finish_bind( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::finish_bind(&mut self.inner, this) + } + + async fn stream( + &mut self, + this: Resource, + remote_address: Option, + ) -> wasmtime_wasi::p2::SocketResult<( + Resource, + Resource, + )> { + p2_udp::HostUdpSocket::stream(&mut self.inner, this, remote_address).await + } + + fn local_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::local_address(&mut self.inner, this) + } + + fn remote_address( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::remote_address(&mut self.inner, this) + } + + fn address_family( + &mut self, + this: Resource, + ) -> wasmtime::Result { + p2_udp::HostUdpSocket::address_family(&mut self.inner, this) + } + + fn unicast_hop_limit( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::unicast_hop_limit(&mut self.inner, this) + } + + fn set_unicast_hop_limit( + &mut self, + this: Resource, + value: u8, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_unicast_hop_limit(&mut self.inner, this, value) + } + + fn receive_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::receive_buffer_size(&mut self.inner, this) + } + + fn set_receive_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_receive_buffer_size(&mut self.inner, this, value) + } + + fn send_buffer_size( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostUdpSocket::send_buffer_size(&mut self.inner, this) + } + + fn set_send_buffer_size( + &mut self, + this: Resource, + value: u64, + ) -> wasmtime_wasi::p2::SocketResult<()> { + p2_udp::HostUdpSocket::set_send_buffer_size(&mut self.inner, this, value) + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostUdpSocket::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.release_permit(this.rep()); + p2_udp::HostUdpSocket::drop(&mut self.inner, this) + } +} + +impl p2_udp::HostIncomingDatagramStream for SpinSocketsView<'_> { + fn receive( + &mut self, + this: Resource, + max_results: u64, + ) -> wasmtime_wasi::p2::SocketResult> { + p2_udp::HostIncomingDatagramStream::receive(&mut self.inner, this, max_results) + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostIncomingDatagramStream::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + p2_udp::HostIncomingDatagramStream::drop(&mut self.inner, this) + } +} + +impl p2_udp::HostOutgoingDatagramStream for SpinSocketsView<'_> { + fn check_send( + &mut self, + this: Resource, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostOutgoingDatagramStream::check_send(&mut self.inner, this) + } + + async fn send( + &mut self, + this: Resource, + datagrams: Vec, + ) -> wasmtime_wasi::p2::SocketResult { + p2_udp::HostOutgoingDatagramStream::send(&mut self.inner, this, datagrams).await + } + + fn subscribe( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + p2_udp::HostOutgoingDatagramStream::subscribe(&mut self.inner, this) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + p2_udp::HostOutgoingDatagramStream::drop(&mut self.inner, this) + } +} + +impl p2_udp_create::Host for SpinSocketsView<'_> { + fn create_udp_socket( + &mut self, + address_family: wasmtime_wasi::p2::bindings::sockets::network::IpAddressFamily, + ) -> wasmtime_wasi::p2::SocketResult> { + // Check quota before allocating the socket resource. + // See the analogous comment in `start_connect` for why we fail + // immediately rather than waiting (as outbound HTTP does). + let Ok(permit) = self.try_acquire() else { + tracing::warn!("UDP socket creation refused: connection quota exhausted"); + return Err(SocketErrorCode::NewSocketLimit.into()); + }; + let sock = p2_udp_create::Host::create_udp_socket(&mut self.inner, address_family)?; + self.register_permit(sock.rep(), permit); + Ok(sock) + } +} diff --git a/crates/factor-wasi/src/wasi_2023_10_18.rs b/crates/factor-wasi/src/wasi_2023_10_18.rs index cc0a86111d..79e2c4373b 100644 --- a/crates/factor-wasi/src/wasi_2023_10_18.rs +++ b/crates/factor-wasi/src/wasi_2023_10_18.rs @@ -1,3 +1,4 @@ +use crate::sockets::{SpinSockets, SpinSocketsView}; use spin_factors::anyhow::Result; use std::mem; use wasmtime::component::{Linker, Resource, ResourceTable}; @@ -6,8 +7,8 @@ use wasmtime_wasi::cli::{WasiCli, WasiCliCtxView}; use wasmtime_wasi::clocks::{WasiClocks, WasiClocksCtxView}; use wasmtime_wasi::filesystem::{WasiFilesystem, WasiFilesystemCtxView}; use wasmtime_wasi::p2::DynPollable; +use wasmtime_wasi::p2::bindings::sockets::udp_create_socket as p2_udp_create; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; mod latest { pub use wasmtime_wasi::p2::bindings::*; @@ -126,7 +127,7 @@ pub fn add_to_linker( clocks_closure: fn(&mut T) -> WasiClocksCtxView<'_>, cli_closure: fn(&mut T) -> WasiCliCtxView<'_>, filesystem_closure: fn(&mut T) -> WasiFilesystemCtxView<'_>, - sockets_closure: fn(&mut T) -> WasiSocketsCtxView<'_>, + sockets_closure: fn(&mut T) -> SpinSocketsView<'_>, ) -> Result<()> where T: Send + 'static, @@ -150,13 +151,13 @@ where wasi::cli::terminal_stdin::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stdout::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stderr::add_to_linker::<_, WasiCli>(linker, cli_closure)?; - wasi::sockets::tcp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::tcp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::instance_network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; + wasi::sockets::tcp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::instance_network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::ip_name_lookup::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; Ok(()) } @@ -900,9 +901,9 @@ impl wasi::cli::terminal_output::HostTerminalOutput for WasiCliCtxView<'_> { } } -impl wasi::sockets::tcp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::tcp::Host for SpinSocketsView<'_> {} -impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp::HostTcpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -935,6 +936,10 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { network: Resource, remote_address: IpSocketAddress, ) -> wasmtime::Result> { + // Delegate to the P2 SpinSocketsView impl (passing `self`, not `&mut self.inner`). + // This snapshot uses the raw P2 TcpSocket type — the resource rep is the same at + // start_connect and drop time — so the P2 impl's quota acquire/register/release + // logic round-trips correctly without any wrapper-level bookkeeping here. convert_result( latest::sockets::tcp::HostTcpSocket::start_connect( self, @@ -1147,7 +1152,7 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp_create_socket::Host for SpinSocketsView<'_> { fn create_tcp_socket( &mut self, address_family: IpAddressFamily, @@ -1159,7 +1164,7 @@ impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::udp::Host for SpinSocketsView<'_> {} /// Between the snapshot of WASI that this file is implementing and the current /// implementation of WASI UDP sockets were redesigned slightly to deal with @@ -1180,7 +1185,7 @@ pub enum UdpSocket { impl UdpSocket { async fn finish_connect( - table: &mut WasiSocketsCtxView<'_>, + table: &mut SpinSocketsView<'_>, socket: &Resource, explicit: bool, ) -> wasmtime::Result> { @@ -1197,8 +1202,12 @@ impl UdpSocket { }; let borrow = Resource::new_borrow(new_socket.rep()); let result = convert_result( - latest::sockets::udp::HostUdpSocket::stream(table, borrow, addr.map(|a| a.into())) - .await, + latest::sockets::udp::HostUdpSocket::stream( + &mut table.inner, + borrow, + addr.map(|a| a.into()), + ) + .await, )?; let (incoming, outgoing) = match result { Ok(pair) => pair, @@ -1223,7 +1232,7 @@ impl UdpSocket { } } -impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostUdpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -1233,7 +1242,7 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { let socket = self.table.get(&self_)?.inner()?; convert_result( latest::sockets::udp::HostUdpSocket::start_bind( - self, + &mut self.inner, socket, network, local_address.into(), @@ -1248,7 +1257,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::finish_bind( - self, socket, + &mut self.inner, + socket, )) } @@ -1358,7 +1368,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::local_address( - self, socket, + &mut self.inner, + socket, )) } @@ -1368,13 +1379,15 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::remote_address( - self, socket, + &mut self.inner, + socket, )) } fn address_family(&mut self, self_: Resource) -> wasmtime::Result { let socket = self.table.get(&self_)?.inner()?; - latest::sockets::udp::HostUdpSocket::address_family(self, socket).map(|e| e.into()) + latest::sockets::udp::HostUdpSocket::address_family(&mut self.inner, socket) + .map(|e| e.into()) } fn ipv6_only( @@ -1398,7 +1411,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::unicast_hop_limit( - self, socket, + &mut self.inner, + socket, )) } @@ -1409,7 +1423,9 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::set_unicast_hop_limit( - self, socket, value, + &mut self.inner, + socket, + value, )) } @@ -1419,7 +1435,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::receive_buffer_size( - self, socket, + &mut self.inner, + socket, )) } @@ -1430,7 +1447,11 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result( - latest::sockets::udp::HostUdpSocket::set_receive_buffer_size(self, socket, value), + latest::sockets::udp::HostUdpSocket::set_receive_buffer_size( + &mut self.inner, + socket, + value, + ), ) } @@ -1440,7 +1461,8 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::send_buffer_size( - self, socket, + &mut self.inner, + socket, )) } @@ -1451,17 +1473,25 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { ) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; convert_result(latest::sockets::udp::HostUdpSocket::set_send_buffer_size( - self, socket, value, + &mut self.inner, + socket, + value, )) } fn subscribe(&mut self, self_: Resource) -> wasmtime::Result> { let socket = self.table.get(&self_)?.inner()?; - latest::sockets::udp::HostUdpSocket::subscribe(self, socket) + latest::sockets::udp::HostUdpSocket::subscribe(&mut self.inner, socket) } fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { + let socket_rep = rep.rep(); + // Delete before releasing: the only error case that matters is `HasChildren`, + // where the socket still exists and the permit must stay held. `NotPresent` + // (double-drop) is unreachable from a guest, and `release_permit` is idempotent + // anyway since `HashMap::remove` is a no-op for absent keys. let me = self.table.delete(rep)?; + self.release_permit(socket_rep); let socket = match me { UdpSocket::Initial(s) => s, UdpSocket::Connecting(s, _) => s, @@ -1470,49 +1500,63 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { incoming, outgoing, } => { - latest::sockets::udp::HostIncomingDatagramStream::drop(self, incoming)?; - latest::sockets::udp::HostOutgoingDatagramStream::drop(self, outgoing)?; + latest::sockets::udp::HostIncomingDatagramStream::drop(&mut self.inner, incoming)?; + latest::sockets::udp::HostOutgoingDatagramStream::drop(&mut self.inner, outgoing)?; socket } UdpSocket::Dummy => return Ok(()), }; - latest::sockets::udp::HostUdpSocket::drop(self, socket) + // Drop the inner P2 socket directly, bypassing quota tracking for rep R. + latest::sockets::udp::HostUdpSocket::drop(&mut self.inner, socket) } } -impl wasi::sockets::udp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp_create_socket::Host for SpinSocketsView<'_> { fn create_udp_socket( &mut self, address_family: IpAddressFamily, ) -> wasmtime::Result, SocketErrorCode>> { - let result = convert_result(latest::sockets::udp_create_socket::Host::create_udp_socket( - self, + // Cannot delegate to the P2 SpinSocketsView impl here (unlike TCP). This snapshot + // wraps the P2 UdpSocket in a custom UdpSocket enum stored in a separate resource + // table, so the outer wrapper rep (used at drop time) differs from the inner P2 + // socket rep (which the P2 impl would register the permit under). Delegating would + // cause release_permit at drop time to look up the wrong rep and silently leak the + // semaphore slot. Instead, quota is checked explicitly here and the permit is + // registered under the wrapper rep. + let Ok(permit) = self.try_acquire() else { + tracing::warn!("UDP socket creation refused: connection quota exhausted"); + return Ok(Err(SocketErrorCode::NewSocketLimit)); + }; + // Create the inner P2 socket via self.inner to avoid charging quota at the P2 level. + let result = convert_result(p2_udp_create::Host::create_udp_socket( + &mut self.inner, address_family.into(), ))?; let socket = match result { Ok(socket) => socket, Err(e) => return Ok(Err(e)), }; - let socket = self.table.push(UdpSocket::Initial(socket))?; - Ok(Ok(socket)) + let wrapped = self.table.push(UdpSocket::Initial(socket))?; + self.register_permit(wrapped.rep(), permit); + Ok(Ok(wrapped)) } } -impl wasi::sockets::instance_network::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::instance_network::Host for SpinSocketsView<'_> { fn instance_network(&mut self) -> wasmtime::Result> { - latest::sockets::instance_network::Host::instance_network(self) + latest::sockets::instance_network::Host::instance_network(&mut self.inner) } } -impl wasi::sockets::network::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::network::Host for SpinSocketsView<'_> {} -impl wasi::sockets::network::HostNetwork for WasiSocketsCtxView<'_> { +impl wasi::sockets::network::HostNetwork for SpinSocketsView<'_> { fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::network::HostNetwork::drop(self, rep) + latest::sockets::network::HostNetwork::drop(&mut self.inner, rep) } } -impl wasi::sockets::ip_name_lookup::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::Host for SpinSocketsView<'_> { fn resolve_addresses( &mut self, network: Resource, @@ -1521,19 +1565,22 @@ impl wasi::sockets::ip_name_lookup::Host for WasiSocketsCtxView<'_> { _include_unavailable: bool, ) -> wasmtime::Result, SocketErrorCode>> { convert_result(latest::sockets::ip_name_lookup::Host::resolve_addresses( - self, network, name, + &mut self.inner, + network, + name, )) } } -impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for SpinSocketsView<'_> { fn resolve_next_address( &mut self, self_: Resource, ) -> wasmtime::Result, SocketErrorCode>> { convert_result( latest::sockets::ip_name_lookup::HostResolveAddressStream::resolve_next_address( - self, self_, + &mut self.inner, + self_, ) .map(|e| e.map(|e| e.into())), ) @@ -1543,11 +1590,11 @@ impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxV &mut self, self_: Resource, ) -> wasmtime::Result> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(self, self_) + latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(&mut self.inner, self_) } fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(self, rep) + latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(&mut self.inner, rep) } } diff --git a/crates/factor-wasi/src/wasi_2023_11_10.rs b/crates/factor-wasi/src/wasi_2023_11_10.rs index 81de4c6c74..c296a4dad0 100644 --- a/crates/factor-wasi/src/wasi_2023_11_10.rs +++ b/crates/factor-wasi/src/wasi_2023_11_10.rs @@ -1,11 +1,11 @@ use super::wasi_2023_10_18::{convert, convert_result}; +use crate::sockets::{SpinSockets, SpinSocketsView}; use spin_factors::anyhow::Result; use wasmtime::component::{Linker, Resource, ResourceTable}; use wasmtime_wasi::cli::{WasiCli, WasiCliCtxView}; use wasmtime_wasi::clocks::{WasiClocks, WasiClocksCtxView}; use wasmtime_wasi::filesystem::{WasiFilesystem, WasiFilesystemCtxView}; use wasmtime_wasi::random::{WasiRandom, WasiRandomCtx}; -use wasmtime_wasi::sockets::{WasiSockets, WasiSocketsCtxView}; mod latest { pub use wasmtime_wasi::p2::bindings::*; @@ -119,7 +119,7 @@ pub fn add_to_linker( clocks_closure: fn(&mut T) -> WasiClocksCtxView<'_>, cli_closure: fn(&mut T) -> WasiCliCtxView<'_>, filesystem_closure: fn(&mut T) -> WasiFilesystemCtxView<'_>, - sockets_closure: fn(&mut T) -> WasiSocketsCtxView<'_>, + sockets_closure: fn(&mut T) -> SpinSocketsView<'_>, ) -> Result<()> where T: Send + 'static, @@ -144,13 +144,13 @@ where wasi::cli::terminal_stdin::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stdout::add_to_linker::<_, WasiCli>(linker, cli_closure)?; wasi::cli::terminal_stderr::add_to_linker::<_, WasiCli>(linker, cli_closure)?; - wasi::sockets::tcp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::tcp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::udp_create_socket::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::instance_network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::network::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; - wasi::sockets::ip_name_lookup::add_to_linker::<_, WasiSockets>(linker, sockets_closure)?; + wasi::sockets::tcp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::tcp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::udp_create_socket::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::instance_network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::network::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; + wasi::sockets::ip_name_lookup::add_to_linker::<_, SpinSockets>(linker, sockets_closure)?; Ok(()) } @@ -830,9 +830,9 @@ impl wasi::cli::terminal_output::HostTerminalOutput for WasiCliCtxView<'_> { } } -impl wasi::sockets::tcp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::tcp::Host for SpinSocketsView<'_> {} -impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp::HostTcpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -865,6 +865,10 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { network: Resource, remote_address: IpSocketAddress, ) -> wasmtime::Result> { + // Delegate to the P2 SpinSocketsView impl (passing `self`, not `&mut self.inner`). + // This snapshot uses the raw P2 TcpSocket type — the resource rep is the same at + // start_connect and drop time — so the P2 impl's quota acquire/register/release + // logic round-trips correctly without any wrapper-level bookkeeping here. convert_result( latest::sockets::tcp::HostTcpSocket::start_connect( self, @@ -1123,7 +1127,7 @@ impl wasi::sockets::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::tcp_create_socket::Host for SpinSocketsView<'_> { fn create_tcp_socket( &mut self, address_family: IpAddressFamily, @@ -1135,9 +1139,9 @@ impl wasi::sockets::tcp_create_socket::Host for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::udp::Host for SpinSocketsView<'_> {} -impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostUdpSocket for SpinSocketsView<'_> { async fn start_bind( &mut self, self_: Resource, @@ -1290,7 +1294,7 @@ impl wasi::sockets::udp::HostUdpSocket for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostOutgoingDatagramStream for SpinSocketsView<'_> { fn check_send( &mut self, self_: Resource, @@ -1325,7 +1329,7 @@ impl wasi::sockets::udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp::HostIncomingDatagramStream for SpinSocketsView<'_> { fn receive( &mut self, self_: Resource, @@ -1351,7 +1355,7 @@ impl wasi::sockets::udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::udp_create_socket::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::udp_create_socket::Host for SpinSocketsView<'_> { fn create_udp_socket( &mut self, address_family: IpAddressFamily, @@ -1363,40 +1367,43 @@ impl wasi::sockets::udp_create_socket::Host for WasiSocketsCtxView<'_> { } } -impl wasi::sockets::instance_network::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::instance_network::Host for SpinSocketsView<'_> { fn instance_network(&mut self) -> wasmtime::Result> { - latest::sockets::instance_network::Host::instance_network(self) + latest::sockets::instance_network::Host::instance_network(&mut self.inner) } } -impl wasi::sockets::network::Host for WasiSocketsCtxView<'_> {} +impl wasi::sockets::network::Host for SpinSocketsView<'_> {} -impl wasi::sockets::network::HostNetwork for WasiSocketsCtxView<'_> { +impl wasi::sockets::network::HostNetwork for SpinSocketsView<'_> { fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::network::HostNetwork::drop(self, rep) + latest::sockets::network::HostNetwork::drop(&mut self.inner, rep) } } -impl wasi::sockets::ip_name_lookup::Host for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::Host for SpinSocketsView<'_> { fn resolve_addresses( &mut self, network: Resource, name: String, ) -> wasmtime::Result, SocketErrorCode>> { convert_result(latest::sockets::ip_name_lookup::Host::resolve_addresses( - self, network, name, + &mut self.inner, + network, + name, )) } } -impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxView<'_> { +impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for SpinSocketsView<'_> { fn resolve_next_address( &mut self, self_: Resource, ) -> wasmtime::Result, SocketErrorCode>> { convert_result( latest::sockets::ip_name_lookup::HostResolveAddressStream::resolve_next_address( - self, self_, + &mut self.inner, + self_, ) .map(|e| e.map(|e| e.into())), ) @@ -1406,11 +1413,11 @@ impl wasi::sockets::ip_name_lookup::HostResolveAddressStream for WasiSocketsCtxV &mut self, self_: Resource, ) -> wasmtime::Result> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(self, self_) + latest::sockets::ip_name_lookup::HostResolveAddressStream::subscribe(&mut self.inner, self_) } fn drop(&mut self, rep: Resource) -> wasmtime::Result<()> { - latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(self, rep) + latest::sockets::ip_name_lookup::HostResolveAddressStream::drop(&mut self.inner, rep) } } diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index b5b2f6b1b8..f164d0dd50 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -79,6 +79,46 @@ impl ResolvedRuntimeConfig { summaries.push(format!("[llm_compute: {ty}")); } } + // [outbound_networking: max_total_connections=N] + if let Some(table) = self + .toml + .get("outbound_networking") + .and_then(Value::as_table) + { + if let Some(max) = table + .get("max_total_connections") + .and_then(Value::as_integer) + { + summaries.push(format!( + "[outbound_networking: max_total_connections={max}]" + )); + } + } + // [outbound_redis: max_connections=N], [outbound_pg: max_connections=N], [outbound_mysql: max_connections=N], [outbound_mqtt: max_connections=N], [outbound_http: max_connections=N] + for key in [ + "outbound_redis", + "outbound_pg", + "outbound_mysql", + "outbound_mqtt", + "outbound_http", + ] { + if let Some(table) = self.toml.get(key).and_then(Value::as_table) { + if let Some(max) = table.get("max_connections").and_then(Value::as_integer) { + summaries.push(format!("[{key}: max_connections={max}]")); + } + } + } + // [outbound_http: max_concurrent_requests=N (deprecated)] + if let Some(table) = self.toml.get("outbound_http").and_then(Value::as_table) { + if let Some(max) = table + .get("max_concurrent_requests") + .and_then(Value::as_integer) + { + summaries.push(format!( + "[outbound_http: max_concurrent_requests={max} (deprecated, use max_connections)]" + )); + } + } if !summaries.is_empty() { let summaries = summaries.join(", "); let from_path = runtime_config_path @@ -350,14 +390,18 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_pg::runtime_config::spin::config_from_table(&self.toml.table) } } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_mysql::runtime_config::spin::config_from_table(&self.toml.table) } } @@ -368,8 +412,10 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config( + &mut self, + ) -> anyhow::Result::RuntimeConfig>> { + spin_factor_outbound_redis::runtime_config::spin::config_from_table(&self.toml.table) } }