From 9a91b13b77ec784efb9d6ae7c7a28a48ed321968 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Thu, 7 May 2026 15:51:14 +0000 Subject: [PATCH 1/2] chore: std::sync::Mutex -> parking_lot::Mutex This patch migrates from stdlib mutexes to parking lot mutexes. This probably provides some small perf win, but the main reason I want it is for the arc_lock method that I'm going to need in order to fix a race condition in how we handle attaches. --- Cargo.lock | 57 +++++++++++++++++++++- libshpool/Cargo.toml | 1 + libshpool/src/daemon/exit_notify.rs | 26 +++++----- libshpool/src/daemon/pager.rs | 15 +++--- libshpool/src/daemon/server.rs | 76 +++++++++++++++-------------- libshpool/src/daemon/shell.rs | 11 +++-- libshpool/src/daemon/show_motd.rs | 11 ++--- libshpool/src/daemon/ttl_reaper.rs | 5 +- libshpool/src/lib.rs | 4 +- libshpool/src/test_hooks.rs | 14 +++--- shpool/Cargo.toml | 1 + shpool/tests/daemon.rs | 4 +- shpool/tests/support/daemon.rs | 13 ++--- shpool/tests/support/mod.rs | 7 +-- 14 files changed, 150 insertions(+), 95 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e94c2e44..fa2c2bda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -327,7 +327,7 @@ checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.4.1", "windows-sys 0.52.0", ] @@ -554,6 +554,7 @@ dependencies = [ "nix", "notify", "ntest", + "parking_lot", "rmp-serde", "serde", "serde_derive", @@ -584,6 +585,15 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.22" @@ -735,6 +745,29 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.5.18", + "smallvec", + "windows-link", +] + [[package]] name = "paste" version = "1.0.15" @@ -822,6 +855,15 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags 2.4.2", +] + [[package]] name = "regex" version = "1.12.2" @@ -919,6 +961,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.228" @@ -1001,6 +1049,7 @@ dependencies = [ "libshpool", "nix", "ntest", + "parking_lot", "rand", "regex", "serde_json", @@ -1449,6 +1498,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/libshpool/Cargo.toml b/libshpool/Cargo.toml index 0d192d31..c80d78d4 100644 --- a/libshpool/Cargo.toml +++ b/libshpool/Cargo.toml @@ -45,6 +45,7 @@ strip-ansi-escapes = "0.2.0" # cleaning up strings for pager display notify = { version = "7", features = ["crossbeam-channel"] } # watch config file for updates libproc = "0.14.8" # sniffing shells by examining the subprocess daemonize = "0.5" # autodaemonization +parking_lot = "0.12" # faster more featureful sync primitives shpool-protocol = { version = "0.3.5", path = "../shpool-protocol" } # client-server protocol # rusty wrapper for unix apis diff --git a/libshpool/src/daemon/exit_notify.rs b/libshpool/src/daemon/exit_notify.rs index da898a4a..7e627e69 100644 --- a/libshpool/src/daemon/exit_notify.rs +++ b/libshpool/src/daemon/exit_notify.rs @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - sync::{Condvar, Mutex}, - time::Duration, -}; +use std::time::Duration; + +use parking_lot::{Condvar, Mutex}; #[derive(Debug)] pub struct ExitNotifier { @@ -30,7 +29,7 @@ impl ExitNotifier { /// Notify all waiters that the process has exited. pub fn notify_exit(&self, status: i32) { - let mut slot = self.slot.lock().unwrap(); + let mut slot = self.slot.lock(); *slot = Some(status); self.cond.notify_all(); } @@ -38,7 +37,7 @@ impl ExitNotifier { /// Wait for the process to exit, with an optional timeout /// to allow the caller to wake up periodically. pub fn wait(&self, timeout: Option) -> Option { - let slot = self.slot.lock().unwrap(); + let mut slot = self.slot.lock(); // If a thread waits on the exit status when the child has already // exited, we just want to immediately return. @@ -48,19 +47,16 @@ impl ExitNotifier { match timeout { Some(t) => { - // returns a lock result, so we want to unwrap - // to propagate the lock poisoning - let (exit_status, wait_res) = self - .cond - .wait_timeout_while(slot, t, |exit_status| exit_status.is_none()) - .unwrap(); - if wait_res.timed_out() { + if self.cond.wait_for(&mut slot, t).timed_out() { None } else { - *exit_status + *slot } } - None => *self.cond.wait_while(slot, |exit_status| exit_status.is_none()).unwrap(), + None => { + self.cond.wait(&mut slot); + *slot + } } } } diff --git a/libshpool/src/daemon/pager.rs b/libshpool/src/daemon/pager.rs index 625c44c8..e4c08539 100644 --- a/libshpool/src/daemon/pager.rs +++ b/libshpool/src/daemon/pager.rs @@ -40,13 +40,14 @@ use std::{ }, process, sync::atomic::{AtomicBool, Ordering}, - sync::{Arc, Mutex}, + sync::Arc, thread, time::{Duration, Instant}, }; use anyhow::{anyhow, Context}; use nix::{poll, sys::signal, unistd}; +use parking_lot::Mutex; use shpool_protocol::{Chunk, ChunkKind, TtySize}; use tracing::{error, info, instrument, span, trace, warn, Level}; @@ -111,7 +112,7 @@ impl Pager { let (tty_size_change_tx, tty_size_change_rx) = crossbeam_channel::bounded(0); let (tty_size_change_ack_tx, tty_size_change_ack_rx) = crossbeam_channel::bounded(0); { - let mut ctl_handle = ctl_slot.lock().unwrap(); + let mut ctl_handle = ctl_slot.lock(); if ctl_handle.is_some() { return Err(anyhow!("only one pager per session at a time allowed")); } @@ -195,7 +196,7 @@ impl Pager { { // register the new size so it will get returned - let mut tty_size = tty_size_ref.lock().unwrap(); + let mut tty_size = tty_size_ref.lock(); *tty_size = size; } @@ -220,7 +221,7 @@ impl Pager { ]; let nready = poll::poll(&mut poll_fds, POLL_MS).context("polling both streams")?; if pager_exited.load(Ordering::Relaxed) { - let tty_size = tty_size.lock().unwrap(); + let tty_size = tty_size.lock(); return Ok(tty_size.clone()); } if nready == 0 { @@ -289,13 +290,13 @@ impl Pager { // assume the pager proc just quit normally and the // timing was such that we didn't pick it up with our // exit watcher thread. - let tty_size = tty_size.lock().unwrap(); + let tty_size = tty_size.lock(); return Ok(tty_size.clone()); } if let Err(e) = pty_master.flush() { info!("Error flushing pager pty, nbd though: {:?}", e); // same logic as above - let tty_size = tty_size.lock().unwrap(); + let tty_size = tty_size.lock(); return Ok(tty_size.clone()); } } @@ -312,7 +313,7 @@ struct PagerCltGuard { impl std::ops::Drop for PagerCltGuard { fn drop(&mut self) { - let mut pager_ctl = self.ctl_slot.lock().unwrap(); + let mut pager_ctl = self.ctl_slot.lock(); // N.B. clobbering the handles here will cause the listening // thread to exit because it drops the senders. This ensures // that no callers can grab the lock on the ctl handles and diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index adbeb622..fc87c61f 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -26,13 +26,14 @@ use std::{ }, path::{Path, PathBuf}, process, - sync::{Arc, Mutex}, + sync::Arc, thread, time, time::{Duration, Instant}, }; use anyhow::{anyhow, Context}; use nix::unistd; +use parking_lot::Mutex; use shpool_protocol::{ AttachHeader, AttachReplyHeader, AttachStatus, ConnectHeader, DetachReply, DetachRequest, KillReply, KillRequest, ListReply, LogLevel, ResizeReply, Session, SessionMessageDetachReply, @@ -222,7 +223,7 @@ impl Server { let user_info = user::info().context("resolving user info")?; let shell_env = self.build_shell_env(&user_info, &header).context("building shell env")?; - let (child_exit_notifier, inner_to_stream, pager_ctl_slot, status) = + let (shell_results, status) = match self.select_shell_desc(stream, conn_id, &header, &user_info, &shell_env) { Ok(t) => t, Err(err) @@ -239,11 +240,9 @@ impl Server { self.link_ssh_auth_sock(&header).context("linking SSH_AUTH_SOCK")?; self.populate_session_env_file(&header).context("populating session env file")?; - if let (Some(child_exit_notifier), Some(inner), Some(pager_ctl_slot)) = - (child_exit_notifier, inner_to_stream, pager_ctl_slot) - { + if let Some((child_exit_notifier, inner, pager_ctl_slot)) = shell_results { let mut child_done = false; - let mut inner = inner.lock().unwrap(); + let mut inner = inner.lock(); let client_stream = match inner.client_stream.as_mut() { Some(s) => s, None => { @@ -309,7 +308,7 @@ impl Server { { let _s = span!(Level::INFO, "2_lock(shells)").entered(); - let mut shells = self.shells.lock().unwrap(); + let mut shells = self.shells.lock(); shells.remove(&header.name); } @@ -327,9 +326,9 @@ impl Server { // Client disconnected but shell is still running - set last_disconnected_at { let _s = span!(Level::INFO, "disconnect_lock(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); if let Some(session) = shells.get(&header.name) { - session.lifecycle_timestamps.lock().unwrap().last_disconnected_at = + session.lifecycle_timestamps.lock().last_disconnected_at = Some(time::SystemTime::now()); } } @@ -355,9 +354,11 @@ impl Server { user_info: &user::Info, shell_env: &[(OsString, OsString)], ) -> anyhow::Result<( - Option>, - Option>>, - Option>>>, + Option<( + Arc, + Arc>, + Arc>>, + )>, AttachStatus, )> { let warnings = vec![]; @@ -370,11 +371,11 @@ impl Server { { // we unwrap to propagate the poison as an unwind let _s = span!(Level::INFO, "select_shell_lock_1(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); if let Some(session) = shells.get(&header.name) { info!("found entry for '{}'", header.name); - if let Ok(mut inner) = session.inner.try_lock() { + if let Some(mut inner) = session.inner.try_lock() { let _s = span!(Level::INFO, "aquired_lock(session.inner)", s = header.name) .entered(); // We have an existing session in our table, but the subshell @@ -392,7 +393,7 @@ impl Server { // the channel is still open so the subshell is still running info!("taking over existing session inner"); inner.client_stream = Some(stream.try_clone()?); - session.lifecycle_timestamps.lock().unwrap().last_connected_at = + session.lifecycle_timestamps.lock().last_connected_at = Some(time::SystemTime::now()); if inner @@ -466,12 +467,11 @@ impl Server { matches!(motd, MotdDisplayMode::Dump), )?; - session.lifecycle_timestamps.lock().unwrap().last_connected_at = - Some(time::SystemTime::now()); + session.lifecycle_timestamps.lock().last_connected_at = Some(time::SystemTime::now()); { // we unwrap to propagate the poison as an unwind let _s = span!(Level::INFO, "select_shell_lock_2(shells)").entered(); - let mut shells = self.shells.lock().unwrap(); + let mut shells = self.shells.lock(); shells.insert(header.name.clone(), Box::new(session)); } // fallthrough to bidi streaming @@ -481,20 +481,22 @@ impl Server { // we unwrap to propagate the poison as an unwind let _s = span!(Level::INFO, "select_shell_lock_3(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); // return a reference to the inner session so that // we can work with it without the global session // table lock held if let Some(session) = shells.get(&header.name) { Ok(( - Some(Arc::clone(&session.child_exit_notifier)), - Some(Arc::clone(&session.inner)), - Some(Arc::clone(&session.pager_ctl)), + Some(( + Arc::clone(&session.child_exit_notifier), + Arc::clone(&session.inner), + Arc::clone(&session.pager_ctl), + )), status, )) } else { - Ok((None, None, None, status)) + Ok((None, status)) } } @@ -555,11 +557,11 @@ impl Server { let mut not_attached_sessions = vec![]; { let _s = span!(Level::INFO, "lock(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); for session in request.sessions.into_iter() { if let Some(s) = shells.get(&session) { let _s = span!(Level::INFO, "lock(shell_to_client_ctl)", s = session).entered(); - let shell_to_client_ctl = s.shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = s.shell_to_client_ctl.lock(); shell_to_client_ctl .client_connection .send(shell::ClientConnectionMsg::Disconnect) @@ -572,7 +574,7 @@ impl Server { if let shell::ClientConnectionStatus::DetachNone = status { not_attached_sessions.push(session); } else { - s.lifecycle_timestamps.lock().unwrap().last_disconnected_at = + s.lifecycle_timestamps.lock().last_disconnected_at = Some(time::SystemTime::now()); } } else { @@ -614,7 +616,7 @@ impl Server { let mut not_found_sessions = vec![]; { let _s = span!(Level::INFO, "lock(shells)").entered(); - let mut shells = self.shells.lock().unwrap(); + let mut shells = self.shells.lock(); let mut to_remove = Vec::with_capacity(request.sessions.len()); for session in request.sessions.into_iter() { @@ -645,17 +647,17 @@ impl Server { #[instrument(skip_all)] fn handle_list(&self, mut stream: UnixStream) -> anyhow::Result<()> { let _s = span!(Level::INFO, "lock(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); let sessions: anyhow::Result> = shells .iter() .map(|(k, v)| { let status = match v.inner.try_lock() { - Ok(_) => SessionStatus::Disconnected, - Err(_) => SessionStatus::Attached, + Some(_) => SessionStatus::Disconnected, + None => SessionStatus::Attached, }; - let timestamps = v.lifecycle_timestamps.lock().unwrap(); + let timestamps = v.lifecycle_timestamps.lock(); let last_connected_at_unix_ms = timestamps .last_connected_at .map(|t| t.duration_since(time::UNIX_EPOCH).map(|d| d.as_millis() as i64)) @@ -704,7 +706,7 @@ impl Server { SessionMessageRequestPayload::Resize(resize_request) => { let pager_ctl = { let _s = span!(Level::INFO, "resize_lock_1(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); if let Some(session) = shells.get(&header.session_name) { Arc::clone(&session.pager_ctl) } else { @@ -712,7 +714,7 @@ impl Server { } }; let _s = span!(Level::INFO, "lock(pager_ctl)").entered(); - let pager_ctl = pager_ctl.lock().unwrap(); + let pager_ctl = pager_ctl.lock(); if let Some(pager_ctl) = pager_ctl.as_ref() { info!("resizing pager"); @@ -727,7 +729,7 @@ impl Server { } else { let shell_to_client_ctl = { let _s = span!(Level::INFO, "resize_lock_2(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); if let Some(session) = shells.get(&header.session_name) { Arc::clone(&session.shell_to_client_ctl) } else { @@ -735,7 +737,7 @@ impl Server { } }; let _s = span!(Level::INFO, "lock(shell_to_client_ctl)").entered(); - let shell_to_client_ctl = shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = shell_to_client_ctl.lock(); shell_to_client_ctl .tty_size_change @@ -752,7 +754,7 @@ impl Server { SessionMessageRequestPayload::Detach => { let shell_to_client_ctl = { let _s = span!(Level::INFO, "detach_lock(shells)").entered(); - let shells = self.shells.lock().unwrap(); + let shells = self.shells.lock(); if let Some(session) = shells.get(&header.session_name) { Arc::clone(&session.shell_to_client_ctl) } else { @@ -760,7 +762,7 @@ impl Server { } }; let _s = span!(Level::INFO, "lock(shell_to_client_ctl)").entered(); - let shell_to_client_ctl = shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = shell_to_client_ctl.lock(); shell_to_client_ctl .client_connection diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index ee74bc2c..7c40ae76 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -20,7 +20,7 @@ use std::{ os::unix::net::UnixStream, sync::{ atomic::{AtomicBool, Ordering}, - Arc, Mutex, + Arc, }, thread, time, time::Duration, @@ -28,6 +28,7 @@ use std::{ use anyhow::{anyhow, Context}; use nix::{poll, poll::PollFlags, sys::signal, unistd::Pid}; +use parking_lot::Mutex; use shpool_protocol::{Chunk, ChunkKind, TtySize}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; @@ -593,7 +594,7 @@ impl SessionInner { { let _s = span!(Level::INFO, "initial_attach_lock(shell_to_client_ctl)").entered(); - let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = self.shell_to_client_ctl.lock(); shell_to_client_ctl .client_connection .send_timeout( @@ -659,7 +660,7 @@ impl SessionInner { let c_done = child_done.load(Ordering::Acquire); { let _s = span!(Level::INFO, "disconnect_lock(shell_to_client_ctl)").entered(); - let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = self.shell_to_client_ctl.lock(); let send_res = shell_to_client_ctl.client_connection.send_timeout(if c_done { let exit_status = child_exit_notifier .wait(Some(Duration::from_secs(0))) @@ -891,7 +892,7 @@ impl SessionInner { return Ok(()); } { - let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = self.shell_to_client_ctl.lock(); match shell_to_client_ctl .heartbeat .send_timeout((), SHELL_TO_CLIENT_CTL_TIMEOUT) @@ -984,7 +985,7 @@ impl SessionInner { #[instrument(skip_all)] fn action_detach(&self) -> anyhow::Result<()> { - let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap(); + let shell_to_client_ctl = self.shell_to_client_ctl.lock(); shell_to_client_ctl .client_connection .send_timeout(ClientConnectionMsg::Disconnect, SHELL_TO_CLIENT_CTL_TIMEOUT) diff --git a/libshpool/src/daemon/show_motd.rs b/libshpool/src/daemon/show_motd.rs index 2a36bc30..a31f4618 100644 --- a/libshpool/src/daemon/show_motd.rs +++ b/libshpool/src/daemon/show_motd.rs @@ -12,15 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - ffi::OsString, - io, - os::unix::net::UnixStream, - sync::{Arc, Mutex}, - time, -}; +use std::{ffi::OsString, io, os::unix::net::UnixStream, sync::Arc, time}; use anyhow::{anyhow, Context}; +use parking_lot::Mutex; use shpool_protocol::{Chunk, ChunkKind, TtySize}; use tracing::{info, instrument}; @@ -186,7 +181,7 @@ impl Debouncer { #[instrument(skip_all)] fn should_fire(&self) -> anyhow::Result { - let mut last_fired = self.last_fired.lock().unwrap(); + let mut last_fired = self.last_fired.lock(); if last_fired.elapsed()? >= self.dur { let old_ts: chrono::DateTime = (*last_fired).into(); *last_fired = time::SystemTime::now(); diff --git a/libshpool/src/daemon/ttl_reaper.rs b/libshpool/src/daemon/ttl_reaper.rs index f5664b75..b85b8c21 100644 --- a/libshpool/src/daemon/ttl_reaper.rs +++ b/libshpool/src/daemon/ttl_reaper.rs @@ -23,10 +23,11 @@ use std::{ cmp, collections::{BinaryHeap, HashMap}, - sync::{Arc, Mutex}, + sync::Arc, time::Instant, }; +use parking_lot::Mutex; use tracing::{info, span, warn, Level}; use super::shell; @@ -103,7 +104,7 @@ pub fn run( } let _s = span!(Level::INFO, "lock(shells)").entered(); - let mut shells = shells.lock().unwrap(); + let mut shells = shells.lock(); if let Some(sess) = shells.get(&reapable.session_name) { if let Err(e) = sess.kill() { warn!("error trying to kill '{}': {:?}", diff --git a/libshpool/src/lib.rs b/libshpool/src/lib.rs index 528fcd62..03758038 100644 --- a/libshpool/src/lib.rs +++ b/libshpool/src/lib.rs @@ -18,12 +18,12 @@ use std::{ hash::{Hash, Hasher}, io, path::PathBuf, - sync::{Mutex, MutexGuard}, }; use anyhow::{anyhow, Context}; use clap::{Parser, Subcommand}; pub use hooks::Hooks; +use parking_lot::{Mutex, MutexGuard}; use tracing::error; use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; @@ -266,7 +266,7 @@ impl<'writer> tracing_subscriber::fmt::MakeWriter<'writer> for LogWriterBuilder fn make_writer(&'writer self) -> Self::Writer { if let Some(log_file) = &self.log_file { - Box::new(MutexGuardWriter(log_file.lock().expect("poisoned"))) + Box::new(MutexGuardWriter(log_file.lock())) } else if self.is_daemon { Box::new(io::stderr()) } else { diff --git a/libshpool/src/test_hooks.rs b/libshpool/src/test_hooks.rs index b1598cfb..286b5d0c 100644 --- a/libshpool/src/test_hooks.rs +++ b/libshpool/src/test_hooks.rs @@ -23,16 +23,16 @@ use std::{ io::Write, os::unix::net::{UnixListener, UnixStream}, - sync::Mutex, time, }; use anyhow::{anyhow, Context}; +use parking_lot::Mutex; use tracing::{error, info}; #[cfg(feature = "test_hooks")] pub fn emit(event: &str) { - let sock_path = TEST_HOOK_SERVER.sock_path.lock().unwrap(); + let sock_path = TEST_HOOK_SERVER.sock_path.lock(); if sock_path.is_some() { TEST_HOOK_SERVER.emit_event(event); } @@ -83,7 +83,7 @@ impl TestHookServer { } pub fn set_socket_path(&self, path: String) { - let mut sock_path = self.sock_path.lock().unwrap(); + let mut sock_path = self.sock_path.lock(); *sock_path = Some(path); } @@ -91,7 +91,7 @@ impl TestHookServer { let mut sleep_dur = time::Duration::from_millis(5); for _ in 0..12 { { - let clients = self.clients.lock().unwrap(); + let clients = self.clients.lock(); if !clients.is_empty() { return Ok(()); } @@ -112,7 +112,7 @@ impl TestHookServer { pub fn start(&self) { let sock_path: String; { - let sock_path_m = self.sock_path.lock().unwrap(); + let sock_path_m = self.sock_path.lock(); match &*sock_path_m { Some(s) => { sock_path = String::from(s); @@ -141,7 +141,7 @@ impl TestHookServer { continue; } }; - let mut clients = self.clients.lock().unwrap(); + let mut clients = self.clients.lock(); clients.push(stream); } } @@ -149,7 +149,7 @@ impl TestHookServer { fn emit_event(&self, event: &str) { info!("emitting event '{}'", event); let event_line = format!("{event}\n"); - let clients = self.clients.lock().unwrap(); + let clients = self.clients.lock(); for mut client in clients.iter() { if let Err(e) = client.write_all(event_line.as_bytes()) { error!("error emitting '{}' event: {:?}", event, e); diff --git a/shpool/Cargo.toml b/shpool/Cargo.toml index 717e870c..13b7650b 100644 --- a/shpool/Cargo.toml +++ b/shpool/Cargo.toml @@ -31,6 +31,7 @@ regex = "1" # test assertions serde_json = "1" # json parsing ntest = "0.9" # test timeouts rand = "0.8" # tmp files for tests +parking_lot = "0.12" # faster more featureful sync primitives # rusty wrapper for unix apis [dependencies.nix] diff --git a/shpool/tests/daemon.rs b/shpool/tests/daemon.rs index bfd0917e..dd0c0592 100644 --- a/shpool/tests/daemon.rs +++ b/shpool/tests/daemon.rs @@ -225,11 +225,11 @@ fn hooks() -> anyhow::Result<()> { sh1_proc.run_cmd("exit")?; // 1 shell disconnect support::wait_until(|| { - let hook_records = daemon_proc.hook_records.as_ref().unwrap().lock().unwrap(); + let hook_records = daemon_proc.hook_records.as_ref().unwrap().lock(); Ok(!hook_records.shell_disconnects.is_empty()) })?; - let hook_records = daemon_proc.hook_records.as_ref().unwrap().lock().unwrap(); + let hook_records = daemon_proc.hook_records.as_ref().unwrap().lock(); eprintln!("hook_records: {hook_records:?}"); assert_eq!(hook_records.new_sessions[0], "sh1"); assert_eq!(hook_records.reattaches[0], "sh1"); diff --git a/shpool/tests/support/daemon.rs b/shpool/tests/support/daemon.rs index 91cd2aa7..2f670e1c 100644 --- a/shpool/tests/support/daemon.rs +++ b/shpool/tests/support/daemon.rs @@ -8,11 +8,12 @@ use std::{ path::{Path, PathBuf}, process, process::{Command, Stdio}, - sync::{Arc, Mutex}, + sync::Arc, thread, time, }; use anyhow::{anyhow, Context}; +use parking_lot::Mutex; use super::{attach, events::Events, shpool_bin, testdata_file, tmpdir, wait_until}; @@ -63,35 +64,35 @@ pub struct HooksRecorder { impl libshpool::Hooks for HooksRecorder { fn on_new_session(&self, session_name: &str) -> anyhow::Result<()> { eprintln!("on_new_session: {session_name}"); - let mut recs = self.records.lock().unwrap(); + let mut recs = self.records.lock(); recs.new_sessions.push(String::from(session_name)); Ok(()) } fn on_reattach(&self, session_name: &str) -> anyhow::Result<()> { eprintln!("on_reattach: {session_name}"); - let mut recs = self.records.lock().unwrap(); + let mut recs = self.records.lock(); recs.reattaches.push(String::from(session_name)); Ok(()) } fn on_busy(&self, session_name: &str) -> anyhow::Result<()> { eprintln!("on_busy: {session_name}"); - let mut recs = self.records.lock().unwrap(); + let mut recs = self.records.lock(); recs.busys.push(String::from(session_name)); Ok(()) } fn on_client_disconnect(&self, session_name: &str) -> anyhow::Result<()> { eprintln!("on_client_disconnect: {session_name}"); - let mut recs = self.records.lock().unwrap(); + let mut recs = self.records.lock(); recs.client_disconnects.push(String::from(session_name)); Ok(()) } fn on_shell_disconnect(&self, session_name: &str) -> anyhow::Result<()> { eprintln!("on_shell_disconnect: {session_name}"); - let mut recs = self.records.lock().unwrap(); + let mut recs = self.records.lock(); recs.shell_disconnects.push(String::from(session_name)); Ok(()) } diff --git a/shpool/tests/support/mod.rs b/shpool/tests/support/mod.rs index becd4457..067d5bca 100644 --- a/shpool/tests/support/mod.rs +++ b/shpool/tests/support/mod.rs @@ -8,7 +8,6 @@ use std::{ io::BufRead, path::{Path, PathBuf}, process::Command, - sync::Mutex, time, }; @@ -28,8 +27,10 @@ pub fn testdata_file>(file: P) -> PathBuf { } lazy_static::lazy_static! { - // cache the result and make sure we only ever compile once - static ref SHPOOL_BIN_PATH: Mutex> = Mutex::new(None); + // cache the result and make sure we only ever compile once. + // We can't use a parking lot mutex here because of the Sized + // constraint for static vars. + static ref SHPOOL_BIN_PATH: std::sync::Mutex> = std::sync::Mutex::new(None); } pub fn wait_until

(mut pred: P) -> anyhow::Result<()> From 8021f79df1c498da03f0cd50ed0a5d44a3a34ecb Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Thu, 7 May 2026 16:17:42 +0000 Subject: [PATCH 2/2] fix: race race in attaching to existing sessions This patch fixes a race condition that could happen when two different procs raced to attach to a given session. Since we release the session lock breifly as part of the attach handshake, we were leaving a window open in which two racing attaches could trample on one another. The regression test in this patch shows the issue in detail. To fix the issue, I had select_shell_desc start returning the lock guard for the session so that there was no gap in which the lock was released. In order to do that, I needed to switch to use parking_lot mutexes, since it is impossible to return the lock guard for an Arc> using the stdlib types. --- libshpool/Cargo.toml | 2 +- libshpool/src/daemon/server.rs | 84 +++++++++++++++++++--------------- libshpool/src/test_hooks.rs | 82 +++++++++++++++++++++++++++++---- shpool/tests/regression.rs | 48 +++++++++++++++++++ shpool/tests/support/daemon.rs | 24 ++++++---- shpool/tests/support/events.rs | 22 ++++++--- 6 files changed, 201 insertions(+), 61 deletions(-) diff --git a/libshpool/Cargo.toml b/libshpool/Cargo.toml index c80d78d4..f7bb985a 100644 --- a/libshpool/Cargo.toml +++ b/libshpool/Cargo.toml @@ -45,7 +45,7 @@ strip-ansi-escapes = "0.2.0" # cleaning up strings for pager display notify = { version = "7", features = ["crossbeam-channel"] } # watch config file for updates libproc = "0.14.8" # sniffing shells by examining the subprocess daemonize = "0.5" # autodaemonization -parking_lot = "0.12" # faster more featureful sync primitives +parking_lot = { version = "0.12", features = ["arc_lock"] } # faster more featureful sync primitives shpool-protocol = { version = "0.3.5", path = "../shpool-protocol" } # client-server protocol # rusty wrapper for unix apis diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index fc87c61f..ce73a8d0 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -33,7 +33,7 @@ use std::{ use anyhow::{anyhow, Context}; use nix::unistd; -use parking_lot::Mutex; +use parking_lot::{ArcMutexGuard, Mutex, RawMutex}; use shpool_protocol::{ AttachHeader, AttachReplyHeader, AttachStatus, ConnectHeader, DetachReply, DetachRequest, KillReply, KillRequest, ListReply, LogLevel, ResizeReply, Session, SessionMessageDetachReply, @@ -223,6 +223,7 @@ impl Server { let user_info = user::info().context("resolving user info")?; let shell_env = self.build_shell_env(&user_info, &header).context("building shell env")?; + test_hooks::emit("handle-attach-before-select-shell"); let (shell_results, status) = match self.select_shell_desc(stream, conn_id, &header, &user_info, &shell_env) { Ok(t) => t, @@ -240,9 +241,10 @@ impl Server { self.link_ssh_auth_sock(&header).context("linking SSH_AUTH_SOCK")?; self.populate_session_env_file(&header).context("populating session env file")?; - if let Some((child_exit_notifier, inner, pager_ctl_slot)) = shell_results { + test_hooks::emit("handle-attach-before-inner-session-lock"); + + if let Some((child_exit_notifier, mut inner, pager_ctl_slot)) = shell_results { let mut child_done = false; - let mut inner = inner.lock(); let client_stream = match inner.client_stream.as_mut() { Some(s) => s, None => { @@ -356,13 +358,12 @@ impl Server { ) -> anyhow::Result<( Option<( Arc, - Arc>, + ArcMutexGuard, Arc>>, )>, AttachStatus, )> { let warnings = vec![]; - let mut status = AttachStatus::Attached { warnings: warnings.clone() }; // Critical section for the global shells lock. We only hold it // while grubbing about for an existing session, then release it @@ -375,7 +376,7 @@ impl Server { if let Some(session) = shells.get(&header.name) { info!("found entry for '{}'", header.name); - if let Some(mut inner) = session.inner.try_lock() { + if let Some(mut inner) = session.inner.try_lock_arc() { let _s = span!(Level::INFO, "aquired_lock(session.inner)", s = header.name) .entered(); // We have an existing session in our table, but the subshell @@ -405,7 +406,22 @@ impl Server { warn!( "child_exited chan unclosed, but shell->client thread has exited, clobbering with new subshell" ); - status = AttachStatus::Created { warnings }; + } else { + if let Err(err) = self.hooks.on_reattach(&header.name) { + warn!("reattach hook: {:?}", err); + } + // Immediately return so that we never give + // up the inner lock for a session we are + // reattaching to. We only want to give it + // up when creating a new session. + return Ok(( + Some(( + Arc::clone(&session.child_exit_notifier), + inner, + Arc::clone(&session.pager_ctl), + )), + AttachStatus::Attached { warnings }, + )); } // status is already attached @@ -416,7 +432,6 @@ impl Server { "stale inner, (child exited with status {}) clobbering with new subshell", exit_status ); - status = AttachStatus::Created { warnings }; } } @@ -432,7 +447,6 @@ impl Server { .map_err(|e| anyhow!("joining shell->client on reattach: {:?}", e))? .context("within shell->client thread on reattach")?; } - assert!(matches!(status, AttachStatus::Created { .. })); } // fallthrough to bidi streaming @@ -448,35 +462,29 @@ impl Server { } } else { info!("no existing '{}' session, creating new one", &header.name); - status = AttachStatus::Created { warnings }; } }; - if matches!(status, AttachStatus::Created { .. }) { - info!("creating new subshell"); - if let Err(err) = self.hooks.on_new_session(&header.name) { - warn!("new_session hook: {:?}", err); - } - let motd = self.config.get().motd.clone().unwrap_or_default(); - let session = self.spawn_subshell( - conn_id, - stream, - header, - user_info, - shell_env, - matches!(motd, MotdDisplayMode::Dump), - )?; - - session.lifecycle_timestamps.lock().last_connected_at = Some(time::SystemTime::now()); - { - // we unwrap to propagate the poison as an unwind - let _s = span!(Level::INFO, "select_shell_lock_2(shells)").entered(); - let mut shells = self.shells.lock(); - shells.insert(header.name.clone(), Box::new(session)); - } - // fallthrough to bidi streaming - } else if let Err(err) = self.hooks.on_reattach(&header.name) { - warn!("reattach hook: {:?}", err); + info!("creating new subshell"); + if let Err(err) = self.hooks.on_new_session(&header.name) { + warn!("new_session hook: {:?}", err); + } + let motd = self.config.get().motd.clone().unwrap_or_default(); + let session = self.spawn_subshell( + conn_id, + stream, + header, + user_info, + shell_env, + matches!(motd, MotdDisplayMode::Dump), + )?; + + session.lifecycle_timestamps.lock().last_connected_at = Some(time::SystemTime::now()); + { + // we unwrap to propagate the poison as an unwind + let _s = span!(Level::INFO, "select_shell_lock_2(shells)").entered(); + let mut shells = self.shells.lock(); + shells.insert(header.name.clone(), Box::new(session)); } // we unwrap to propagate the poison as an unwind @@ -490,13 +498,13 @@ impl Server { Ok(( Some(( Arc::clone(&session.child_exit_notifier), - Arc::clone(&session.inner), + session.inner.lock_arc(), Arc::clone(&session.pager_ctl), )), - status, + AttachStatus::Created { warnings }, )) } else { - Ok((None, status)) + Ok((None, AttachStatus::UnexpectedError(String::from("selecting session")))) } } diff --git a/libshpool/src/test_hooks.rs b/libshpool/src/test_hooks.rs index 286b5d0c..5fa4b2ab 100644 --- a/libshpool/src/test_hooks.rs +++ b/libshpool/src/test_hooks.rs @@ -21,20 +21,22 @@ // we publish a unix socket and then clients can listen for specific // named events in order to block until they have occurred. use std::{ - io::Write, + collections::HashSet, + io::{BufRead, Write}, os::unix::net::{UnixListener, UnixStream}, - time, + thread, time, }; use anyhow::{anyhow, Context}; -use parking_lot::Mutex; +use parking_lot::{Condvar, Mutex}; use tracing::{error, info}; #[cfg(feature = "test_hooks")] pub fn emit(event: &str) { - let sock_path = TEST_HOOK_SERVER.sock_path.lock(); - if sock_path.is_some() { + let has_sock = TEST_HOOK_SERVER.sock_path.lock().is_some(); + if has_sock { TEST_HOOK_SERVER.emit_event(event); + TEST_HOOK_SERVER.maybe_pause(event); } } @@ -75,11 +77,18 @@ lazy_static::lazy_static! { pub struct TestHookServer { sock_path: Mutex>, clients: Mutex>, + pending_pauses: Mutex>, + pause_cv: Condvar, } impl TestHookServer { fn new() -> Self { - TestHookServer { sock_path: Mutex::new(None), clients: Mutex::new(vec![]) } + TestHookServer { + sock_path: Mutex::new(None), + clients: Mutex::new(vec![]), + pending_pauses: Mutex::new(HashSet::new()), + pause_cv: Condvar::new(), + } } pub fn set_socket_path(&self, path: String) { @@ -141,8 +150,19 @@ impl TestHookServer { continue; } }; - let mut clients = self.clients.lock(); - clients.push(stream); + match stream.try_clone() { + Ok(stream_clone) => { + let mut clients = self.clients.lock(); + clients.push(stream); + + thread::spawn(move || { + TEST_HOOK_SERVER.handle_client(stream_clone); + }); + } + Err(e) => { + error!("error cloning test hook stream: {:?}", e); + } + } } } @@ -156,4 +176,50 @@ impl TestHookServer { } } } + + fn handle_client(&self, stream: UnixStream) { + let mut reader = std::io::BufReader::new(stream); + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line) { + Ok(0) => break, // EOF + Ok(_) => { + let parts: Vec<&str> = line.trim().splitn(2, ' ').collect(); + if parts.len() == 2 { + let cmd = parts[0]; + let event = parts[1]; + match cmd { + "pause-at" => { + info!("test requested pause at '{}'", event); + self.pending_pauses.lock().insert(event.to_string()); + } + "release" => { + info!("test requested release of '{}'", event); + self.pending_pauses.lock().remove(event); + self.pause_cv.notify_all(); + } + _ => error!("unknown test hook command: {}", cmd), + } + } + } + Err(e) => { + error!("error reading from test hook client: {:?}", e); + break; + } + } + } + } + + fn maybe_pause(&self, event: &str) { + let mut pending = self.pending_pauses.lock(); + if pending.contains(event) { + info!("pausing at '{}'", event); + self.emit_event(&format!("paused-at {event}")); + while pending.contains(event) { + self.pause_cv.wait(&mut pending); + } + info!("resuming from '{}'", event); + } + } } diff --git a/shpool/tests/regression.rs b/shpool/tests/regression.rs index 91ca0e0c..7e2420a0 100644 --- a/shpool/tests/regression.rs +++ b/shpool/tests/regression.rs @@ -252,3 +252,51 @@ fn pager_eof_does_not_spin() -> anyhow::Result<()> { Ok(()) } + +/// Regression test for a race condition where concurrent attach attempts to +/// a disconnected session can clobber each other's stream, leading to +/// "no client stream" error in the daemon. +#[test] +#[timeout(30000)] +fn concurrent_attach_to_existing_session_race() -> anyhow::Result<()> { + let mut daemon_proc = support::daemon::Proc::new("norc.toml", DaemonArgs::default()) + .context("starting daemon proc")?; + + let mut attach_first = + daemon_proc.attach("main", Default::default()).context("starting first attach proc")?; + daemon_proc.await_event("daemon-bidi-stream-enter")?; + + attach_first.proc.kill()?; + daemon_proc.wait_until_list_matches(|out| out.contains("disconnected"))?; + + // Pause the daemon before it locks the session to trigger the race. + daemon_proc.send_event_command("pause-at handle-attach-before-inner-session-lock")?; + + let _attach_a = daemon_proc.attach("main", Default::default()).context("starting attach A")?; + daemon_proc.await_event("paused-at handle-attach-before-inner-session-lock")?; + + let _attach_b = daemon_proc.attach("main", Default::default()).context("starting attach B")?; + daemon_proc.await_event("handle-attach-before-select-shell")?; + + // In fixed code, attach b can't finish select shell because of a lock, so + // we need to use a sleep here to allow it to enter in broken code. + std::thread::sleep(Duration::from_millis(500)); + + daemon_proc.send_event_command("release handle-attach-before-inner-session-lock")?; + + // Disconnect B to trigger the error for A which is using B's stream. + drop(_attach_b); + + std::thread::sleep(Duration::from_millis(500)); + + let log_content = fs::read_to_string(&daemon_proc.log_file)?; + + // On buggy code, the clobbered stream causes a "no client stream" error + // when the second attach tries to take over after the first exits. + assert!( + !log_content.contains("no client stream, should be impossible"), + "REGRESSION: Daemon logged 'no client stream' error!" + ); + + Ok(()) +} diff --git a/shpool/tests/support/daemon.rs b/shpool/tests/support/daemon.rs index 2f670e1c..b9ebb4c3 100644 --- a/shpool/tests/support/daemon.rs +++ b/shpool/tests/support/daemon.rs @@ -116,7 +116,7 @@ impl Proc { let test_hook_socket_path = tmp_dir.path().join("hook.sock"); let log_file = tmp_dir.path().join("daemon.log"); - eprintln!("spawning daemon proc with log {:?}", &log_file); + eprintln!("spawning daemon proc with log {:?}", log_file); let resolved_config = if config.as_ref().exists() { PathBuf::from(config.as_ref()) @@ -196,7 +196,7 @@ impl Proc { testdata_file(config) }; - eprintln!("spawning instrumented daemon thread with log {:?}", &log_file); + eprintln!("spawning instrumented daemon thread with log {:?}", log_file); let args = libshpool::Args { log_file: Some( @@ -289,7 +289,7 @@ impl Proc { self.tmp_dir.path().join(format!("attach_{}_{}.log", name, self.subproc_counter)); let test_hook_socket_path = self.tmp_dir.path().join(format!("ah{}_{}.sock", name, self.subproc_counter)); - eprintln!("spawning attach proc with log {:?}", &log_file); + eprintln!("spawning attach proc with log {:?}", log_file); self.subproc_counter += 1; let mut cmd = Command::new(&self.bin_path); @@ -350,7 +350,7 @@ impl Proc { pub fn detach(&mut self, sessions: Vec) -> anyhow::Result { let log_file = self.tmp_dir.path().join(format!("detach_{}.log", self.subproc_counter)); - eprintln!("spawning detach proc with log {:?}", &log_file); + eprintln!("spawning detach proc with log {:?}", log_file); self.subproc_counter += 1; let mut cmd = Command::new(&self.bin_path); @@ -369,7 +369,7 @@ impl Proc { pub fn kill(&mut self, sessions: Vec) -> anyhow::Result { let log_file = self.tmp_dir.path().join(format!("kill_{}.log", self.subproc_counter)); - eprintln!("spawning kill proc with log {:?}", &log_file); + eprintln!("spawning kill proc with log {:?}", log_file); self.subproc_counter += 1; let mut cmd = Command::new(&self.bin_path); @@ -406,7 +406,7 @@ impl Proc { /// output and returns it as a string pub fn list(&mut self) -> anyhow::Result { let log_file = self.tmp_dir.path().join(format!("list_{}.log", self.subproc_counter)); - eprintln!("spawning list proc with log {:?}", &log_file); + eprintln!("spawning list proc with log {:?}", log_file); self.subproc_counter += 1; Command::new(&self.bin_path) @@ -422,7 +422,7 @@ impl Proc { pub fn list_json(&mut self) -> anyhow::Result { let log_file = self.tmp_dir.path().join(format!("list_{}.log", self.subproc_counter)); - eprintln!("spawning list --json proc with log {:?}", &log_file); + eprintln!("spawning list --json proc with log {:?}", log_file); self.subproc_counter += 1; Command::new(&self.bin_path) @@ -441,7 +441,7 @@ impl Proc { pub fn set_log_level(&mut self, level: &str) -> anyhow::Result { let log_file = self.tmp_dir.path().join(format!("set_log_level_{}.log", self.subproc_counter)); - eprintln!("spawning set-log-level proc with log {:?}", &log_file); + eprintln!("spawning set-log-level proc with log {:?}", log_file); self.subproc_counter += 1; Command::new(&self.bin_path) @@ -463,6 +463,14 @@ impl Proc { Err(anyhow!("no events stream")) } } + + pub fn send_event_command(&mut self, cmd: &str) -> anyhow::Result<()> { + if let Some(events) = &mut self.events { + events.send_command(cmd) + } else { + Err(anyhow!("no events stream")) + } + } } impl std::ops::Drop for Proc { diff --git a/shpool/tests/support/events.rs b/shpool/tests/support/events.rs index a5aaf98f..253d977e 100644 --- a/shpool/tests/support/events.rs +++ b/shpool/tests/support/events.rs @@ -8,6 +8,7 @@ use anyhow::anyhow; /// an EventWaiter with the `waiter` or `await_event` routines. pub struct Events { lines: io::Lines>, + writer: UnixStream, } impl Events { @@ -15,7 +16,8 @@ impl Events { let mut sleep_dur = time::Duration::from_millis(5); for _ in 0..12 { if let Ok(s) = UnixStream::connect(&sock) { - return Ok(Events { lines: io::BufReader::new(s).lines() }); + let writer = s.try_clone()?; + return Ok(Events { lines: io::BufReader::new(s).lines(), writer }); } else { std::thread::sleep(sleep_dur); sleep_dur *= 2; @@ -69,7 +71,8 @@ impl Events { } if return_lines { - tx.send(WaiterEvent::Done((events[offset].clone(), self.lines))).unwrap(); + tx.send(WaiterEvent::Done((events[offset].clone(), self.lines, self.writer))) + .unwrap(); } }); @@ -90,6 +93,13 @@ impl Events { Ok(()) } + + pub fn send_command(&mut self, cmd: &str) -> anyhow::Result<()> { + use std::io::Write; + self.writer.write_all(format!("{cmd}\n").as_bytes())?; + self.writer.flush()?; + Ok(()) + } } /// EventWaiter represents waiting for a particular event. @@ -101,7 +111,7 @@ pub struct EventWaiter { enum WaiterEvent { Event(String), - Done((String, io::Lines>)), + Done((String, io::Lines>, UnixStream)), } impl EventWaiter { @@ -115,7 +125,7 @@ impl EventWaiter { Err(anyhow!("Got '{}' event, want '{}'", e, event)) } } - WaiterEvent::Done((e, _)) => { + WaiterEvent::Done((e, _, _)) => { if e == event { Ok(()) } else { @@ -131,9 +141,9 @@ impl EventWaiter { WaiterEvent::Event(e) => { Err(anyhow!("Got non-fianl '{}' event, want final '{}'", e, event)) } - WaiterEvent::Done((e, lines)) => { + WaiterEvent::Done((e, lines, writer)) => { if e == event { - Ok(Events { lines }) + Ok(Events { lines, writer }) } else { Err(anyhow!("Got '{}' event, want '{}'", e, event)) }