Skip to content

Commit

Permalink
SCM sockets transmit listeners in binary format
Browse files Browse the repository at this point in the history
  • Loading branch information
Keksoj committed Feb 2, 2024
1 parent a6ffebe commit 6d43eb1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
10 changes: 10 additions & 0 deletions command/src/command.proto
Original file line number Diff line number Diff line change
Expand Up @@ -674,4 +674,14 @@ message ServerConfig {
required uint64 command_buffer_size = 13 [default = 1000000];
required uint64 max_command_buffer_size = 14 [default = 2000000];
optional ServerMetricsConfig metrics = 15;
}

// Addresses of listeners, passed to new workers
message ListenersCount {
// socket addresses of HTTP listeners
repeated string http = 1;
// socket addresses of HTTPS listeners
repeated string tls = 2;
// socket addresses of TCP listeners
repeated string tcp = 3;
}
64 changes: 37 additions & 27 deletions command/src/scm_socket.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::{
io::{IoSlice, IoSliceMut},
net::SocketAddr,
net::{AddrParseError, SocketAddr},
os::unix::{
io::{FromRawFd, IntoRawFd, RawFd},
net::UnixStream as StdUnixStream,
},
str::from_utf8,
};

use mio::net::TcpListener;
use nix::{cmsg_space, sys::socket};
use serde_json;
use prost::{DecodeError, Message};

use crate::proto::command::ListenersCount;

pub const MAX_FDS_OUT: usize = 200;
pub const MAX_BYTES_OUT: usize = 4096;
Expand All @@ -30,6 +31,13 @@ pub enum ScmSocketError {
InvalidCharSet(String),
#[error("Could not deserialize utf8 string into listeners: {0}")]
ListenerParse(String),
#[error("Wrong socket address {address}: {error}")]
WrongSocketAddress {
address: String,
error: AddrParseError,
},
#[error("error decoding the protobuf format of the listeners: {0}")]
DecodeError(DecodeError),
}

/// A unix socket specialized for file descriptor passing
Expand Down Expand Up @@ -80,14 +88,12 @@ impl ScmSocket {
/// Send listeners (socket addresses and file descriptors) via an scm socket
pub fn send_listeners(&self, listeners: &Listeners) -> Result<(), ScmSocketError> {
let listeners_count = ListenersCount {
http: listeners.http.iter().map(|t| t.0).collect(),
tls: listeners.tls.iter().map(|t| t.0).collect(),
tcp: listeners.tcp.iter().map(|t| t.0).collect(),
http: listeners.http.iter().map(|t| t.0.to_string()).collect(),
tls: listeners.tls.iter().map(|t| t.0.to_string()).collect(),
tcp: listeners.tcp.iter().map(|t| t.0.to_string()).collect(),
};

let message = serde_json::to_string(&listeners_count)
.map(|s| s.into_bytes())
.unwrap_or_else(|_| Vec::new());
let message = listeners_count.encode_length_delimited_to_vec();

let mut file_descriptors: Vec<RawFd> = Vec::new();

Expand All @@ -109,18 +115,18 @@ impl ScmSocket {

debug!("{} received :{:?}", self.fd, (size, file_descriptor_length));

let raw_listener_list = from_utf8(&buf[..size])
.map_err(|utf8_error| ScmSocketError::InvalidCharSet(utf8_error.to_string()))?;
let listeners_count = ListenersCount::decode_length_delimited(&buf[..size])
.map_err(|error| ScmSocketError::DecodeError(error))?;

let mut listeners_count = serde_json::from_str::<ListenersCount>(raw_listener_list)
.map_err(|error| ScmSocketError::ListenerParse(error.to_string()))?;
let mut http_addresses = parse_addresses(&listeners_count.http)?;
let mut tls_addresses = parse_addresses(&listeners_count.tls)?;
let mut tcp_addresses = parse_addresses(&listeners_count.tcp)?;

let mut index = 0;
let len = listeners_count.http.len();
let mut http = Vec::new();
http.extend(
listeners_count
.http
http_addresses
.drain(..)
.zip(received_fds[index..index + len].iter().cloned()),
);
Expand All @@ -129,17 +135,15 @@ impl ScmSocket {
let len = listeners_count.tls.len();
let mut tls = Vec::new();
tls.extend(
listeners_count
.tls
tls_addresses
.drain(..)
.zip(received_fds[index..index + len].iter().cloned()),
);

index += len;
let mut tcp = Vec::new();
tcp.extend(
listeners_count
.tcp
tcp_addresses
.drain(..)
.zip(received_fds[index..file_descriptor_length].iter().cloned()),
);
Expand Down Expand Up @@ -208,21 +212,14 @@ impl ScmSocket {
}
}

/// Socket addresses and file descriptors needed by a Proxy to start listening
/// Socket addresses and file descriptors of TCP sockets, needed by a Proxy to start listening
#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq)]
pub struct Listeners {
pub http: Vec<(SocketAddr, RawFd)>,
pub tls: Vec<(SocketAddr, RawFd)>,
pub tcp: Vec<(SocketAddr, RawFd)>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
struct ListenersCount {
pub http: Vec<SocketAddr>,
pub tls: Vec<SocketAddr>,
pub tcp: Vec<SocketAddr>,
}

impl Listeners {
pub fn get_http(&mut self, addr: &SocketAddr) -> Option<RawFd> {
self.http
Expand Down Expand Up @@ -267,6 +264,19 @@ impl Listeners {
}
}

fn parse_addresses(addresses: &Vec<String>) -> Result<Vec<SocketAddr>, ScmSocketError> {
let mut parsed_addresses = Vec::new();
for address in addresses {
parsed_addresses.push(address.parse::<SocketAddr>().map_err(|error| {
ScmSocketError::WrongSocketAddress {
address: address.to_owned(),
error,
}
})?);
}
Ok(parsed_addresses)
}

#[cfg(test)]
mod tests {

Expand Down

0 comments on commit 6d43eb1

Please sign in to comment.