From a944dbcff2de49e45d5fa99edb227c85a5c3d40f Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Tue, 28 Mar 2023 15:47:09 -0400 Subject: [PATCH] feat(console): add support for Unix domain sockets (#388) Add support for console connections that use Unix domain sockets rather than TCP. Closes #296. Co-authored-by: Eliza Weisman --- Cargo.lock | 1 + console-subscriber/Cargo.toml | 2 +- console-subscriber/examples/uds.rs | 40 ++++++++++ console-subscriber/src/builder.rs | 113 ++++++++++++++++++++++++++--- console-subscriber/src/lib.rs | 33 ++++++--- tokio-console/Cargo.toml | 1 + tokio-console/args.example | 3 + tokio-console/src/config.rs | 4 + tokio-console/src/conn.rs | 33 ++++++++- 9 files changed, 208 insertions(+), 22 deletions(-) create mode 100644 console-subscriber/examples/uds.rs diff --git a/Cargo.lock b/Cargo.lock index be4f27905..02156e518 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,6 +1454,7 @@ dependencies = [ "tokio", "toml", "tonic", + "tower", "tracing", "tracing-journald", "tracing-subscriber", diff --git a/console-subscriber/Cargo.toml b/console-subscriber/Cargo.toml index 3ed323c32..3c044c8d7 100644 --- a/console-subscriber/Cargo.toml +++ b/console-subscriber/Cargo.toml @@ -33,7 +33,7 @@ env-filter = ["tracing-subscriber/env-filter"] crossbeam-utils = "0.8.7" tokio = { version = "^1.21", features = ["sync", "time", "macros", "tracing"] } -tokio-stream = "0.1" +tokio-stream = { version = "0.1", features = ["net"] } thread_local = "1.1.3" console-api = { version = "0.4.0", path = "../console-api", features = ["transport"] } tonic = { version = "0.8", features = ["transport"] } diff --git a/console-subscriber/examples/uds.rs b/console-subscriber/examples/uds.rs new file mode 100644 index 000000000..03f6f2d4a --- /dev/null +++ b/console-subscriber/examples/uds.rs @@ -0,0 +1,40 @@ +//! Demonstrates serving the console API over a [Unix domain socket] (UDS) +//! connection, rather than over TCP. +//! +//! Note that this example only works on Unix operating systems that +//! support UDS, such as Linux, BSDs, and macOS. +//! +//! [Unix domain socket]: https://en.wikipedia.org/wiki/Unix_domain_socket + +#[cfg(unix)] +use { + std::time::Duration, + tokio::{fs, task, time}, + tracing::info, +}; + +#[cfg(unix)] +#[tokio::main] +async fn main() -> Result<(), Box> { + let cwd = fs::canonicalize(".").await?; + let addr = cwd.join("console-server"); + console_subscriber::ConsoleLayer::builder() + .server_addr(&*addr) + .init(); + info!( + "listening for console connections at file://localhost{}", + addr.display() + ); + task::Builder::default() + .name("sleepy") + .spawn(async move { time::sleep(Duration::from_secs(90)).await }) + .unwrap() + .await?; + + Ok(()) +} + +#[cfg(not(unix))] +fn main() { + panic!("only supported on Unix platforms") +} diff --git a/console-subscriber/src/builder.rs b/console-subscriber/src/builder.rs index 2c9fb0cc6..1e0819dde 100644 --- a/console-subscriber/src/builder.rs +++ b/console-subscriber/src/builder.rs @@ -1,6 +1,8 @@ use super::{ConsoleLayer, Server}; +#[cfg(unix)] +use std::path::Path; use std::{ - net::{SocketAddr, ToSocketAddrs}, + net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, path::PathBuf, thread, time::Duration, @@ -32,7 +34,7 @@ pub struct Builder { pub(crate) retention: Duration, /// The address on which to serve the RPC server. - pub(super) server_addr: SocketAddr, + pub(super) server_addr: ServerAddr, /// If and where to save a recording of the events. pub(super) recording_path: Option, @@ -58,7 +60,7 @@ impl Default for Builder { publish_interval: ConsoleLayer::DEFAULT_PUBLISH_INTERVAL, retention: ConsoleLayer::DEFAULT_RETENTION, poll_duration_max: ConsoleLayer::DEFAULT_POLL_DURATION_MAX, - server_addr: SocketAddr::new(Server::DEFAULT_IP, Server::DEFAULT_PORT), + server_addr: ServerAddr::Tcp(SocketAddr::new(Server::DEFAULT_IP, Server::DEFAULT_PORT)), recording_path: None, filter_env_var: "RUST_LOG".to_string(), self_trace: false, @@ -137,8 +139,38 @@ impl Builder { /// before falling back on constructing a socket address from those /// defaults. /// + /// The socket address can be either a TCP socket address or a + /// [Unix domain socket] (UDS) address. Unix domain sockets are only + /// supported on Unix-compatible operating systems, such as Linux, BSDs, + /// and macOS. + /// + /// Each call to this method will overwrite the previously set value. + /// + /// # Examples + /// + /// Connect to the TCP address `localhost:1234`: + /// + /// ``` + /// # use console_subscriber::Builder; + /// use std::net::Ipv4Addr; + /// let builder = Builder::default().server_addr((Ipv4Addr::LOCALHOST, 1234)); + /// ``` + /// + /// Connect to the UDS address `/tmp/tokio-console`: + /// + /// ``` + /// # use console_subscriber::Builder; + /// # #[cfg(unix)] + /// use std::path::Path; + /// + /// // Unix domain sockets are only available on Unix-compatible operating systems. + /// #[cfg(unix)] + /// let builder = Builder::default().server_addr(Path::new("/tmp/tokio-console")); + /// ``` + /// /// [environment variable]: `Builder::with_default_env` - pub fn server_addr(self, server_addr: impl Into) -> Self { + /// [Unix domain socket]: https://en.wikipedia.org/wiki/Unix_domain_socket + pub fn server_addr(self, server_addr: impl Into) -> Self { Self { server_addr: server_addr.into(), ..self @@ -231,11 +263,14 @@ impl Builder { } if let Ok(bind) = std::env::var("TOKIO_CONSOLE_BIND") { - self.server_addr = bind - .to_socket_addrs() - .expect("TOKIO_CONSOLE_BIND must be formatted as HOST:PORT, such as localhost:4321") - .next() - .expect("tokio console could not resolve TOKIO_CONSOLE_BIND"); + self.server_addr = ServerAddr::Tcp( + bind.to_socket_addrs() + .expect( + "TOKIO_CONSOLE_BIND must be formatted as HOST:PORT, such as localhost:4321", + ) + .next() + .expect("tokio console could not resolve TOKIO_CONSOLE_BIND"), + ); } if let Some(interval) = duration_from_env("TOKIO_CONSOLE_PUBLISH_INTERVAL") { @@ -456,6 +491,66 @@ impl Builder { } } +/// Specifies the address on which a [`Server`] should listen. +/// +/// This type is passed as an argument to the [`Builder::server_addr`] +/// method, and may be either a TCP socket address, or a [Unix domain socket] +/// (UDS) address. Unix domain sockets are only supported on Unix-compatible +/// operating systems, such as Linux, BSDs, and macOS. +/// +/// [`Server`]: crate::Server +/// [Unix domain socket]: https://en.wikipedia.org/wiki/Unix_domain_socket +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum ServerAddr { + /// A TCP address. + Tcp(SocketAddr), + /// A Unix socket address. + #[cfg(unix)] + Unix(PathBuf), +} + +impl From for ServerAddr { + fn from(addr: SocketAddr) -> ServerAddr { + ServerAddr::Tcp(addr) + } +} + +impl From for ServerAddr { + fn from(addr: SocketAddrV4) -> ServerAddr { + ServerAddr::Tcp(addr.into()) + } +} + +impl From for ServerAddr { + fn from(addr: SocketAddrV6) -> ServerAddr { + ServerAddr::Tcp(addr.into()) + } +} + +impl From<(I, u16)> for ServerAddr +where + I: Into, +{ + fn from(pieces: (I, u16)) -> ServerAddr { + ServerAddr::Tcp(pieces.into()) + } +} + +#[cfg(unix)] +impl From for ServerAddr { + fn from(path: PathBuf) -> ServerAddr { + ServerAddr::Unix(path) + } +} + +#[cfg(unix)] +impl<'a> From<&'a Path> for ServerAddr { + fn from(path: &'a Path) -> ServerAddr { + ServerAddr::Unix(path.to_path_buf()) + } +} + /// Initializes the console [tracing `Subscriber`][sub] and starts the console /// subscriber [`Server`] on its own background thread. /// diff --git a/console-subscriber/src/lib.rs b/console-subscriber/src/lib.rs index bde80347c..d91ee392a 100644 --- a/console-subscriber/src/lib.rs +++ b/console-subscriber/src/lib.rs @@ -5,7 +5,7 @@ use serde::Serialize; use std::{ cell::RefCell, fmt, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -13,7 +13,11 @@ use std::{ time::{Duration, Instant}, }; use thread_local::ThreadLocal; +#[cfg(unix)] +use tokio::net::UnixListener; use tokio::sync::{mpsc, oneshot}; +#[cfg(unix)] +use tokio_stream::wrappers::UnixListenerStream; use tracing_core::{ span::{self, Id}, subscriber::{self, Subscriber}, @@ -36,7 +40,7 @@ pub(crate) mod sync; mod visitors; use aggregator::Aggregator; -pub use builder::Builder; +pub use builder::{Builder, ServerAddr}; use callsites::Callsites; use record::Recorder; use stack::SpanStack; @@ -134,7 +138,7 @@ pub struct ConsoleLayer { /// [cli]: https://crates.io/crates/tokio-console pub struct Server { subscribe: mpsc::Sender, - addr: SocketAddr, + addr: ServerAddr, aggregator: Option, client_buffer: usize, } @@ -945,13 +949,22 @@ impl Server { .take() .expect("cannot start server multiple times"); let aggregate = spawn_named(aggregate.run(), "console::aggregate"); - let addr = self.addr; - let serve = builder - .add_service(proto::instrument::instrument_server::InstrumentServer::new( - self, - )) - .serve(addr); - let res = spawn_named(serve, "console::serve").await; + let addr = self.addr.clone(); + let router = builder.add_service( + proto::instrument::instrument_server::InstrumentServer::new(self), + ); + let res = match addr { + ServerAddr::Tcp(addr) => { + let serve = router.serve(addr); + spawn_named(serve, "console::serve").await + } + #[cfg(unix)] + ServerAddr::Unix(path) => { + let incoming = UnixListener::bind(path)?; + let serve = router.serve_with_incoming(UnixListenerStream::new(incoming)); + spawn_named(serve, "console::serve").await + } + }; aggregate.abort(); res?.map_err(Into::into) } diff --git a/tokio-console/Cargo.toml b/tokio-console/Cargo.toml index 6b0909312..a6f0de3f3 100644 --- a/tokio-console/Cargo.toml +++ b/tokio-console/Cargo.toml @@ -34,6 +34,7 @@ tokio = { version = "1", features = ["full", "rt-multi-thread"] } tonic = { version = "0.8", features = ["transport"] } futures = "0.3" tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } +tower = "0.4.12" tracing = "0.1" tracing-subscriber = { version = "0.3.0", features = ["env-filter"] } tracing-journald = { version = "0.2", optional = true } diff --git a/tokio-console/args.example b/tokio-console/args.example index db39c8f85..a6b59aaf8 100644 --- a/tokio-console/args.example +++ b/tokio-console/args.example @@ -7,6 +7,9 @@ ARGS: This may be an IP address and port, or a DNS name. + On Unix platforms, this may also be a URI with the `file` scheme that specifies the path + to a Unix domain socket, as in `file://localhost/path/to/socket`. + [default: http://127.0.0.1:6669] OPTIONS: diff --git a/tokio-console/src/config.rs b/tokio-console/src/config.rs index 52b980c2d..9513d7609 100644 --- a/tokio-console/src/config.rs +++ b/tokio-console/src/config.rs @@ -26,6 +26,10 @@ pub struct Config { /// /// This may be an IP address and port, or a DNS name. /// + /// On Unix platforms, this may also be a URI with the `file` scheme that + /// specifies the path to a Unix domain socket, as in + /// `file://localhost/path/to/socket`. + /// /// [default: http://127.0.0.1:6669] #[clap(value_hint = ValueHint::Url)] pub(crate) target_addr: Option, diff --git a/tokio-console/src/conn.rs b/tokio-console/src/conn.rs index faf42e9e0..6330b329b 100644 --- a/tokio-console/src/conn.rs +++ b/tokio-console/src/conn.rs @@ -5,7 +5,12 @@ use console_api::instrument::{ use console_api::tasks::TaskDetails; use futures::stream::StreamExt; use std::{error::Error, pin::Pin, time::Duration}; -use tonic::{transport::Channel, transport::Uri, Streaming}; +#[cfg(unix)] +use tokio::net::UnixStream; +use tonic::{ + transport::{Channel, Endpoint, Uri}, + Streaming, +}; #[derive(Debug)] pub struct Connection { @@ -78,7 +83,31 @@ impl Connection { tokio::time::sleep(backoff).await; } let try_connect = async { - let mut client = InstrumentClient::connect(self.target.clone()).await?; + let channel = match self.target.scheme_str() { + #[cfg(unix)] + Some("file") => { + // Dummy endpoint is ignored by the connector. + let endpoint = Endpoint::from_static("http://localhost"); + if !matches!(self.target.host(), None | Some("localhost")) { + return Err("cannot connect to non-localhost unix domain socket".into()); + } + let path = self.target.path().to_owned(); + endpoint + .connect_with_connector(tower::service_fn(move |_| { + UnixStream::connect(path.clone()) + })) + .await? + } + #[cfg(not(unix))] + Some("file") => { + return Err("unix domain sockets are not supported on this platform".into()); + } + _ => { + let endpoint = Endpoint::try_from(self.target.clone())?; + endpoint.connect().await? + } + }; + let mut client = InstrumentClient::new(channel); let request = tonic::Request::new(InstrumentRequest {}); let stream = Box::new(client.watch_updates(request).await?.into_inner()); Ok::>(State::Connected { client, stream })