Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions gateway-sp-comms/src/communicator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ use crate::management_switch::ManagementSwitchDiscovery;
use crate::management_switch::SpSocket;
use crate::management_switch::SwitchPort;
use crate::recv_handler::RecvHandler;
use crate::Elapsed;
use crate::KnownSps;
use crate::SpIdentifier;
use crate::Timeout;
use futures::stream::FuturesUnordered;
use futures::Future;
use futures::Stream;
Expand Down Expand Up @@ -43,7 +45,6 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use tokio_tungstenite::tungstenite::handshake;

/// Helper trait that allows us to return an `impl FuturesUnordered<_>` where
Expand Down Expand Up @@ -105,7 +106,7 @@ impl Communicator {
pub async fn get_ignition_state(
&self,
sp: SpIdentifier,
timeout: Instant,
timeout: Timeout,
) -> Result<IgnitionState, Error> {
let controller = self.switch.ignition_controller();
let port = self.id_to_port(sp)?;
Expand All @@ -124,7 +125,7 @@ impl Communicator {
/// Ask the local ignition controller for the ignition state of all SPs.
pub async fn get_ignition_state_all(
&self,
timeout: Instant,
timeout: Timeout,
) -> Result<Vec<(SpIdentifier, IgnitionState)>, Error> {
let controller = self.switch.ignition_controller();
let request = RequestKind::BulkIgnitionState;
Expand Down Expand Up @@ -165,7 +166,7 @@ impl Communicator {
&self,
target_sp: SpIdentifier,
command: IgnitionCommand,
timeout: Instant,
timeout: Timeout,
) -> Result<(), Error> {
let controller = self.switch.ignition_controller();
let target = self.id_to_port(target_sp)?.as_ignition_target();
Expand Down Expand Up @@ -285,7 +286,7 @@ impl Communicator {
&self,
port: SwitchPort,
packet: SerialConsole,
timeout: Instant,
timeout: Timeout,
) -> Result<(), Error> {
// We can only send to an SP's serial console if we've attached to it,
// which means we know its address.
Expand All @@ -310,15 +311,15 @@ impl Communicator {
pub async fn get_state(
&self,
sp: SpIdentifier,
timeout: Instant,
timeout: Timeout,
) -> Result<SpState, Error> {
self.get_state_maybe_timeout(sp, Some(timeout)).await
}

/// Get the state of a given SP without a timeout; it is the caller's
/// responsibility to ensure a reasonable timeout is applied higher up in
/// the chain.
// TODO we could have one method that takes `Option<Instant>` for a timeout,
// TODO we could have one method that takes `Option<Timeout>` for a timeout,
// and/or apply that to _all_ the methods in this class. I don't want to
// make it easy to accidentally call a method without providing a timeout,
// though, so went with the current design.
Expand All @@ -332,7 +333,7 @@ impl Communicator {
async fn get_state_maybe_timeout(
&self,
sp: SpIdentifier,
timeout: Option<Instant>,
timeout: Option<Timeout>,
) -> Result<SpState, Error> {
let port = self.id_to_port(sp)?;
let sp =
Expand Down Expand Up @@ -366,14 +367,10 @@ impl Communicator {
pub fn query_all_online_sps<F, T, Fut>(
&self,
ignition_state: &[(SpIdentifier, IgnitionState)],
timeout: Instant,
timeout: Timeout,
f: F,
) -> impl FuturesUnorderedImpl<
Item = (
SpIdentifier,
IgnitionState,
Option<Result<T, tokio::time::error::Elapsed>>,
),
Item = (SpIdentifier, IgnitionState, Option<Result<T, Elapsed>>),
>
where
F: FnMut(SpIdentifier) -> Fut + Clone,
Expand All @@ -386,7 +383,7 @@ impl Communicator {
let mut f = f.clone();
async move {
let val = if state.is_powered_on() {
Some(tokio::time::timeout_at(timeout, f(id)).await)
Some(timeout.timeout_at(f(id)).await)
} else {
None
};
Expand All @@ -400,22 +397,22 @@ impl Communicator {
&self,
sp: &SpSocket<'_>,
mut kind: RequestKind,
timeout: Option<Instant>,
timeout: Option<Timeout>,
mut map_response_kind: F,
) -> Result<T, Error>
where
F: FnMut(ResponseKind) -> Result<T, BadResponseType>,
{
// helper to wrap a future in a timeout if we have one
async fn maybe_with_timeout<F, U>(
timeout: Option<Instant>,
timeout: Option<Timeout>,
fut: F,
) -> Result<U, tokio::time::error::Elapsed>
) -> Result<U, Elapsed>
where
F: Future<Output = U>,
{
match timeout {
Some(t) => tokio::time::timeout_at(t, fut).await,
Some(t) => t.timeout_at(fut).await,
None => Ok(fut.await),
}
}
Expand All @@ -435,7 +432,12 @@ impl Communicator {
let duration = backoff
.next_backoff()
.expect("internal backoff policy gave up");
maybe_with_timeout(timeout, tokio::time::sleep(duration)).await?;
maybe_with_timeout(timeout, tokio::time::sleep(duration))
.await
.map_err(|err| Error::Timeout {
timeout: err.duration(),
sp: self.port_to_id(sp.port()),
})?;

// request IDs will eventually roll over; since we enforce timeouts
// this should be a non-issue in practice. does this need testing?
Expand All @@ -461,7 +463,11 @@ impl Communicator {

Ok::<ResponseKind, SpCommunicationError>(response_fut.await?)
})
.await?;
.await
.map_err(|err| Error::Timeout {
timeout: err.duration(),
sp: self.port_to_id(sp.port()),
})?;

match result {
Ok(response_kind) => {
Expand Down
11 changes: 3 additions & 8 deletions gateway-sp-comms/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::SpIdentifier;
use gateway_messages::ResponseError;
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use thiserror::Error;

#[derive(Debug, Error)]
Expand All @@ -26,8 +27,8 @@ pub enum Error {
.0.slot,
)]
SpAddressUnknown(SpIdentifier),
#[error("timeout elapsed")]
Timeout,
#[error("timeout ({timeout:?}) elapsed communicating with {sp:?}")]
Timeout { timeout: Duration, sp: SpIdentifier },
#[error("error communicating with SP: {0}")]
SpCommunicationFailed(#[from] SpCommunicationError),
#[error("serial console is already attached")]
Expand All @@ -54,9 +55,3 @@ pub struct BadResponseType {
pub expected: &'static str,
pub got: &'static str,
}

impl From<tokio::time::error::Elapsed> for Error {
fn from(_: tokio::time::error::Elapsed) -> Self {
Self::Timeout
}
}
3 changes: 3 additions & 0 deletions gateway-sp-comms/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
mod communicator;
mod management_switch;
mod recv_handler;
mod timeout;

pub use usdt::register_probes;

Expand All @@ -25,6 +26,8 @@ pub use communicator::Communicator;
pub use communicator::FuturesUnorderedImpl;
pub use management_switch::SpIdentifier;
pub use management_switch::SpType;
pub use timeout::Elapsed;
pub use timeout::Timeout;

// TODO these will remain public for a while, but eventually will be removed
// altogther; currently these provide a way to hard-code the rack topology,
Expand Down
3 changes: 2 additions & 1 deletion gateway-sp-comms/src/recv_handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::management_switch::ManagementSwitch;
use crate::management_switch::ManagementSwitchDiscovery;
use crate::management_switch::SwitchPort;
use crate::Communicator;
use crate::Timeout;
use futures::future::Fuse;
use futures::FutureExt;
use futures::SinkExt;
Expand Down Expand Up @@ -414,7 +415,7 @@ impl SerialConsoleTask {
.serial_console_send_packet(
self.port,
packet,
tokio::time::Instant::now() + self.sp_ack_timeout,
Timeout::from_now(self.sp_ack_timeout),
)
.map_ok(move |()| packet_data_len)
.fuse();
Expand Down
58 changes: 58 additions & 0 deletions gateway-sp-comms/src/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

// Copyright 2022 Oxide Computer Company

use futures::Future;
use futures::TryFutureExt;
use std::time::Duration;
use tokio::time::Instant;

/// Error type returned from [`Timeout::timeout_at()`].
#[derive(Debug, Clone, Copy)]
pub struct Elapsed(pub Timeout);

impl Elapsed {
/// Get the duration of the timeout that elapsed.
pub fn duration(&self) -> Duration {
self.0.duration()
}
}

/// Representation of a timeout as both its starting time and its duration.
#[derive(Debug, Clone, Copy)]
pub struct Timeout {
start: Instant,
duration: Duration,
}

impl Timeout {
/// Create a new `Timeout` with the given duration starting from
/// [`Instant::now()`].
pub fn from_now(duration: Duration) -> Self {
Self { start: Instant::now(), duration }
}

/// Get the [`Instant`] when this timeout expires.
pub fn end(&self) -> Instant {
self.start + self.duration
}

/// Get the duration of this timeout.
pub fn duration(&self) -> Duration {
self.duration
}

/// Wrap a future with this timeout.
pub fn timeout_at<T>(
self,
future: T,
) -> impl Future<Output = Result<T::Output, Elapsed>>
where
T: Future,
{
tokio::time::timeout_at(self.end(), future)
.map_err(move |_| Elapsed(self))
}
}
11 changes: 6 additions & 5 deletions gateway/src/bulk_state_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ use crate::error::InvalidPageToken;
use futures::StreamExt;
use gateway_messages::IgnitionState;
use gateway_sp_comms::Communicator;
use gateway_sp_comms::Elapsed;
use gateway_sp_comms::FuturesUnorderedImpl;
use gateway_sp_comms::SpIdentifier;
use gateway_sp_comms::Timeout;
use serde::Deserialize;
use serde::Serialize;
use slog::debug;
Expand All @@ -75,7 +77,6 @@ use std::sync::Mutex;
use std::sync::RwLock;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::Instant;
use uuid::Uuid;

use crate::http_entrypoints::SpState;
Expand Down Expand Up @@ -144,7 +145,7 @@ impl BulkSpStateRequests {

pub(crate) async fn start(
&self,
timeout: Instant,
timeout: Timeout,
retain_grace_period: Duration,
) -> Result<SpStateRequestId, Error> {
// set up the receiving end of all SP responses
Expand Down Expand Up @@ -189,7 +190,7 @@ impl BulkSpStateRequests {
&self,
id: &SpStateRequestId,
last_seen: Option<SpIdentifier>,
timeout: Instant,
timeout: Timeout,
limit: usize,
) -> Result<BulkStateProgress, Error> {
let log = self.log.new(slog::o!(
Expand All @@ -201,7 +202,7 @@ impl BulkSpStateRequests {

// Go ahead and create (and pin) the timeout, but we don't actually
// await it until the loop at the bottom of this function.
let timeout = tokio::time::sleep_until(timeout);
let timeout = tokio::time::sleep_until(timeout.end());
tokio::pin!(timeout);

let collector =
Expand Down Expand Up @@ -381,7 +382,7 @@ async fn wait_for_sp_responses<S>(
Item = (
SpIdentifier,
IgnitionState,
Option<Result<SpStateResult, tokio::time::error::Elapsed>>,
Option<Result<SpStateResult, Elapsed>>,
),
>,
{
Expand Down
2 changes: 1 addition & 1 deletion gateway/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ where
err.to_string(),
),
SpCommsError::SpAddressUnknown(_)
| SpCommsError::Timeout
| SpCommsError::Timeout { .. }
| SpCommsError::SpCommunicationFailed(_) => {
HttpError::for_internal_error(err.to_string())
}
Expand Down
Loading