diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b715cd..7232304 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,19 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Added + +- Command line option `--connections-per-ip` that allows limiting the number of connections per ip address. Default is unlimited ([#22]) + ### Fixed - Raise `ffmpeg` errors as early as possible, e.g. when the `ffmpeg` command is not found +[#22]: https://github.com/sbernauer/breakwater/pull/22 + ## [0.13.0] - 2024-05-15 -## Added +### Added - Also release binary for `aarch64-apple-darwin` diff --git a/breakwater/src/cli_args.rs b/breakwater/src/cli_args.rs index f3a8ed0..caf7dcc 100644 --- a/breakwater/src/cli_args.rs +++ b/breakwater/src/cli_args.rs @@ -70,4 +70,8 @@ pub struct CliArgs { #[cfg(feature = "vnc")] #[clap(short, long, default_value_t = 5900)] pub vnc_port: u16, + + /// Allow only a certain number of connections per ip address + #[clap(short, long)] + pub connections_per_ip: Option, } diff --git a/breakwater/src/main.rs b/breakwater/src/main.rs index 7f01fb4..9ef19a0 100644 --- a/breakwater/src/main.rs +++ b/breakwater/src/main.rs @@ -114,7 +114,7 @@ async fn main() -> Result<(), Error> { statistics_save_mode, ); - let server = Server::new( + let mut server = Server::new( &args.listen_address, Arc::clone(&fb), statistics_tx.clone(), @@ -124,6 +124,7 @@ async fn main() -> Result<(), Error> { .context(InvalidNetworkBufferSizeSnafu { network_buffer_size: args.network_buffer_size, })?, + args.connections_per_ip, ) .await .context(StartPixelflutServerSnafu)?; diff --git a/breakwater/src/server.rs b/breakwater/src/server.rs index 8db2b0f..c4c3cb1 100644 --- a/breakwater/src/server.rs +++ b/breakwater/src/server.rs @@ -1,3 +1,5 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::{cmp::min, net::IpAddr, sync::Arc, time::Duration}; use breakwater_core::framebuffer::FrameBuffer; @@ -42,6 +44,8 @@ pub struct Server { fb: Arc, statistics_tx: mpsc::Sender, network_buffer_size: usize, + connections_per_ip: HashMap, + max_connections_per_ip: Option, } impl Server { @@ -50,6 +54,7 @@ impl Server { fb: Arc, statistics_tx: mpsc::Sender, network_buffer_size: usize, + max_connections_per_ip: Option, ) -> Result { let listener = TcpListener::bind(listen_address) .await @@ -61,23 +66,52 @@ impl Server { fb, statistics_tx, network_buffer_size, + connections_per_ip: HashMap::new(), + max_connections_per_ip, }) } - pub async fn start(&self) -> Result<(), Error> { + pub async fn start(&mut self) -> Result<(), Error> { + let (connection_dropped_tx, mut connection_dropped_rx) = + mpsc::unbounded_channel::(); + let connection_dropped_tx = self.max_connections_per_ip.map(|_| connection_dropped_tx); loop { - let (socket, socket_addr) = self + let (mut socket, socket_addr) = self .listener .accept() .await .context(AcceptNewClientConnectionSnafu)?; + + // If connections are unlimited, will execute one try_recv per new connection + while let Ok(ip) = connection_dropped_rx.try_recv() { + if let Entry::Occupied(mut o) = self.connections_per_ip.entry(ip) { + let connections = o.get_mut(); + *connections -= 1; + if *connections == 0 { + o.remove_entry(); + } + } + } + // If you connect via IPv4 you often show up as embedded inside an IPv6 address // Extracting the embedded information here, so we get the real (TM) address let ip = socket_addr.ip().to_canonical(); + if let Some(limit) = self.max_connections_per_ip { + let current_connections = self.connections_per_ip.entry(ip).or_default(); + if *current_connections < limit { + *current_connections += 1; + } else { + // Errors if session is dropped prematurely + let _ = socket.shutdown().await; + continue; + } + }; + let fb_for_thread = Arc::clone(&self.fb); let statistics_tx_for_thread = self.statistics_tx.clone(); let network_buffer_size = self.network_buffer_size; + let connection_dropped_tx_clone = connection_dropped_tx.clone(); tokio::spawn(async move { handle_connection( socket, @@ -85,6 +119,7 @@ impl Server { fb_for_thread, statistics_tx_for_thread, network_buffer_size, + connection_dropped_tx_clone, ) .await }); @@ -98,6 +133,7 @@ pub async fn handle_connection( fb: Arc, statistics_tx: mpsc::Sender, network_buffer_size: usize, + connection_dropped_tx: Option>, ) -> Result<(), Error> { debug!("Handling connection from {ip}"); @@ -195,5 +231,10 @@ pub async fn handle_connection( .await .context(WriteToStatisticsChannelSnafu)?; + if let Some(tx) = connection_dropped_tx { + // Will fail if the server thread ends before the client thread + let _ = tx.send(ip); + } + Ok(()) } diff --git a/breakwater/src/tests.rs b/breakwater/src/tests.rs index 521ce18..dc49c92 100644 --- a/breakwater/src/tests.rs +++ b/breakwater/src/tests.rs @@ -60,6 +60,7 @@ async fn test_correct_responses_to_general_commands( fb, statistics_channel.0, DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap(); @@ -125,6 +126,7 @@ async fn test_setting_pixel( fb, statistics_channel.0, DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap(); @@ -152,6 +154,7 @@ async fn test_safe( fb.clone(), statistics_channel.0, DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap(); @@ -225,6 +228,7 @@ async fn test_drawing_rect( Arc::clone(&fb), statistics_channel.0.clone(), DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap(); @@ -238,6 +242,7 @@ async fn test_drawing_rect( Arc::clone(&fb), statistics_channel.0.clone(), DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap(); @@ -251,6 +256,7 @@ async fn test_drawing_rect( Arc::clone(&fb), statistics_channel.0.clone(), DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap(); @@ -264,6 +270,7 @@ async fn test_drawing_rect( Arc::clone(&fb), statistics_channel.0.clone(), DEFAULT_NETWORK_BUFFER_SIZE, + None, ) .await .unwrap();