diff --git a/Cargo.lock b/Cargo.lock index bdad796d13d..73de6bfcae3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1691,8 +1691,10 @@ dependencies = [ "bitflags", "hubpack", "serde", + "serde-big-array 0.4.1", "serde_repr", "smoltcp", + "static_assertions", ] [[package]] @@ -5253,6 +5255,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "steno" version = "0.2.0" diff --git a/gateway-messages/Cargo.toml b/gateway-messages/Cargo.toml index 4f0c73bcf03..9b5bdcbf39e 100644 --- a/gateway-messages/Cargo.toml +++ b/gateway-messages/Cargo.toml @@ -7,7 +7,9 @@ license = "MPL-2.0" [dependencies] bitflags = "1.3.2" serde = { version = "1.0.144", default-features = false, features = ["derive"] } +serde-big-array = "0.4.1" serde_repr = { version = "0.1" } +static_assertions = "1.1.0" hubpack = { git = "https://github.com/cbiffle/hubpack", rev = "df08cc3a6e1f97381cd0472ae348e310f0119e25" } diff --git a/gateway-messages/src/lib.rs b/gateway-messages/src/lib.rs index 9f95a3fff3b..f2d2445b973 100644 --- a/gateway-messages/src/lib.rs +++ b/gateway-messages/src/lib.rs @@ -5,19 +5,23 @@ #![cfg_attr(all(not(test), not(feature = "std")), no_std)] pub mod sp_impl; -mod variable_packet; use bitflags::bitflags; use core::fmt; +use core::mem; use core::str; use serde::Deserialize; use serde::Serialize; use serde_repr::Deserialize_repr; use serde_repr::Serialize_repr; +use static_assertions::const_assert; pub use hubpack::error::Error as HubpackError; pub use hubpack::{deserialize, serialize, SerializedSize}; +/// Maximum size in bytes for a serialized message. +pub const MAX_SERIALIZED_SIZE: usize = 1024; + pub mod version { pub const V1: u32 = 1; } @@ -34,19 +38,23 @@ pub struct Request { } #[derive(Debug, Clone, SerializedSize, Serialize, Deserialize)] -// TODO: Rework how we serialize packets that contain a large amount of data -// (`SerialConsole`, `UpdateChunk`) to make this enum smaller. -#[allow(clippy::large_enum_variant)] pub enum RequestKind { Discover, // TODO do we want to be able to request IgnitionState for all targets in // one message? - IgnitionState { target: u8 }, + IgnitionState { + target: u8, + }, BulkIgnitionState, - IgnitionCommand { target: u8, command: IgnitionCommand }, + IgnitionCommand { + target: u8, + command: IgnitionCommand, + }, SpState, - SerialConsoleWrite(SerialConsole), + /// `SerialConsoleWrite` always includes trailing raw data. + SerialConsoleWrite(SpComponent), UpdateStart(UpdateStart), + /// `UpdateChunk` always includes trailing raw data. UpdateChunk(UpdateChunk), SysResetPrepare, SysResetTrigger, @@ -70,9 +78,6 @@ pub enum SpPort { Two = 2, } -// TODO: Not all SPs are capable of crafting all these response kinds, but the -// way we're using hubpack requires everyone to allocate Response::MAX_SIZE. Is -// that okay, or should we break this up more? #[derive(Debug, Clone, SerializedSize, Serialize, Deserialize)] pub enum ResponseKind { Discover(DiscoverResponse), @@ -185,7 +190,7 @@ pub enum SpMessageKind { /// Data traveling from an SP-attached component (in practice, a CPU) on the /// component's serial console. - SerialConsole(SerialConsole), + SerialConsole(SpComponent), } #[derive( @@ -204,79 +209,10 @@ pub struct UpdateStart { // TODO should we inline the first chunk? } -#[derive(Debug, Clone, PartialEq, SerializedSize)] +#[derive(Debug, Clone, PartialEq, SerializedSize, Serialize, Deserialize)] pub struct UpdateChunk { /// Offset in bytes of this chunk from the beginning of the update data. pub offset: u32, - /// Length in bytes of this chunk. - pub chunk_length: u16, - /// Data of this chunk; only the first `chunk_length` bytes should be used. - pub data: [u8; Self::MAX_CHUNK_SIZE], -} - -mod update_chunk_serde { - use super::variable_packet::VariablePacket; - use super::*; - - #[derive(Debug, Deserialize, Serialize)] - pub(crate) struct Header { - offset: u32, - chunk_length: u16, - } - - impl VariablePacket for UpdateChunk { - type Header = Header; - type Element = u8; - - const MAX_ELEMENTS: usize = Self::MAX_CHUNK_SIZE; - const DESERIALIZE_NAME: &'static str = "update chunk"; - - fn header(&self) -> Self::Header { - Header { offset: self.offset, chunk_length: self.chunk_length } - } - - fn num_elements(&self) -> u16 { - self.chunk_length - } - - fn elements(&self) -> &[Self::Element] { - &self.data - } - - fn elements_mut(&mut self) -> &mut [Self::Element] { - &mut self.data - } - - fn from_header(header: Self::Header) -> Self { - Self { - offset: header.offset, - chunk_length: header.chunk_length, - data: [0; Self::MAX_CHUNK_SIZE], - } - } - } - - impl Serialize for UpdateChunk { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - VariablePacket::serialize(self, serializer) - } - } - - impl<'de> Deserialize<'de> for UpdateChunk { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - VariablePacket::deserialize(deserializer) - } - } -} - -impl UpdateChunk { - pub const MAX_CHUNK_SIZE: usize = 512; } #[derive( @@ -317,112 +253,23 @@ bitflags! { } } -#[derive(Clone, PartialEq, SerializedSize)] +#[derive(Debug, Clone, PartialEq, SerializedSize, Serialize, Deserialize)] pub struct BulkIgnitionState { - /// Number of ignition targets present in `targets`. - pub num_targets: u16, /// Ignition state for each target. /// /// TODO The ignition target is implicitly the array index; is that /// reasonable or should we specify target indices explicitly? + #[serde(with = "serde_big_array::BigArray")] pub targets: [IgnitionState; Self::MAX_IGNITION_TARGETS], } -impl fmt::Debug for BulkIgnitionState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut debug = f.debug_struct("BulkIgnitionState"); - debug.field("num_targets", &self.num_targets); - let targets = &self.targets[..usize::from(self.num_targets)]; - debug.field("targets", &targets); - debug.finish() - } -} - impl BulkIgnitionState { - // TODO We need to decide how to set max sizes for packets that may contain - // a variable amount of data. There are (at least) three concerns: - // - // 1. It determines a max packet size; we need to make sure this stays under - // whatever limit is in place on the management network. - // 2. It determines the size of the relevant structs/enums (and - // corresponding serialization/deserialization buffers). This is almost - // certainly irrelevant for MGS, but is very relevant for SPs. - // 3. What are the implications on versioning of changing the size? It - // doesn't actually affect the packet format on the wire, but a receiver - // with a lower compiled-in max size will reject packets it receives with - // more data than its max size. - // - // plus one note: these max sizes do not include the header overhead for the - // packets; that needs to be accounted for (particularly for point 1 above). - // - // Another question specific to `BulkIgnitionState`: Will we always send - // "max number of targets in the rack" states, even if some slots are - // unpopulated? Maybe this message shouldn't be variable at all. For now we - // leave it like it is; it's certainly "variable" in the sense that our - // simulated racks for tests have fewer than 36 targets. + // TODO-cleanup Is it okay to hard code this number to what we know the + // value is for the initial rack? For now assuming yes, and any changes in + // future products could use a different message. pub const MAX_IGNITION_TARGETS: usize = 36; } -mod bulk_ignition_state_serde { - use super::variable_packet::VariablePacket; - use super::*; - - #[derive(Debug, Deserialize, Serialize)] - pub(crate) struct Header { - num_targets: u16, - } - - impl VariablePacket for BulkIgnitionState { - type Header = Header; - type Element = IgnitionState; - - const MAX_ELEMENTS: usize = Self::MAX_IGNITION_TARGETS; - const DESERIALIZE_NAME: &'static str = "bulk ignition state packet"; - - fn header(&self) -> Self::Header { - Header { num_targets: self.num_targets } - } - - fn num_elements(&self) -> u16 { - self.num_targets - } - - fn elements(&self) -> &[Self::Element] { - &self.targets - } - - fn elements_mut(&mut self) -> &mut [Self::Element] { - &mut self.targets - } - - fn from_header(header: Self::Header) -> Self { - Self { - num_targets: header.num_targets, - targets: [IgnitionState::default(); - BulkIgnitionState::MAX_IGNITION_TARGETS], - } - } - } - - impl Serialize for BulkIgnitionState { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - VariablePacket::serialize(self, serializer) - } - } - - impl<'de> Deserialize<'de> for BulkIgnitionState { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - VariablePacket::deserialize(deserializer) - } - } -} - #[derive( Debug, Clone, Copy, SerializedSize, Serialize, Deserialize, PartialEq, )] @@ -451,6 +298,10 @@ impl SpComponent { /// Maximum number of bytes for a component ID. pub const MAX_ID_LENGTH: usize = 16; + /// The `sp3` host CPU. + pub const SP3_HOST_CPU: Self = + Self { id: *b"sp3\0\0\0\0\0\0\0\0\0\0\0\0\0" }; + /// Interpret the component name as a human-readable string. /// /// Our current expectation of component names is that this should never @@ -499,159 +350,51 @@ impl TryFrom<&str> for SpComponent { } } -// We could derive `Copy`, but `data` is large-ish so we want callers to think -// abount cloning. -#[derive(Clone, SerializedSize)] -pub struct SerialConsole { - /// Source component with an attached serial console. - pub component: SpComponent, - - /// Offset of this chunk of data relative to all console data this - /// source has sent since it booted. The receiver can determine if it's - /// missed data and reconstruct out-of-order packets based on this value - /// plus `len`. - pub offset: u64, - - /// Number of bytes in `data`. - pub len: u16, - - /// Actual serial console data. - pub data: [u8; Self::MAX_DATA_PER_PACKET], -} - -impl fmt::Debug for SerialConsole { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut debug = f.debug_struct("SerialConsole"); - debug.field("component", &self.component); - debug.field("offset", &self.offset); - debug.field("len", &self.len); - let data = &self.data[..usize::from(self.len)]; - if let Ok(s) = str::from_utf8(data) { - debug.field("data", &s); - } else { - debug.field("data", &data); - } - debug.finish() - } -} - -impl SerialConsole { - // TODO: See discussion on `BulkIgnitionState::MAX_IGNITION_TARGETS` for - // concerns about setting this limit. - // - // A concern specific to `SerialConsole`: What should we do (if anything) to - // account for something like "user `cat`s a large file, which is now - // streaming across the management network"? A couple possibilities: - // - // 1. One packet per line, and truncate any lines longer than - // `MAX_DATA_PER_PACKET` (seems like this could be _very_ annoying if a - // user bumped into it without realizing it). - // 2. Rate limiting (enforced where?) - pub const MAX_DATA_PER_PACKET: usize = 128; -} - -mod serial_console_serde { - use super::variable_packet::VariablePacket; - use super::*; - - #[derive(Debug, Deserialize, Serialize)] - pub(crate) struct Header { - component: SpComponent, - offset: u64, - len: u16, - } - - impl VariablePacket for SerialConsole { - type Header = Header; - type Element = u8; - - const MAX_ELEMENTS: usize = Self::MAX_DATA_PER_PACKET; - const DESERIALIZE_NAME: &'static str = "serial console packet"; - - fn header(&self) -> Self::Header { - Header { - component: self.component, - offset: self.offset, - len: self.len, - } - } - - fn num_elements(&self) -> u16 { - self.len - } - - fn elements(&self) -> &[Self::Element] { - &self.data - } - - fn elements_mut(&mut self) -> &mut [Self::Element] { - &mut self.data - } - - fn from_header(header: Self::Header) -> Self { - Self { - component: header.component, - offset: header.offset, - len: header.len, - data: [0; Self::MAX_DATA_PER_PACKET], - } - } - } - - impl Serialize for SerialConsole { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - VariablePacket::serialize(self, serializer) - } - } - - impl<'de> Deserialize<'de> for SerialConsole { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - VariablePacket::deserialize(deserializer) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn roundtrip_serial_console() { - let line = b"hello world\n"; - let mut console = SerialConsole { - component: SpComponent { id: *b"0000111122223333" }, - offset: 12345, - len: line.len() as u16, - data: [0xff; SerialConsole::MAX_DATA_PER_PACKET], - }; - console.data[..line.len()].copy_from_slice(line); - - let mut serialized = [0; SerialConsole::MAX_SIZE]; - let n = serialize(&mut serialized, &console).unwrap(); - - // serialized size should be limited to actual line length, not - // the size of `console.data` (`MAX_DATA_PER_PACKET`) - assert_eq!( - n, - SpComponent::MAX_SIZE + u64::MAX_SIZE + u16::MAX_SIZE + line.len() - ); - - let (deserialized, _) = - deserialize::(&serialized[..n]).unwrap(); - assert_eq!(deserialized.len, console.len); - assert_eq!(&deserialized.data[..line.len()], line); - } - - #[test] - fn serial_console_data_length_fits_in_u16() { - // this is just a sanity check that if we bump `MAX_DATA_PER_PACKET` - // above 65535 we also need to change the type of `SerialConsole::len` - assert!(SerialConsole::MAX_DATA_PER_PACKET <= usize::from(u16::MAX)); - } +/// Sealed trait restricting the types that can be passed to +/// [`serialize_with_trailing_data()`]. +pub trait GatewayMessage: SerializedSize + Serialize + private::Sealed {} +mod private { + pub trait Sealed {} +} +impl GatewayMessage for Request {} +impl GatewayMessage for SpMessage {} +impl private::Sealed for Request {} +impl private::Sealed for SpMessage {} + +// `GatewayMessage` imlementers can be followed by binary data; we want the +// majority of our packet to be available for that data. Statically check that +// our serialized message headers haven't gotten too large. The specific value +// here is arbitrary; if this check starts failing, it's probably fine to reduce +// it some. The check is here to force us to think about it. +const_assert!(MAX_SERIALIZED_SIZE - Request::MAX_SIZE > 700); +const_assert!(MAX_SERIALIZED_SIZE - SpMessage::MAX_SIZE > 700); + +/// Returns `(serialized_size, data_bytes_written)` where `serialized_size` is +/// the message size written to `out` and `data_bytes_written` is the number of +/// bytes included in `out` from `data`. +pub fn serialize_with_trailing_data( + out: &mut [u8; MAX_SERIALIZED_SIZE], + header: &T, + data: &[u8], +) -> (usize, usize) +where + T: GatewayMessage, +{ + // We know `T` is either `Request` or `SpMessage`, both of which we know + // statically (confirmed by `const_assert`s above) are significantly smaller + // than `MAX_SERIALIZED_SIZE`. They cannot fail to serialize for any reason + // other than an undersized buffer, so we can unwrap here. + let n = hubpack::serialize(out, header).unwrap(); + let out = &mut out[n..]; + + // How much data can we fit in what's left, leaving room for a 2-byte + // prefix? We know `out.len() > 2` thanks to the static assertion comparing + // `Request::MAX_SIZE` and `MAX_SERIALIZED_SIZE` at the root of our crate. + let to_write = usize::min(data.len(), out.len() - mem::size_of::()); + + out[..mem::size_of::()] + .copy_from_slice(&(to_write as u16).to_le_bytes()); + out[mem::size_of::()..][..to_write].copy_from_slice(&data[..to_write]); + + (n + mem::size_of::() + to_write, to_write) } diff --git a/gateway-messages/src/sp_impl.rs b/gateway-messages/src/sp_impl.rs index 9aadf81fa6d..c113ecf54bb 100644 --- a/gateway-messages/src/sp_impl.rs +++ b/gateway-messages/src/sp_impl.rs @@ -13,7 +13,6 @@ use crate::Request; use crate::RequestKind; use crate::ResponseError; use crate::ResponseKind; -use crate::SerialConsole; use crate::SpComponent; use crate::SpMessage; use crate::SpMessageKind; @@ -22,7 +21,7 @@ use crate::SpState; use crate::UpdateChunk; use crate::UpdateStart; use core::convert::Infallible; -use hubpack::SerializedSize; +use core::mem; #[cfg(feature = "std")] use std::net::SocketAddrV6; @@ -80,6 +79,7 @@ pub trait SpHandler { sender: SocketAddrV6, port: SpPort, chunk: UpdateChunk, + data: &[u8], ) -> Result<(), ResponseError>; // TODO Should we return "number of bytes written" here, or is it sufficient @@ -89,17 +89,18 @@ pub trait SpHandler { &mut self, sender: SocketAddrV6, port: SpPort, - packet: SerialConsole, + component: SpComponent, + data: &[u8], ) -> Result<(), ResponseError>; - fn sys_reset_prepare( + fn reset_prepare( &mut self, sender: SocketAddrV6, port: SpPort, ) -> Result<(), ResponseError>; // On success, this method cannot return (it should perform a reset). - fn sys_reset_trigger( + fn reset_trigger( &mut self, sender: SocketAddrV6, port: SpPort, @@ -124,80 +125,18 @@ impl From for Error { } } -#[derive(Debug)] -pub struct SerialConsolePacketizer { - component: SpComponent, - offset: u64, -} - -impl SerialConsolePacketizer { - pub fn new(component: SpComponent) -> Self { - Self { component, offset: 0 } - } - - pub fn packetize<'a, 'b>( - &'a mut self, - data: &'b [u8], - ) -> SerialConsolePackets<'a, 'b> { - SerialConsolePackets { parent: self, data } - } - - /// Extract the first packet from `data`, returning that packet and any - /// remaining data (which may be empty). - /// - /// Panics if `data` is empty. - pub fn first_packet<'a>( - &mut self, - data: &'a [u8], - ) -> (SerialConsole, &'a [u8]) { - if data.is_empty() { - panic!(); - } - - let (this_packet, remaining) = data.split_at(usize::min( - data.len(), - SerialConsole::MAX_DATA_PER_PACKET, - )); - - let mut packet = SerialConsole { - component: self.component, - offset: self.offset, - len: this_packet.len() as u16, - data: [0; SerialConsole::MAX_DATA_PER_PACKET], - }; - packet.data[..this_packet.len()].copy_from_slice(this_packet); - - self.offset += this_packet.len() as u64; - - (packet, remaining) +/// Unpack the 2-byte length-prefixed trailing data that comes after some +/// packets (e.g., update chunks, serial console). +pub fn unpack_trailing_data(data: &[u8]) -> hubpack::error::Result<&[u8]> { + if data.len() < mem::size_of::() { + return Err(hubpack::error::Error::Truncated); } - - // TODO this function exists only to allow callers to inject artifical gaps - // in the data they're sending; should we gate it behind a cargo feature? - pub fn danger_emulate_dropped_packets(&mut self, bytes_to_skip: u64) { - self.offset += bytes_to_skip; - } -} - -#[derive(Debug)] -pub struct SerialConsolePackets<'a, 'b> { - parent: &'a mut SerialConsolePacketizer, - data: &'b [u8], -} - -impl Iterator for SerialConsolePackets<'_, '_> { - type Item = SerialConsole; - - fn next(&mut self) -> Option { - if self.data.is_empty() { - return None; - } - - let (packet, remaining) = self.parent.first_packet(self.data); - self.data = remaining; - - Some(packet) + let (prefix, data) = data.split_at(mem::size_of::()); + let len = u16::from_le_bytes([prefix[0], prefix[1]]); + if data.len() != usize::from(len) { + return Err(hubpack::error::Error::Invalid); } + Ok(data) } /// Handle a single incoming message. @@ -213,16 +152,13 @@ pub fn handle_message( port: SpPort, data: &[u8], handler: &mut H, - out: &mut [u8; SpMessage::MAX_SIZE], + out: &mut [u8; crate::MAX_SERIALIZED_SIZE], ) -> Result { // parse request, with sanity checks on sizes - if data.len() > Request::MAX_SIZE { + if data.len() > crate::MAX_SERIALIZED_SIZE { return Err(Error::DataTooLarge); } let (request, leftover) = hubpack::deserialize::(data)?; - if !leftover.is_empty() { - return Err(Error::LeftoverData); - } // `version` is intentionally the first 4 bytes of the packet; we could // check it before trying to deserialize? @@ -230,6 +166,20 @@ pub fn handle_message( return Err(Error::UnsupportedVersion(request.version)); } + // Do we expect any trailing raw data? Only for specific kinds of messages; + // if we get any for other messages, bail out. + let trailing_data = match &request.kind { + RequestKind::UpdateChunk(_) | RequestKind::SerialConsoleWrite(_) => { + unpack_trailing_data(leftover)? + } + _ => { + if !leftover.is_empty() { + return Err(Error::LeftoverData); + } + &[] + } + }; + // call out to handler to provide response let result = match request.kind { RequestKind::Discover => { @@ -251,16 +201,16 @@ pub fn handle_message( .update_start(sender, port, update) .map(|()| ResponseKind::UpdateStartAck), RequestKind::UpdateChunk(chunk) => handler - .update_chunk(sender, port, chunk) + .update_chunk(sender, port, chunk, trailing_data) .map(|()| ResponseKind::UpdateChunkAck), RequestKind::SerialConsoleWrite(packet) => handler - .serial_console_write(sender, port, packet) + .serial_console_write(sender, port, packet, trailing_data) .map(|()| ResponseKind::SerialConsoleWriteAck), RequestKind::SysResetPrepare => handler - .sys_reset_prepare(sender, port) + .reset_prepare(sender, port) .map(|()| ResponseKind::SysResetPrepareAck), RequestKind::SysResetTrigger => { - handler.sys_reset_trigger(sender, port).map(|infallible| { + handler.reset_trigger(sender, port).map(|infallible| { // A bit of type system magic here; `sys_reset_trigger`'s // success type (`Infallible`) cannot be instantiated. We can // provide an empty match to teach the type system that an diff --git a/gateway-messages/src/variable_packet.rs b/gateway-messages/src/variable_packet.rs deleted file mode 100644 index 764a5dd8457..00000000000 --- a/gateway-messages/src/variable_packet.rs +++ /dev/null @@ -1,119 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -// Copyright 2022 Oxide Computer Company - -//! Helper trait to write serde `Serialize`/`Deserialize` implementations that -//! traffic in a header followed by a variable number of elements (the count of -//! which is described by the header). This allows us to live within `hubpack`'s -//! static world with a fixed max size without sending padding for -//! underpopulated messages. - -use core::marker::PhantomData; -use serde::de::{DeserializeOwned, Error, Visitor}; -use serde::ser::SerializeTuple; -use serde::Serialize; - -pub(crate) trait VariablePacket { - type Header: DeserializeOwned + Serialize; - type Element: DeserializeOwned + Serialize; - - const MAX_ELEMENTS: usize; - const DESERIALIZE_NAME: &'static str; - - // construct a header from this instance - fn header(&self) -> Self::Header; - - // number of elements actually contained in this instance - fn num_elements(&self) -> u16; - - // `elements` and `elements_mut` can return slices up to - // `Self::MAX_ELEMENTS` long; the `serialize`/`deserialize` implementations - // will shorten them to `num_elements()` as needed - fn elements(&self) -> &[Self::Element]; - fn elements_mut(&mut self) -> &mut [Self::Element]; - - // construct an instance from `header` with empty/zero'd elements that - // `deserialize` will populate before returning - fn from_header(header: Self::Header) -> Self; - - // We can't `impl Serialize for T { .. }` due to - // coherence rules, so instead we'll plop the implementation here, and all - // our types that implement `VariablePacket` can now have 1-line - // `Serialize`/`Deserialize` implementations. - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let header = self.header(); - let num_elements = usize::from(self.num_elements()); - - // serialize ourselves as a tuple containing our header + each element - let mut tup = serializer.serialize_tuple(1 + num_elements)?; - tup.serialize_element(&header)?; - - // This is the same as what serde's default serialize implementation - // does, but we should confirm this generates reasonable code if - // `Self::Element == u8`. Ideally rustc/llvm will reduce this loop to - // something approximating memcpy; TODO check this on the stm32. - for element in &self.elements()[..num_elements] { - tup.serialize_element(element)?; - } - - tup.end() - } - - fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - Self: Sized, - { - struct TupleVisitor(PhantomData); - - impl<'de, T> Visitor<'de> for TupleVisitor - where - T: VariablePacket, - { - type Value = T; - - fn expecting( - &self, - formatter: &mut core::fmt::Formatter, - ) -> core::fmt::Result { - write!(formatter, "{}", T::DESERIALIZE_NAME) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let header: T::Header = match seq.next_element()? { - Some(header) => header, - None => { - return Err(A::Error::custom("missing packet header")) - } - }; - let mut out = T::from_header(header); - let num_elements = usize::from(out.num_elements()); - if num_elements > T::MAX_ELEMENTS { - return Err(A::Error::custom("packet length too long")); - } - for element in &mut out.elements_mut()[..num_elements] { - *element = match seq.next_element()? { - Some(element) => element, - None => { - return Err(A::Error::custom( - "invalid packet length", - )) - } - }; - } - Ok(out) - } - } - - let visitor = TupleVisitor(PhantomData); - deserializer.deserialize_tuple(1 + Self::MAX_ELEMENTS, visitor) - } -} diff --git a/gateway-sp-comms/src/communicator.rs b/gateway-sp-comms/src/communicator.rs index a0f4217d826..94efe7e0522 100644 --- a/gateway-sp-comms/src/communicator.rs +++ b/gateway-sp-comms/src/communicator.rs @@ -22,6 +22,7 @@ use gateway_messages::DiscoverResponse; use gateway_messages::IgnitionCommand; use gateway_messages::IgnitionState; use gateway_messages::ResponseKind; +use gateway_messages::SpComponent; use gateway_messages::SpState; use slog::info; use slog::o; @@ -127,14 +128,18 @@ impl Communicator { .ok_or(Error::LocalIgnitionControllerAddressUnknown)?; let bulk_state = controller.bulk_ignition_state().await?; - // deserializing checks that `num_targets` is reasonably sized, so we - // don't need to guard that here - let targets = - &bulk_state.targets[..usize::from(bulk_state.num_targets)]; - // map ignition target indices back to `SpIdentifier`s for our caller - targets + bulk_state + .targets .iter() + .filter(|state| { + // TODO-cleanup `state.id` should match one of the constants + // defined in RFD 142 section 5.2.2, all of which are nonzero. + // What does the real ignition controller return for unpopulated + // sleds? Our simulator returns 0 for unpopulated targets; + // filter those out. + state.id != 0 + }) .copied() .enumerate() .map(|(target, state)| { @@ -178,10 +183,11 @@ impl Communicator { pub async fn serial_console_attach( &self, sp: SpIdentifier, + component: SpComponent, ) -> Result { let port = self.id_to_port(sp)?; let sp = self.switch.sp(port).ok_or(Error::SpAddressUnknown(sp))?; - Ok(sp.serial_console_attach().await?) + Ok(sp.serial_console_attach(component).await?) } /// Detach any existing connection to the given SP component's serial @@ -207,7 +213,7 @@ impl Communicator { pub async fn update( &self, sp: SpIdentifier, - image: &[u8], + image: Vec, ) -> Result<(), Error> { let port = self.id_to_port(sp)?; let sp = self.switch.sp(port).ok_or(Error::SpAddressUnknown(sp))?; diff --git a/gateway-sp-comms/src/single_sp.rs b/gateway-sp-comms/src/single_sp.rs index 5865d05dfed..a512a581c55 100644 --- a/gateway-sp-comms/src/single_sp.rs +++ b/gateway-sp-comms/src/single_sp.rs @@ -11,6 +11,7 @@ use crate::error::BadResponseType; use crate::error::SerialConsoleAlreadyAttached; use crate::error::SpCommunicationError; use crate::error::UpdateError; +use gateway_messages::sp_impl; use gateway_messages::version; use gateway_messages::BulkIgnitionState; use gateway_messages::IgnitionCommand; @@ -19,8 +20,7 @@ use gateway_messages::Request; use gateway_messages::RequestKind; use gateway_messages::ResponseError; use gateway_messages::ResponseKind; -use gateway_messages::SerialConsole; -use gateway_messages::SerializedSize; +use gateway_messages::SpComponent; use gateway_messages::SpMessage; use gateway_messages::SpMessageKind; use gateway_messages::SpPort; @@ -36,6 +36,9 @@ use slog::trace; use slog::warn; use slog::Logger; use std::convert::TryInto; +use std::io::Cursor; +use std::io::Seek; +use std::io::SeekFrom; use std::net::Ipv6Addr; use std::net::SocketAddr; use std::net::SocketAddrV6; @@ -161,9 +164,14 @@ impl SingleSp { /// Update th SP. /// - /// This is a bulk operation that will call [`Self::update_start()`] - /// followed by [`Self::update_chunk()`] the necessary number of times. - pub async fn update(&self, image: &[u8]) -> Result<(), UpdateError> { + /// This is a bulk operation that will make multiple RPC calls to the SP to + /// deliver all of `image`. + /// + /// # Panics + /// + /// Panics if `image.is_empty()`. + pub async fn update(&self, image: Vec) -> Result<(), UpdateError> { + assert!(!image.is_empty()); let total_size = image .len() .try_into() @@ -172,16 +180,19 @@ impl SingleSp { info!(self.log, "starting SP update"; "total_size" => total_size); self.update_start(total_size).await.map_err(UpdateError::Start)?; - for (i, data) in image.chunks(UpdateChunk::MAX_CHUNK_SIZE).enumerate() { - let offset = (i * UpdateChunk::MAX_CHUNK_SIZE) as u32; - debug!( - self.log, "sending update chunk"; - "offset" => offset, - "size" => data.len(), - ); - self.update_chunk(offset, data) + let mut image = Cursor::new(image); + let mut offset = 0; + while !CursorExt::is_empty(&image) { + let prior_pos = image.position(); + debug!(self.log, "sending update chunk"; "offset" => offset); + + image = self + .update_chunk(offset, image) .await .map_err(|err| UpdateError::Chunk { offset, err })?; + + // Update our offset according to how far our cursor advanced. + offset += (image.position() - prior_pos) as u32; } info!(self.log, "update complete"); @@ -192,7 +203,7 @@ impl SingleSp { /// /// This should be followed by a series of `update_chunk()` calls totalling /// `total_size` bytes of data. - pub async fn update_start(&self, total_size: u32) -> Result<()> { + async fn update_start(&self, total_size: u32) -> Result<()> { self.rpc(RequestKind::UpdateStart(UpdateStart { total_size })) .await .and_then(|(_peer, response)| { @@ -210,20 +221,24 @@ impl SingleSp { /// update starts). /// /// Panics if `chunk.len() > UpdateChunk::MAX_CHUNK_SIZE`. - pub async fn update_chunk(&self, offset: u32, chunk: &[u8]) -> Result<()> { - assert!(chunk.len() <= UpdateChunk::MAX_CHUNK_SIZE); - let mut update_chunk = UpdateChunk { - offset, - chunk_length: chunk.len() as u16, - data: [0; UpdateChunk::MAX_CHUNK_SIZE], - }; - update_chunk.data[..chunk.len()].copy_from_slice(chunk); + async fn update_chunk( + &self, + offset: u32, + data: Cursor>, + ) -> Result>> { + let update_chunk = UpdateChunk { offset }; + let (result, data) = self + .rpc_with_trailing_data( + RequestKind::UpdateChunk(update_chunk), + data, + ) + .await; - self.rpc(RequestKind::UpdateChunk(update_chunk)).await.and_then( - |(_peer, response)| { - response.expect_update_chunk_ack().map_err(Into::into) - }, - ) + result.and_then(|(_peer, response)| { + response.expect_update_chunk_ack().map_err(Into::into) + })?; + + Ok(data) } /// Instruct the SP that a reset trigger will be coming. @@ -268,6 +283,7 @@ impl SingleSp { /// incoming serial console packets from the SP. pub async fn serial_console_attach( &self, + component: SpComponent, ) -> Result { let (tx, rx) = oneshot::channel(); @@ -280,6 +296,7 @@ impl SingleSp { Ok(AttachedSerialConsole { key: attachment.key, rx: attachment.incoming, + component, inner_tx: self.cmds_tx.clone(), }) } @@ -298,20 +315,46 @@ impl SingleSp { &self, kind: RequestKind, ) -> Result<(SocketAddrV6, ResponseKind)> { - rpc(&self.cmds_tx, kind).await + rpc(&self.cmds_tx, kind, None).await.result + } + + async fn rpc_with_trailing_data( + &self, + kind: RequestKind, + trailing_data: Cursor>, + ) -> (Result<(SocketAddrV6, ResponseKind)>, Cursor>) { + rpc_with_trailing_data(&self.cmds_tx, kind, trailing_data).await } } +async fn rpc_with_trailing_data( + inner_tx: &mpsc::Sender, + kind: RequestKind, + trailing_data: Cursor>, +) -> (Result<(SocketAddrV6, ResponseKind)>, Cursor>) { + let RpcResponse { result, trailing_data } = + rpc(inner_tx, kind, Some(trailing_data)).await; + + // We sent `Some(_)` trailing data, so we get `Some(_)` back; unwrap it + // so our caller can remain ignorant of this detail. + (result, trailing_data.unwrap()) +} + async fn rpc( inner_tx: &mpsc::Sender, kind: RequestKind, -) -> Result<(SocketAddrV6, ResponseKind)> { + trailing_data: Option>>, +) -> RpcResponse { let (resp_tx, resp_rx) = oneshot::channel(); // `Inner::run()` doesn't exit as long as `inner_tx` exists, so unwrapping // here only panics if it itself panicked. inner_tx - .send(InnerCommand::Rpc(RpcRequest { kind, response: resp_tx })) + .send(InnerCommand::Rpc(RpcRequest { + kind, + trailing_data, + response_tx: resp_tx, + })) .await .unwrap(); @@ -321,7 +364,8 @@ async fn rpc( #[derive(Debug)] pub struct AttachedSerialConsole { key: u64, - rx: mpsc::Receiver, + component: SpComponent, + rx: mpsc::Receiver>, inner_tx: mpsc::Sender, } @@ -332,6 +376,7 @@ impl AttachedSerialConsole { ( AttachedSerialConsoleSend { key: self.key, + component: self.component, inner_tx: self.inner_tx, }, AttachedSerialConsoleRecv { rx: self.rx }, @@ -343,16 +388,28 @@ impl AttachedSerialConsole { pub struct AttachedSerialConsoleSend { key: u64, inner_tx: mpsc::Sender, + component: SpComponent, } impl AttachedSerialConsoleSend { /// Write `data` to the serial console of the SP. - pub async fn write(&self, data: SerialConsole) -> Result<()> { - rpc(&self.inner_tx, RequestKind::SerialConsoleWrite(data)) - .await - .and_then(|(_peer, response)| { + pub async fn write(&self, data: Vec) -> Result<()> { + let mut data = Cursor::new(data); + while !CursorExt::is_empty(&data) { + let (result, new_data) = rpc_with_trailing_data( + &self.inner_tx, + RequestKind::SerialConsoleWrite(self.component), + data, + ) + .await; + + result.and_then(|(_peer, response)| { response.expect_serial_console_write_ack().map_err(Into::into) - }) + })?; + + data = new_data; + } + Ok(()) } /// Detach this serial console connection. @@ -366,7 +423,7 @@ impl AttachedSerialConsoleSend { #[derive(Debug)] pub struct AttachedSerialConsoleRecv { - rx: mpsc::Receiver, + rx: mpsc::Receiver>, } impl AttachedSerialConsoleRecv { @@ -374,21 +431,39 @@ impl AttachedSerialConsoleRecv { /// /// Returns `None` if the underlying channel has been closed (e.g., if the /// serial console has been detached). - pub async fn recv(&mut self) -> Option { + pub async fn recv(&mut self) -> Option> { self.rx.recv().await } } +// All RPC request/responses are handled by message passing to the `Inner` task +// below. `trailing_data` deserves some extra documentation: Some packet types +// (e.g., update chunks) want to send potentially-large binary data. We +// serialize this data with `gateway_messages::serialize_with_trailing_data()`, +// which appends as much data as will fit after the message header, but the +// caller doesn't know how much data that is until serialization happens. To +// handle this, we traffic in `Cursor>`s for communicating trailing data +// to `Inner`. If `trailing_data` in the `RpcRequest` is `Some(_)`, it will +// always be returned as `Some(_)` in the response as well, and the cursor will +// have been advanced by however much data was packed into the single RPC packet +// exchanged with the SP. #[derive(Debug)] struct RpcRequest { kind: RequestKind, - response: oneshot::Sender>, + trailing_data: Option>>, + response_tx: oneshot::Sender, +} + +#[derive(Debug)] +struct RpcResponse { + result: Result<(SocketAddrV6, ResponseKind)>, + trailing_data: Option>>, } #[derive(Debug)] struct SerialConsoleAttachment { key: u64, - incoming: mpsc::Receiver, + incoming: mpsc::Receiver>, } #[derive(Debug)] @@ -417,7 +492,7 @@ struct Inner { discovery_addr: SocketAddrV6, max_attempts: usize, per_attempt_timeout: Duration, - serial_console_tx: Option>, + serial_console_tx: Option>>, cmds_rx: mpsc::Receiver, request_id: u32, serial_console_connection_key: u64, @@ -448,7 +523,7 @@ impl Inner { } async fn run(mut self) { - let mut incoming_buf = [0; SpMessage::MAX_SIZE]; + let mut incoming_buf = [0; gateway_messages::MAX_SERIALIZED_SIZE]; let maybe_known_addr = *self.sp_addr_tx.borrow(); let mut sp_addr = match maybe_known_addr { @@ -524,10 +599,15 @@ impl Inner { async fn discover( &mut self, - incoming_buf: &mut [u8; SpMessage::MAX_SIZE], + incoming_buf: &mut [u8; gateway_messages::MAX_SERIALIZED_SIZE], ) -> Result { let (addr, response) = self - .rpc_call(self.discovery_addr, RequestKind::Discover, incoming_buf) + .rpc_call( + self.discovery_addr, + RequestKind::Discover, + None, + incoming_buf, + ) .await?; let discovery = response.expect_discover()?; @@ -544,7 +624,7 @@ impl Inner { &mut self, sp_addr: SocketAddrV6, command: InnerCommand, - incoming_buf: &mut [u8; SpMessage::MAX_SIZE], + incoming_buf: &mut [u8; gateway_messages::MAX_SERIALIZED_SIZE], ) { // When a caller attaches to the SP's serial console, we return an // `mpsc::Receiver<_>` on which we send any packets received from the @@ -558,11 +638,19 @@ impl Inner { const SERIAL_CONSOLE_CHANNEL_DEPTH: usize = 32; match command { - InnerCommand::Rpc(rpc) => { - let result = - self.rpc_call(sp_addr, rpc.kind, incoming_buf).await; - - if rpc.response.send(result).is_err() { + InnerCommand::Rpc(mut rpc) => { + let result = self + .rpc_call( + sp_addr, + rpc.kind, + rpc.trailing_data.as_mut(), + incoming_buf, + ) + .await; + let response = + RpcResponse { result, trailing_data: rpc.trailing_data }; + + if rpc.response_tx.send(response).is_err() { warn!( self.log, "RPC requester disappeared while waiting for response" @@ -595,10 +683,12 @@ impl Inner { fn handle_incoming_message( &mut self, - result: Result<(SocketAddrV6, SpMessage)>, + result: Result<(SocketAddrV6, SpMessage, &[u8])>, ) { - let (peer, message) = match result { - Ok((peer, message)) => (peer, message), + let (peer, message, trailing_data) = match result { + Ok((peer, message, trailing_data)) => { + (peer, message, trailing_data) + } Err(err) => { error!( self.log, @@ -633,8 +723,8 @@ impl Inner { "result" => ?result, ); } - SpMessageKind::SerialConsole(serial_console) => { - self.forward_serial_console(serial_console); + SpMessageKind::SerialConsole(component) => { + self.forward_serial_console(component, trailing_data); } } } @@ -643,19 +733,36 @@ impl Inner { &mut self, addr: SocketAddrV6, kind: RequestKind, - incoming_buf: &mut [u8; SpMessage::MAX_SIZE], + trailing_data: Option<&mut Cursor>>, + incoming_buf: &mut [u8; gateway_messages::MAX_SERIALIZED_SIZE], ) -> Result<(SocketAddrV6, ResponseKind)> { // Build and serialize our request once. self.request_id += 1; let request = Request { version: version::V1, request_id: self.request_id, kind }; - // We know statically that `outgoing_buf` is large enough to hold any - // `Request`, which in practice is the only possible serialization - // error. Therefore, we can `.unwrap()`. - let mut outgoing_buf = [0; Request::MAX_SIZE]; - let n = gateway_messages::serialize(&mut outgoing_buf[..], &request) - .unwrap(); + let mut outgoing_buf = [0; gateway_messages::MAX_SERIALIZED_SIZE]; + let n = match trailing_data { + Some(data) => { + let (n, written) = + gateway_messages::serialize_with_trailing_data( + &mut outgoing_buf, + &request, + CursorExt::remaining_slice(data), + ); + // `data` is an in-memory cursor; seeking can only fail if we + // provide a bogus offset, so it's safe to unwrap here. + data.seek(SeekFrom::Current(written as i64)).unwrap(); + n + } + None => { + // We know statically that `outgoing_buf` is large enough to + // hold any `Request`, which in practice is the only possible + // serialization error. Therefore, we can `.unwrap()`. + gateway_messages::serialize(&mut outgoing_buf[..], &request) + .unwrap() + } + }; let outgoing_buf = &outgoing_buf[..n]; for attempt in 1..=self.max_attempts { @@ -687,7 +794,7 @@ impl Inner { addr: SocketAddrV6, request_id: u32, serialized_request: &[u8], - incoming_buf: &mut [u8; SpMessage::MAX_SIZE], + incoming_buf: &mut [u8; gateway_messages::MAX_SERIALIZED_SIZE], ) -> Result> { // We consider an RPC attempt to be our attempt to contact the SP. It's // possible for the SP to respond and say it's busy; we shouldn't count @@ -708,8 +815,8 @@ impl Inner { Err(_elapsed) => return Ok(None), }; - let (peer, response) = match result { - Ok((peer, response)) => (peer, response), + let (peer, response, trailing_data) = match result { + Ok((peer, response, data)) => (peer, response, data), Err(err) => { warn!( self.log, "error receiving response"; @@ -721,6 +828,9 @@ impl Inner { let result = match response.kind { SpMessageKind::Response { request_id: response_id, result } => { + if !trailing_data.is_empty() { + warn!(self.log, "received unexpected trailing data with response (discarding)"); + } if response_id == request_id { result } else { @@ -733,7 +843,7 @@ impl Inner { } } SpMessageKind::SerialConsole(serial_console) => { - self.forward_serial_console(serial_console); + self.forward_serial_console(serial_console, trailing_data); continue; } }; @@ -750,9 +860,14 @@ impl Inner { } } - fn forward_serial_console(&mut self, serial_console: SerialConsole) { + fn forward_serial_console(&mut self, _component: SpComponent, data: &[u8]) { + // TODO-cleanup component support for serial console is half baked; + // should we check here that it matches the attached serial console? For + // the foreseeable future we only support one component, so we skip that + // for now. + if let Some(tx) = self.serial_console_tx.as_ref() { - match tx.try_send(serial_console) { + match tx.try_send(data.to_vec()) { Ok(()) => return, Err(mpsc::error::TrySendError::Closed(_)) => { self.serial_console_tx = None; @@ -786,11 +901,11 @@ async fn send( Ok(()) } -async fn recv( +async fn recv<'a>( socket: &UdpSocket, - incoming_buf: &mut [u8; SpMessage::MAX_SIZE], + incoming_buf: &'a mut [u8; gateway_messages::MAX_SERIALIZED_SIZE], log: &Logger, -) -> Result<(SocketAddrV6, SpMessage)> { +) -> Result<(SocketAddrV6, SpMessage, &'a [u8])> { let (n, peer) = socket .recv_from(&mut incoming_buf[..]) .await @@ -809,7 +924,7 @@ async fn recv( } }; - let (message, _n) = + let (message, leftover) = gateway_messages::deserialize::(&incoming_buf[..n]) .map_err(|err| SpCommunicationError::Deserialize { peer, err })?; @@ -819,7 +934,14 @@ async fn recv( "message" => ?message, ); - Ok((peer, message)) + let trailing_data = if leftover.is_empty() { + &[] + } else { + sp_impl::unpack_trailing_data(leftover) + .map_err(|err| SpCommunicationError::Deserialize { peer, err })? + }; + + Ok((peer, message, trailing_data)) } fn sp_busy_policy() -> backoff::ExponentialBackoff { @@ -836,6 +958,24 @@ fn sp_busy_policy() -> backoff::ExponentialBackoff { } } +// Helper trait to provide methods on `io::Cursor` that are currently unstable. +trait CursorExt { + fn is_empty(&self) -> bool; + fn remaining_slice(&self) -> &[u8]; +} + +impl CursorExt for Cursor> { + fn is_empty(&self) -> bool { + self.position() as usize >= self.get_ref().len() + } + + fn remaining_slice(&self) -> &[u8] { + let data = self.get_ref(); + let pos = usize::min(self.position() as usize, data.len()); + &data[pos..] + } +} + #[usdt::provider(provider = "gateway_sp_comms")] mod probes { fn recv_packet( diff --git a/gateway/faux-mgs/src/main.rs b/gateway/faux-mgs/src/main.rs index e7b40b356a2..2e955922f98 100644 --- a/gateway/faux-mgs/src/main.rs +++ b/gateway/faux-mgs/src/main.rs @@ -191,7 +191,7 @@ async fn main() -> Result<()> { let data = fs::read(&image).with_context(|| { format!("failed to read image {}", image.display()) })?; - sp.update(&data).await.with_context(|| { + sp.update(data).await.with_context(|| { format!("updating to {} failed", image.display()) })?; } diff --git a/gateway/faux-mgs/src/usart.rs b/gateway/faux-mgs/src/usart.rs index 8eb68e2fe22..f537d8cdd86 100644 --- a/gateway/faux-mgs/src/usart.rs +++ b/gateway/faux-mgs/src/usart.rs @@ -6,7 +6,6 @@ use anyhow::Context; use anyhow::Result; -use gateway_messages::sp_impl::SerialConsolePacketizer; use gateway_messages::SpComponent; use gateway_sp_comms::AttachedSerialConsoleSend; use gateway_sp_comms::SingleSp; @@ -39,14 +38,14 @@ pub(crate) async fn run( let mut out_buf = StdinOutBuf::new(raw); let mut flush_delay = FlushDelay::new(stdin_buffer_time); let console = sp - .serial_console_attach() + .serial_console_attach(SpComponent::SP3_HOST_CPU) .await .with_context(|| "failed to attach to serial console")?; let (console_tx, mut console_rx) = console.split(); let (send_tx, send_rx) = mpsc::channel(8); tokio::spawn(async move { - packetize_and_send(console_tx, send_rx).await.unwrap(); + relay_data_to_sp(console_tx, send_rx).await.unwrap(); }); loop { @@ -69,7 +68,7 @@ pub(crate) async fn run( let chunk = chunk.unwrap(); trace!(log, "writing {chunk:?} data to stdout"); let mut stdout = io::stdout().lock(); - stdout.write_all(&chunk.data[..usize::from(chunk.len)]).unwrap(); + stdout.write_all(&chunk).unwrap(); stdout.flush().unwrap(); } @@ -83,21 +82,16 @@ pub(crate) async fn run( } } -async fn packetize_and_send( +async fn relay_data_to_sp( console_tx: AttachedSerialConsoleSend, mut data_rx: mpsc::Receiver>, ) -> Result<()> { - let mut packetizer = - SerialConsolePacketizer::new(SpComponent::try_from("sp3").unwrap()); loop { let data = match data_rx.recv().await { Some(data) => data, None => return Ok(()), }; - - for chunk in packetizer.packetize(&data) { - console_tx.write(chunk).await?; - } + console_tx.write(data).await?; } } diff --git a/gateway/src/http_entrypoints.rs b/gateway/src/http_entrypoints.rs index 68f2d49a314..5afcbf4d114 100644 --- a/gateway/src/http_entrypoints.rs +++ b/gateway/src/http_entrypoints.rs @@ -505,7 +505,7 @@ async fn sp_update( let sp = path.into_inner().sp; let image = body.into_inner().image; - comms.update(sp.into(), &image).await.map_err(http_err_from_comms_err)?; + comms.update(sp.into(), image).await.map_err(http_err_from_comms_err)?; Ok(HttpResponseUpdatedNoContent {}) } diff --git a/gateway/src/serial_console.rs b/gateway/src/serial_console.rs index 26a97c0cfbc..9d39e9bc1c7 100644 --- a/gateway/src/serial_console.rs +++ b/gateway/src/serial_console.rs @@ -5,12 +5,10 @@ // Copyright 2022 Oxide Computer Company use crate::error::Error; -use futures::future::Fuse; -use futures::FutureExt; +use futures::stream::SplitSink; +use futures::stream::SplitStream; use futures::SinkExt; use futures::StreamExt; -use futures::TryFutureExt; -use gateway_messages::sp_impl::SerialConsolePacketizer; use gateway_messages::SpComponent; use gateway_sp_comms::AttachedSerialConsole; use gateway_sp_comms::AttachedSerialConsoleSend; @@ -25,8 +23,8 @@ use slog::error; use slog::info; use slog::Logger; use std::borrow::Cow; -use std::collections::VecDeque; use std::ops::Deref; +use tokio::sync::mpsc; use tokio_tungstenite::tungstenite::handshake; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::CloseFrame; @@ -87,7 +85,7 @@ pub(crate) async fn attach( .map(|key| handshake::derive_accept_key(key)) .ok_or(Error::BadWebsocketConnection("missing websocket key"))?; - let console = sp_comms.serial_console_attach(sp).await?; + let console = sp_comms.serial_console_attach(sp, component).await?; let upgrade_fut = upgrade::on(request); tokio::spawn(async move { let upgraded = match upgrade_fut.await { @@ -108,7 +106,7 @@ pub(crate) async fn attach( ) .await; - let task = SerialConsoleTask { console, component, ws_stream }; + let task = SerialConsoleTask { console, ws_stream }; match task.run(&log).await { Ok(()) => debug!(log, "serial task complete"), Err(e) => { @@ -139,66 +137,61 @@ enum SerialTaskError { struct SerialConsoleTask { console: AttachedSerialConsole, - component: SpComponent, ws_stream: WebSocketStream, } impl SerialConsoleTask { async fn run(self, log: &Logger) -> Result<(), SerialTaskError> { - let (mut ws_sink, mut ws_stream) = self.ws_stream.split(); + let (ws_sink, ws_stream) = self.ws_stream.split(); + + // Spawn a task to send any messages received from the SP to the client + // websocket. + // + // TODO-cleanup We have no way to apply backpressure to the SP, and are + // willing to buffer up an arbitray amount of data in memory. We should + // apply some form of backpressure (which the SP could only handle by + // discarding data). + let (ws_sink_tx, ws_sink_rx) = mpsc::unbounded_channel(); + let mut ws_sink_handle = + tokio::spawn(Self::ws_sink_task(ws_sink, ws_sink_rx)); + + // Spawn a task to send any messages received from the client websocket + // to the SP. let (console_tx, mut console_rx) = self.console.split(); let console_tx = DetachOnDrop::new(console_tx); - - // TODO Currently we do not apply any backpressure on the SP and are - // willing to buffer up an arbitrary amount of data in memory. Is it - // reasonable to apply backpressure to the SP over UDP? Should we have - // caps on memory and start discarding data if we exceed them? We _do_ - // apply backpressure to the websocket, delaying reading from it if we - // still have data waiting to be sent to the SP. - let mut data_from_sp: VecDeque> = VecDeque::new(); - let mut data_to_sp: Vec = Vec::new(); - let mut packetizer_to_sp = SerialConsolePacketizer::new(self.component); + let mut ws_recv_handle = tokio::spawn(Self::ws_recv_task( + ws_stream, + console_tx, + log.clone(), + )); loop { - let ws_send = if let Some(data) = data_from_sp.pop_front() { - ws_sink.send(Message::Binary(data)).fuse() - } else { - Fuse::terminated() - }; - - let ws_recv; - let sp_send; - if data_to_sp.is_empty() { - sp_send = Fuse::terminated(); - ws_recv = ws_stream.next().fuse(); - } else { - ws_recv = Fuse::terminated(); - - let (packet, _remaining) = - packetizer_to_sp.first_packet(data_to_sp.as_slice()); - let packet_data_len = usize::from(packet.len); - - sp_send = console_tx - .write(packet) - .map_ok(move |()| packet_data_len) - .fuse(); - } - tokio::select! { - // Send a UDP packet to the SP - send_success = sp_send => { - let n = send_success - .map_err(gateway_sp_comms::error::Error::from) - .map_err(Error::from)?; - data_to_sp.drain(..n); + // Our ws_sink task completed; this is only possible if it + // fails, since it loops until we drop ws_sink_tx (which doesn't + // happen until we return!). + join_result = &mut ws_sink_handle => { + let result = join_result.expect("ws sink task panicked"); + return result; + } + + // Our ws_recv task completed; this is possible if the websocket + // connection fails or is closed by the client. In either case, + // we're also done. + join_result = &mut ws_recv_handle => { + let result = join_result.expect("ws recv task panicked"); + return result; } // Receive a UDP packet from the SP. packet = console_rx.recv() => { - match packet.as_ref() { - Some(packet) => { - let data = &packet.data[..usize::from(packet.len)]; - data_from_sp.push_back(data.to_vec()); + match packet { + Some(data) => { + info!( + log, "received serial console data from SP"; + "length" => data.len(), + ); + let _ = ws_sink_tx.send(Message::Binary(data)); } None => { // Sender is closed; i.e., we've been detached. @@ -208,47 +201,55 @@ impl SerialConsoleTask { code: CloseCode::Policy, reason: Cow::Borrowed("serial console was detached"), }; - ws_sink.send(Message::Close(Some(close))).await?; + let _ = ws_sink_tx.send(Message::Close(Some(close))); return Ok(()); } } } + } + } + } - // Send a previously-received UDP packet of data to the websocket - // client - write_success = ws_send => { - write_success?; - } + async fn ws_sink_task( + mut ws_sink: SplitSink, Message>, + mut messages: mpsc::UnboundedReceiver, + ) -> Result<(), SerialTaskError> { + while let Some(message) = messages.recv().await { + ws_sink.send(message).await?; + } + Ok(()) + } - // Receive from the websocket to send to the SP. - msg = ws_recv => { - match msg { - Some(Ok(Message::Binary(mut data))) => { - // we only populate ws_recv when we have no data - // currently queued up; sanity check that here - assert!(data_to_sp.is_empty()); - data_to_sp.append(&mut data); - } - Some(Ok(Message::Close(_))) | None => { - info!( - log, - "remote end closed websocket; terminating task", - ); - return Ok(()); - } - Some(other) => { - let wrong_message = other?; - error!( - log, - "bogus websocket message; terminating task"; - "message" => ?wrong_message, - ); - return Ok(()); - } - } + async fn ws_recv_task( + mut ws_stream: SplitStream>, + console_tx: DetachOnDrop, + log: Logger, + ) -> Result<(), SerialTaskError> { + while let Some(message) = ws_stream.next().await { + match message { + Ok(Message::Binary(data)) => { + console_tx + .write(data) + .await + .map_err(gateway_sp_comms::error::Error::from) + .map_err(Error::from)?; + } + Ok(Message::Close(_)) => { + break; + } + Ok(other) => { + error!( + log, + "bogus websocket message; terminating task"; + "message" => ?other, + ); + return Ok(()); } + Err(err) => return Err(err.into()), } } + info!(log, "remote end closed websocket; terminating task",); + Ok(()) } } diff --git a/sp-sim/src/gimlet.rs b/sp-sim/src/gimlet.rs index 866c4d8a491..e32e3454013 100644 --- a/sp-sim/src/gimlet.rs +++ b/sp-sim/src/gimlet.rs @@ -10,13 +10,11 @@ use crate::{Responsiveness, SimulatedSp}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use futures::future; -use gateway_messages::sp_impl::{SerialConsolePacketizer, SpHandler}; +use gateway_messages::sp_impl::SpHandler; use gateway_messages::version; use gateway_messages::DiscoverResponse; use gateway_messages::ResponseError; -use gateway_messages::SerialConsole; use gateway_messages::SerialNumber; -use gateway_messages::SerializedSize; use gateway_messages::SpComponent; use gateway_messages::SpMessage; use gateway_messages::SpMessageKind; @@ -221,10 +219,10 @@ impl Gimlet { struct SerialConsoleTcpTask { listener: TcpListener, - incoming_serial_console: UnboundedReceiver, + incoming_serial_console: UnboundedReceiver>, socks: [Arc; 2], gateway_addresses: Arc; 2]>>, - console_packetizer: SerialConsolePacketizer, + component: SpComponent, log: Logger, } @@ -232,7 +230,7 @@ impl SerialConsoleTcpTask { fn new( component: SpComponent, listener: TcpListener, - incoming_serial_console: UnboundedReceiver, + incoming_serial_console: UnboundedReceiver>, socks: [Arc; 2], gateway_addresses: Arc; 2]>>, log: Logger, @@ -242,12 +240,12 @@ impl SerialConsoleTcpTask { incoming_serial_console, socks, gateway_addresses, - console_packetizer: SerialConsolePacketizer::new(component), + component, log, } } - async fn send_serial_console(&mut self, mut data: &[u8]) -> Result<()> { + async fn send_serial_console(&mut self, data: &[u8]) -> Result<()> { let gateway_addrs = *self.gateway_addresses.lock().unwrap(); for (i, (sock, &gateway_addr)) in self.socks.iter().zip(&gateway_addrs).enumerate() @@ -267,26 +265,19 @@ impl SerialConsoleTcpTask { } }; - // if we're told to send something starting with "SKIP ", emulate a - // dropped packet spanning 10 bytes before sending the rest of the - // data. - if let Some(remaining) = data.strip_prefix(b"SKIP ") { - self.console_packetizer.danger_emulate_dropped_packets(10); - data = remaining; - } - - let mut out = [0; SpMessage::MAX_SIZE]; - for packet in self.console_packetizer.packetize(data) { + let mut out = [0; gateway_messages::MAX_SERIALIZED_SIZE]; + let mut remaining = data; + while !remaining.is_empty() { let message = SpMessage { version: version::V1, - kind: SpMessageKind::SerialConsole(packet), + kind: SpMessageKind::SerialConsole(self.component), }; - - // We know `out` is big enough for any `SpMessage`, so no need - // to bubble up an error here. - let n = gateway_messages::serialize(&mut out[..], &message) - .unwrap(); + let (n, written) = + gateway_messages::serialize_with_trailing_data( + &mut out, &message, remaining, + ); sock.send_to(&out[..n], gateway_addr).await?; + remaining = &remaining[written..]; } } @@ -356,14 +347,13 @@ impl SerialConsoleTcpTask { incoming = self.incoming_serial_console.recv() => { // we can only get `None` if the tx half was dropped, // which means we're in the process of shutting down - let incoming = match incoming { - Some(incoming) => incoming, + let data = match incoming { + Some(data) => data, None => return Ok(()), }; - let data = &incoming.data[..usize::from(incoming.len)]; conn - .write_all(data) + .write_all(&data) .await .with_context(|| "TCP write error")?; } @@ -393,10 +383,7 @@ impl UdpTask { servers: [UdpServer; 2], gateway_addresses: Arc; 2]>>, serial_number: SerialNumber, - incoming_serial_console: HashMap< - SpComponent, - UnboundedSender, - >, + incoming_serial_console: HashMap>>, commands: mpsc::UnboundedReceiver<( Command, oneshot::Sender, @@ -418,7 +405,7 @@ impl UdpTask { } async fn run(mut self) -> Result<()> { - let mut out_buf = [0; SpMessage::MAX_SIZE]; + let mut out_buf = [0; gateway_messages::MAX_SERIALIZED_SIZE]; let mut responsiveness = Responsiveness::Responsive; loop { select! { @@ -470,8 +457,7 @@ struct Handler { log: Logger, serial_number: SerialNumber, gateway_addresses: Arc; 2]>>, - incoming_serial_console: - HashMap>, + incoming_serial_console: HashMap>>, } impl Handler { @@ -554,7 +540,8 @@ impl SpHandler for Handler { &mut self, sender: SocketAddrV6, port: SpPort, - packet: gateway_messages::SerialConsole, + component: SpComponent, + data: &[u8], ) -> Result<(), ResponseError> { self.update_gateway_address(sender, port); debug!( @@ -562,14 +549,13 @@ impl SpHandler for Handler { "received serial console packet"; "sender" => %sender, "port" => ?port, - "len" => packet.len, - "offset" => packet.offset, - "component" => ?packet.component, + "len" => data.len(), + "component" => ?component, ); let incoming_serial_console = self .incoming_serial_console - .get(&packet.component) + .get(&component) .ok_or(ResponseError::RequestUnsupportedForComponent)?; // should we sanity check `offset`? for now just assume everything @@ -577,7 +563,7 @@ impl SpHandler for Handler { // // if the receiving half is gone, we're in the process of shutting down; // ignore errors here - let _ = incoming_serial_console.send(packet); + let _ = incoming_serial_console.send(data.to_vec()); Ok(()) } @@ -622,6 +608,7 @@ impl SpHandler for Handler { sender: SocketAddrV6, port: SpPort, chunk: gateway_messages::UpdateChunk, + data: &[u8], ) -> Result<(), ResponseError> { warn!( &self.log, @@ -629,12 +616,12 @@ impl SpHandler for Handler { "sender" => %sender, "port" => ?port, "offset" => chunk.offset, - "length" => chunk.chunk_length, + "length" => data.len(), ); Err(ResponseError::RequestUnsupportedForSp) } - fn sys_reset_prepare( + fn reset_prepare( &mut self, sender: SocketAddrV6, port: SpPort, @@ -647,7 +634,7 @@ impl SpHandler for Handler { Err(ResponseError::RequestUnsupportedForSp) } - fn sys_reset_trigger( + fn reset_trigger( &mut self, sender: SocketAddrV6, port: SpPort, diff --git a/sp-sim/src/server.rs b/sp-sim/src/server.rs index c3dbb3a8e4d..ebd8ee7d6d9 100644 --- a/sp-sim/src/server.rs +++ b/sp-sim/src/server.rs @@ -10,9 +10,6 @@ use anyhow::Context; use anyhow::Result; use gateway_messages::sp_impl; use gateway_messages::sp_impl::SpHandler; -use gateway_messages::Request; -use gateway_messages::SerializedSize; -use gateway_messages::SpMessage; use gateway_messages::SpPort; use slog::debug; use slog::error; @@ -24,11 +21,12 @@ use std::net::SocketAddrV6; use std::sync::Arc; use tokio::net::UdpSocket; -/// Thin wrapper pairing a [`UdpSocket`] with a buffer sized for [`Request`]s. +/// Thin wrapper pairing a [`UdpSocket`] with a buffer sized for gateway +/// messages. pub(crate) struct UdpServer { sock: Arc, local_addr: SocketAddrV6, - buf: [u8; Request::MAX_SIZE], + buf: [u8; gateway_messages::MAX_SERIALIZED_SIZE], } impl UdpServer { @@ -70,7 +68,11 @@ impl UdpServer { "multicast_addr" => %multicast_addr, ); - Ok(Self { sock, local_addr, buf: [0; Request::MAX_SIZE] }) + Ok(Self { + sock, + local_addr, + buf: [0; gateway_messages::MAX_SERIALIZED_SIZE], + }) } pub(crate) fn socket(&self) -> &Arc { @@ -123,7 +125,7 @@ pub fn logger(config: &Config) -> Result { pub(crate) async fn handle_request<'a, H: SpHandler>( handler: &mut H, recv: Result<(&[u8], SocketAddrV6)>, - out: &'a mut [u8; SpMessage::MAX_SIZE], + out: &'a mut [u8; gateway_messages::MAX_SERIALIZED_SIZE], responsiveness: Responsiveness, port_num: SpPort, ) -> Result> { diff --git a/sp-sim/src/sidecar.rs b/sp-sim/src/sidecar.rs index 611c8298521..23c54c342b9 100644 --- a/sp-sim/src/sidecar.rs +++ b/sp-sim/src/sidecar.rs @@ -21,8 +21,7 @@ use gateway_messages::IgnitionFlags; use gateway_messages::IgnitionState; use gateway_messages::ResponseError; use gateway_messages::SerialNumber; -use gateway_messages::SerializedSize; -use gateway_messages::SpMessage; +use gateway_messages::SpComponent; use gateway_messages::SpPort; use gateway_messages::SpState; use slog::debug; @@ -225,7 +224,7 @@ impl Inner { } async fn run(mut self) -> Result<()> { - let mut out_buf = [0; SpMessage::MAX_SIZE]; + let mut out_buf = [0; gateway_messages::MAX_SERIALIZED_SIZE]; let mut responsiveness = Responsiveness::Responsive; loop { select! { @@ -346,7 +345,6 @@ impl SpHandler for Handler { BulkIgnitionState::MAX_IGNITION_TARGETS ); let mut out = BulkIgnitionState { - num_targets: u16::try_from(num_targets).unwrap(), targets: [IgnitionState::default(); BulkIgnitionState::MAX_IGNITION_TARGETS], }; @@ -394,7 +392,8 @@ impl SpHandler for Handler { &mut self, sender: SocketAddrV6, port: SpPort, - _packet: gateway_messages::SerialConsole, + _component: SpComponent, + _data: &[u8], ) -> Result<(), ResponseError> { warn!( &self.log, "received serial console write; unsupported by sidecar"; @@ -443,6 +442,7 @@ impl SpHandler for Handler { sender: SocketAddrV6, port: SpPort, chunk: gateway_messages::UpdateChunk, + data: &[u8], ) -> Result<(), ResponseError> { warn!( &self.log, @@ -450,12 +450,12 @@ impl SpHandler for Handler { "sender" => %sender, "port" => ?port, "offset" => chunk.offset, - "length" => chunk.chunk_length, + "length" => data.len(), ); Err(ResponseError::RequestUnsupportedForSp) } - fn sys_reset_prepare( + fn reset_prepare( &mut self, sender: SocketAddrV6, port: SpPort, @@ -468,7 +468,7 @@ impl SpHandler for Handler { Err(ResponseError::RequestUnsupportedForSp) } - fn sys_reset_trigger( + fn reset_trigger( &mut self, sender: SocketAddrV6, port: SpPort,