Skip to content

Commit

Permalink
feat: Add cli option to limit number of connections per ip (#22)
Browse files Browse the repository at this point in the history
* Add cli option to limit number of connections per ip
  • Loading branch information
fabi321 committed May 30, 2024
1 parent aaff580 commit a08b06d
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 4 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

### Added

- Command line option `--connections-per-ip` that allows limiting the number of connections per ip address. Default is unlimited ([#22])

### Fixed

- Raise `ffmpeg` errors as early as possible, e.g. when the `ffmpeg` command is not found

[#22]: https://github.com/sbernauer/breakwater/pull/22

## [0.13.0] - 2024-05-15

## Added
### Added

- Also release binary for `aarch64-apple-darwin`

Expand Down
4 changes: 4 additions & 0 deletions breakwater/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ pub struct CliArgs {
#[cfg(feature = "vnc")]
#[clap(short, long, default_value_t = 5900)]
pub vnc_port: u16,

/// Allow only a certain number of connections per ip address
#[clap(short, long)]
pub connections_per_ip: Option<u64>,
}
3 changes: 2 additions & 1 deletion breakwater/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async fn main() -> Result<(), Error> {
statistics_save_mode,
);

let server = Server::new(
let mut server = Server::new(
&args.listen_address,
Arc::clone(&fb),
statistics_tx.clone(),
Expand All @@ -124,6 +124,7 @@ async fn main() -> Result<(), Error> {
.context(InvalidNetworkBufferSizeSnafu {
network_buffer_size: args.network_buffer_size,
})?,
args.connections_per_ip,
)
.await
.context(StartPixelflutServerSnafu)?;
Expand Down
45 changes: 43 additions & 2 deletions breakwater/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::{cmp::min, net::IpAddr, sync::Arc, time::Duration};

use breakwater_core::framebuffer::FrameBuffer;
Expand Down Expand Up @@ -42,6 +44,8 @@ pub struct Server {
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
connections_per_ip: HashMap<IpAddr, u64>,
max_connections_per_ip: Option<u64>,
}

impl Server {
Expand All @@ -50,6 +54,7 @@ impl Server {
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
max_connections_per_ip: Option<u64>,
) -> Result<Self, Error> {
let listener = TcpListener::bind(listen_address)
.await
Expand All @@ -61,30 +66,60 @@ impl Server {
fb,
statistics_tx,
network_buffer_size,
connections_per_ip: HashMap::new(),
max_connections_per_ip,
})
}

pub async fn start(&self) -> Result<(), Error> {
pub async fn start(&mut self) -> Result<(), Error> {
let (connection_dropped_tx, mut connection_dropped_rx) =
mpsc::unbounded_channel::<IpAddr>();
let connection_dropped_tx = self.max_connections_per_ip.map(|_| connection_dropped_tx);
loop {
let (socket, socket_addr) = self
let (mut socket, socket_addr) = self
.listener
.accept()
.await
.context(AcceptNewClientConnectionSnafu)?;

// If connections are unlimited, will execute one try_recv per new connection
while let Ok(ip) = connection_dropped_rx.try_recv() {
if let Entry::Occupied(mut o) = self.connections_per_ip.entry(ip) {
let connections = o.get_mut();
*connections -= 1;
if *connections == 0 {
o.remove_entry();
}
}
}

// If you connect via IPv4 you often show up as embedded inside an IPv6 address
// Extracting the embedded information here, so we get the real (TM) address
let ip = socket_addr.ip().to_canonical();

if let Some(limit) = self.max_connections_per_ip {
let current_connections = self.connections_per_ip.entry(ip).or_default();
if *current_connections < limit {
*current_connections += 1;
} else {
// Errors if session is dropped prematurely
let _ = socket.shutdown().await;
continue;
}
};

let fb_for_thread = Arc::clone(&self.fb);
let statistics_tx_for_thread = self.statistics_tx.clone();
let network_buffer_size = self.network_buffer_size;
let connection_dropped_tx_clone = connection_dropped_tx.clone();
tokio::spawn(async move {
handle_connection(
socket,
ip,
fb_for_thread,
statistics_tx_for_thread,
network_buffer_size,
connection_dropped_tx_clone,
)
.await
});
Expand All @@ -98,6 +133,7 @@ pub async fn handle_connection(
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
connection_dropped_tx: Option<mpsc::UnboundedSender<IpAddr>>,
) -> Result<(), Error> {
debug!("Handling connection from {ip}");

Expand Down Expand Up @@ -195,5 +231,10 @@ pub async fn handle_connection(
.await
.context(WriteToStatisticsChannelSnafu)?;

if let Some(tx) = connection_dropped_tx {
// Will fail if the server thread ends before the client thread
let _ = tx.send(ip);
}

Ok(())
}
7 changes: 7 additions & 0 deletions breakwater/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async fn test_correct_responses_to_general_commands(
fb,
statistics_channel.0,
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand Down Expand Up @@ -125,6 +126,7 @@ async fn test_setting_pixel(
fb,
statistics_channel.0,
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand Down Expand Up @@ -152,6 +154,7 @@ async fn test_safe(
fb.clone(),
statistics_channel.0,
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand Down Expand Up @@ -225,6 +228,7 @@ async fn test_drawing_rect(
Arc::clone(&fb),
statistics_channel.0.clone(),
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand All @@ -238,6 +242,7 @@ async fn test_drawing_rect(
Arc::clone(&fb),
statistics_channel.0.clone(),
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand All @@ -251,6 +256,7 @@ async fn test_drawing_rect(
Arc::clone(&fb),
statistics_channel.0.clone(),
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand All @@ -264,6 +270,7 @@ async fn test_drawing_rect(
Arc::clone(&fb),
statistics_channel.0.clone(),
DEFAULT_NETWORK_BUFFER_SIZE,
None,
)
.await
.unwrap();
Expand Down

0 comments on commit a08b06d

Please sign in to comment.