diff --git a/Cargo.toml b/Cargo.toml index c355c0d..2220896 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,19 +10,21 @@ members = [ ] [workspace.package] -version = "0.25.0" +version = "0.26.0" edition = "2024" authors = ["shellrow "] [workspace.dependencies] -nex-core = { version = "0.25.0", path = "nex-core" } -nex-datalink = { version = "0.25.0", path = "nex-datalink" } -nex-packet = { version = "0.25.0", path = "nex-packet" } -nex-sys = { version = "0.25.0", path = "nex-sys" } -nex-socket = { version = "0.25.0", path = "nex-socket" } +nex-core = { version = "0.26.0", path = "nex-core" } +nex-datalink = { version = "0.26.0", path = "nex-datalink" } +nex-packet = { version = "0.26.0", path = "nex-packet" } +nex-sys = { version = "0.26.0", path = "nex-sys" } +nex-socket = { version = "0.26.0", path = "nex-socket" } serde = { version = "1" } libc = "0.2" -netdev = { version = "0.40" } +netdev = { version = "0.41.0" } +mac-addr = { version = "0.3.0" } +ipnet = { version = "2.12" } bytes = "1" tokio = { version = "1" } rand = "0.8" diff --git a/examples/tcp_socket.rs b/examples/tcp_socket.rs index 3bacaf2..d2cfe5a 100644 --- a/examples/tcp_socket.rs +++ b/examples/tcp_socket.rs @@ -6,6 +6,7 @@ use nex_socket::tcp::TcpSocket; use std::env; use std::io::{Read, Write}; use std::net::{IpAddr, SocketAddr}; +use std::time::Duration; fn main() -> std::io::Result<()> { let ip: IpAddr = env::args().nth(1).expect("IP").parse().expect("ip"); @@ -20,8 +21,7 @@ fn main() -> std::io::Result<()> { SocketAddr::V4(_) => TcpSocket::v4_stream()?, SocketAddr::V6(_) => TcpSocket::v6_stream()?, }; - socket.connect(addr)?; - let mut stream = socket.to_tcp_stream()?; + let mut stream = socket.connect_timeout(addr, Duration::from_secs(5))?; let req = format!("GET / HTTP/1.1\r\nHost: {}\r\n\r\n", ip); stream.write_all(req.as_bytes())?; diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 0000000..f83457a --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ +artifacts/ +corpus/ +coverage/ +target/ diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..f69cd22 --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "nex-fuzz" +version = "0.0.0" +publish = false +edition = "2024" + +[workspace] + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" +nex-packet = { path = "../nex-packet" } +bytes = "1" + +[[bin]] +name = "frame_parse" +path = "fuzz_targets/frame_parse.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "ipv4_parse" +path = "fuzz_targets/ipv4_parse.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "ipv6_parse" +path = "fuzz_targets/ipv6_parse.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "tcp_options" +path = "fuzz_targets/tcp_options.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "dns_name" +path = "fuzz_targets/dns_name.rs" +test = false +doc = false +bench = false diff --git a/fuzz/README.md b/fuzz/README.md new file mode 100644 index 0000000..3ae8a2e --- /dev/null +++ b/fuzz/README.md @@ -0,0 +1,15 @@ +# Fuzz Targets + +This directory contains minimal `cargo-fuzz` targets for malformed-input hardening. + +Examples: + +```bash +cargo +nightly fuzz run frame_parse +cargo +nightly fuzz run ipv4_parse +cargo +nightly fuzz run ipv6_parse +cargo +nightly fuzz run tcp_options +cargo +nightly fuzz run dns_name +``` + +Targets focus on parser totality and malformed-input robustness. Panics and unbounded traversal are considered bugs. diff --git a/fuzz/fuzz_targets/dns_name.rs b/fuzz/fuzz_targets/dns_name.rs new file mode 100644 index 0000000..e368316 --- /dev/null +++ b/fuzz/fuzz_targets/dns_name.rs @@ -0,0 +1,9 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use nex_packet::dns::DnsName; + +fuzz_target!(|data: &[u8]| { + let _ = DnsName::from_bytes(data); + let _ = DnsName::try_from_bytes(data); +}); diff --git a/fuzz/fuzz_targets/frame_parse.rs b/fuzz/fuzz_targets/frame_parse.rs new file mode 100644 index 0000000..e304592 --- /dev/null +++ b/fuzz/fuzz_targets/frame_parse.rs @@ -0,0 +1,11 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use nex_packet::frame::{Frame, FrameView, ParseOption}; + +fuzz_target!(|data: &[u8]| { + let _ = Frame::from_buf(data, ParseOption::default()); + let _ = Frame::try_from_buf(data, ParseOption::default()); + let _ = Frame::try_from_buf_strict(data, ParseOption::default()); + let _ = FrameView::from_buf(data, ParseOption::default()); +}); diff --git a/fuzz/fuzz_targets/ipv4_parse.rs b/fuzz/fuzz_targets/ipv4_parse.rs new file mode 100644 index 0000000..322b74d --- /dev/null +++ b/fuzz/fuzz_targets/ipv4_parse.rs @@ -0,0 +1,11 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use nex_packet::packet::Packet; +use nex_packet::ipv4::Ipv4Packet; + +fuzz_target!(|data: &[u8]| { + let _ = Ipv4Packet::from_buf(data); + let _ = Ipv4Packet::try_from_buf(data); + let _ = Ipv4Packet::try_from_buf_strict(data); +}); diff --git a/fuzz/fuzz_targets/ipv6_parse.rs b/fuzz/fuzz_targets/ipv6_parse.rs new file mode 100644 index 0000000..a59f4a0 --- /dev/null +++ b/fuzz/fuzz_targets/ipv6_parse.rs @@ -0,0 +1,11 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use nex_packet::packet::Packet; +use nex_packet::ipv6::Ipv6Packet; + +fuzz_target!(|data: &[u8]| { + let _ = Ipv6Packet::from_buf(data); + let _ = Ipv6Packet::try_from_buf(data); + let _ = Ipv6Packet::try_from_buf_strict(data); +}); diff --git a/fuzz/fuzz_targets/tcp_options.rs b/fuzz/fuzz_targets/tcp_options.rs new file mode 100644 index 0000000..28b0a40 --- /dev/null +++ b/fuzz/fuzz_targets/tcp_options.rs @@ -0,0 +1,12 @@ +#![no_main] + +use bytes::Bytes; +use libfuzzer_sys::fuzz_target; +use nex_packet::packet::Packet; +use nex_packet::tcp::TcpPacket; + +fuzz_target!(|data: &[u8]| { + let _ = TcpPacket::from_buf(data); + let _ = TcpPacket::try_from_buf(data); + let _ = TcpPacket::try_from_bytes(Bytes::copy_from_slice(data)); +}); diff --git a/nex-core/Cargo.toml b/nex-core/Cargo.toml index 49decfc..1c98ea1 100644 --- a/nex-core/Cargo.toml +++ b/nex-core/Cargo.toml @@ -12,7 +12,13 @@ license = "MIT" [dependencies] netdev = { workspace = true } +mac-addr = { workspace = true } +ipnet = { workspace = true } +libc = { workspace = true } +nex-sys = { workspace = true } serde = { workspace = true, features = ["derive"], optional = true } [features] -serde = ["dep:serde", "netdev/serde"] +default = ["gateway"] +gateway = ["netdev/gateway"] +serde = ["dep:serde", "mac-addr/serde", "ipnet/serde", "netdev/serde"] diff --git a/nex-core/src/interface.rs b/nex-core/src/interface.rs index fbd89d8..92dd5c9 100644 --- a/nex-core/src/interface.rs +++ b/nex-core/src/interface.rs @@ -1 +1,605 @@ -pub use netdev::*; +use crate::ip::{is_global_ip, is_global_ipv4, is_global_ipv6}; +use crate::mac::MacAddr; +pub use ipnet::{self, Ipv4Net, Ipv6Net}; +use std::convert::TryFrom; +use std::io; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::time::SystemTime; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(unix)] +pub const IFF_UP: u32 = nex_sys::IFF_UP as u32; +#[cfg(windows)] +pub const IFF_UP: u32 = nex_sys::IFF_UP; + +#[cfg(unix)] +pub const IFF_BROADCAST: u32 = nex_sys::IFF_BROADCAST as u32; +#[cfg(windows)] +pub const IFF_BROADCAST: u32 = nex_sys::IFF_BROADCAST; + +#[cfg(unix)] +pub const IFF_LOOPBACK: u32 = nex_sys::IFF_LOOPBACK as u32; +#[cfg(windows)] +pub const IFF_LOOPBACK: u32 = nex_sys::IFF_LOOPBACK; + +#[cfg(unix)] +pub const IFF_POINTOPOINT: u32 = nex_sys::IFF_POINTOPOINT as u32; +#[cfg(windows)] +pub const IFF_POINTOPOINT: u32 = nex_sys::IFF_POINTOPOINT; + +#[cfg(unix)] +pub const IFF_MULTICAST: u32 = nex_sys::IFF_MULTICAST as u32; +#[cfg(windows)] +pub const IFF_MULTICAST: u32 = nex_sys::IFF_MULTICAST; + +#[cfg(unix)] +pub const IFF_RUNNING: u32 = libc::IFF_RUNNING as u32; + +/// Operational state of a network interface. +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum OperState { + Unknown, + NotPresent, + Down, + LowerLayerDown, + Testing, + Dormant, + Up, +} + +impl OperState { + pub fn as_str(&self) -> &'static str { + match self { + OperState::Unknown => "unknown", + OperState::NotPresent => "notpresent", + OperState::Down => "down", + OperState::LowerLayerDown => "lowerlayerdown", + OperState::Testing => "testing", + OperState::Dormant => "dormant", + OperState::Up => "up", + } + } + + pub fn from_if_flags(if_flags: u32) -> Self { + #[cfg(unix)] + { + if if_flags & IFF_UP != 0 { + if if_flags & IFF_RUNNING != 0 { + OperState::Up + } else { + OperState::Dormant + } + } else { + OperState::Down + } + } + + #[cfg(windows)] + { + if if_flags & IFF_UP != 0 { + OperState::Up + } else { + OperState::Down + } + } + } +} + +impl std::fmt::Display for OperState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl std::str::FromStr for OperState { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "unknown" => Ok(OperState::Unknown), + "notpresent" => Ok(OperState::NotPresent), + "down" => Ok(OperState::Down), + "lowerlayerdown" => Ok(OperState::LowerLayerDown), + "testing" => Ok(OperState::Testing), + "dormant" => Ok(OperState::Dormant), + "up" => Ok(OperState::Up), + _ => Err(()), + } + } +} + +impl From for OperState { + fn from(value: netdev::interface::state::OperState) -> Self { + match value { + netdev::interface::state::OperState::Unknown => OperState::Unknown, + netdev::interface::state::OperState::NotPresent => OperState::NotPresent, + netdev::interface::state::OperState::Down => OperState::Down, + netdev::interface::state::OperState::LowerLayerDown => OperState::LowerLayerDown, + netdev::interface::state::OperState::Testing => OperState::Testing, + netdev::interface::state::OperState::Dormant => OperState::Dormant, + netdev::interface::state::OperState::Up => OperState::Up, + } + } +} + +/// Cross-platform classification of a network interface. +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum InterfaceType { + Unknown, + Ethernet, + TokenRing, + Fddi, + BasicIsdn, + PrimaryIsdn, + Ppp, + Loopback, + Ethernet3Megabit, + Slip, + Atm, + GenericModem, + ProprietaryVirtual, + FastEthernetT, + Isdn, + FastEthernetFx, + Wireless80211, + AsymmetricDsl, + RateAdaptDsl, + SymmetricDsl, + VeryHighSpeedDsl, + IPOverAtm, + GigabitEthernet, + Tunnel, + MultiRateSymmetricDsl, + HighPerformanceSerialBus, + Wman, + Wwanpp, + Wwanpp2, + Bridge, + Can, + PeerToPeerWireless, + UnknownWithValue(u32), +} + +impl InterfaceType { + pub fn name(&self) -> String { + match *self { + InterfaceType::Unknown => String::from("Unknown"), + InterfaceType::Ethernet => String::from("Ethernet"), + InterfaceType::TokenRing => String::from("Token Ring"), + InterfaceType::Fddi => String::from("FDDI"), + InterfaceType::BasicIsdn => String::from("Basic ISDN"), + InterfaceType::PrimaryIsdn => String::from("Primary ISDN"), + InterfaceType::Ppp => String::from("PPP"), + InterfaceType::Loopback => String::from("Loopback"), + InterfaceType::Ethernet3Megabit => String::from("Ethernet 3 megabit"), + InterfaceType::Slip => String::from("SLIP"), + InterfaceType::Atm => String::from("ATM"), + InterfaceType::GenericModem => String::from("Generic Modem"), + InterfaceType::ProprietaryVirtual => String::from("Proprietary Virtual/Internal"), + InterfaceType::FastEthernetT => String::from("Fast Ethernet T"), + InterfaceType::Isdn => String::from("ISDN"), + InterfaceType::FastEthernetFx => String::from("Fast Ethernet FX"), + InterfaceType::Wireless80211 => String::from("Wireless IEEE 802.11"), + InterfaceType::AsymmetricDsl => String::from("Asymmetric DSL"), + InterfaceType::RateAdaptDsl => String::from("Rate Adaptive DSL"), + InterfaceType::SymmetricDsl => String::from("Symmetric DSL"), + InterfaceType::VeryHighSpeedDsl => String::from("Very High Data Rate DSL"), + InterfaceType::IPOverAtm => String::from("IP over ATM"), + InterfaceType::GigabitEthernet => String::from("Gigabit Ethernet"), + InterfaceType::Tunnel => String::from("Tunnel"), + InterfaceType::MultiRateSymmetricDsl => String::from("Multi-Rate Symmetric DSL"), + InterfaceType::HighPerformanceSerialBus => String::from("High Performance Serial Bus"), + InterfaceType::Wman => String::from("WMAN"), + InterfaceType::Wwanpp => String::from("WWANPP"), + InterfaceType::Wwanpp2 => String::from("WWANPP2"), + InterfaceType::Bridge => String::from("Bridge"), + InterfaceType::Can => String::from("CAN"), + InterfaceType::PeerToPeerWireless => String::from("Peer-to-Peer Wireless"), + InterfaceType::UnknownWithValue(v) => format!("Unknown ({v})"), + } + } +} + +impl From for InterfaceType { + fn from(value: netdev::interface::types::InterfaceType) -> Self { + match value { + netdev::interface::types::InterfaceType::Unknown => InterfaceType::Unknown, + netdev::interface::types::InterfaceType::Ethernet => InterfaceType::Ethernet, + netdev::interface::types::InterfaceType::TokenRing => InterfaceType::TokenRing, + netdev::interface::types::InterfaceType::Fddi => InterfaceType::Fddi, + netdev::interface::types::InterfaceType::BasicIsdn => InterfaceType::BasicIsdn, + netdev::interface::types::InterfaceType::PrimaryIsdn => InterfaceType::PrimaryIsdn, + netdev::interface::types::InterfaceType::Ppp => InterfaceType::Ppp, + netdev::interface::types::InterfaceType::Loopback => InterfaceType::Loopback, + netdev::interface::types::InterfaceType::Ethernet3Megabit => { + InterfaceType::Ethernet3Megabit + } + netdev::interface::types::InterfaceType::Slip => InterfaceType::Slip, + netdev::interface::types::InterfaceType::Atm => InterfaceType::Atm, + netdev::interface::types::InterfaceType::GenericModem => InterfaceType::GenericModem, + netdev::interface::types::InterfaceType::ProprietaryVirtual => { + InterfaceType::ProprietaryVirtual + } + netdev::interface::types::InterfaceType::FastEthernetT => InterfaceType::FastEthernetT, + netdev::interface::types::InterfaceType::Isdn => InterfaceType::Isdn, + netdev::interface::types::InterfaceType::FastEthernetFx => { + InterfaceType::FastEthernetFx + } + netdev::interface::types::InterfaceType::Wireless80211 => InterfaceType::Wireless80211, + netdev::interface::types::InterfaceType::AsymmetricDsl => InterfaceType::AsymmetricDsl, + netdev::interface::types::InterfaceType::RateAdaptDsl => InterfaceType::RateAdaptDsl, + netdev::interface::types::InterfaceType::SymmetricDsl => InterfaceType::SymmetricDsl, + netdev::interface::types::InterfaceType::VeryHighSpeedDsl => { + InterfaceType::VeryHighSpeedDsl + } + netdev::interface::types::InterfaceType::IPOverAtm => InterfaceType::IPOverAtm, + netdev::interface::types::InterfaceType::GigabitEthernet => { + InterfaceType::GigabitEthernet + } + netdev::interface::types::InterfaceType::Tunnel => InterfaceType::Tunnel, + netdev::interface::types::InterfaceType::MultiRateSymmetricDsl => { + InterfaceType::MultiRateSymmetricDsl + } + netdev::interface::types::InterfaceType::HighPerformanceSerialBus => { + InterfaceType::HighPerformanceSerialBus + } + netdev::interface::types::InterfaceType::Wman => InterfaceType::Wman, + netdev::interface::types::InterfaceType::Wwanpp => InterfaceType::Wwanpp, + netdev::interface::types::InterfaceType::Wwanpp2 => InterfaceType::Wwanpp2, + netdev::interface::types::InterfaceType::Bridge => InterfaceType::Bridge, + netdev::interface::types::InterfaceType::Can => InterfaceType::Can, + netdev::interface::types::InterfaceType::PeerToPeerWireless => { + InterfaceType::PeerToPeerWireless + } + netdev::interface::types::InterfaceType::UnknownWithValue(v) => { + InterfaceType::UnknownWithValue(v) + } + } + } +} + +impl TryFrom for InterfaceType { + type Error = (); + + fn try_from(v: u32) -> Result { + Ok(InterfaceType::from( + netdev::interface::types::InterfaceType::try_from(v)?, + )) + } +} + +/// Address information for a related network device. +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct NetworkDevice { + pub mac_addr: MacAddr, + pub ipv4: Vec, + pub ipv6: Vec, +} + +impl NetworkDevice { + pub fn new() -> NetworkDevice { + NetworkDevice { + mac_addr: MacAddr::zero(), + ipv4: Vec::new(), + ipv6: Vec::new(), + } + } +} + +impl Default for NetworkDevice { + fn default() -> Self { + Self::new() + } +} + +impl From for NetworkDevice { + fn from(value: netdev::NetworkDevice) -> Self { + NetworkDevice { + mac_addr: value.mac_addr, + ipv4: value.ipv4, + ipv6: value.ipv6, + } + } +} + +/// Interface traffic statistics at a given point in time. +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct InterfaceStats { + pub rx_bytes: u64, + pub tx_bytes: u64, + pub timestamp: Option, +} + +impl From for InterfaceStats { + fn from(value: netdev::stats::counters::InterfaceStats) -> Self { + InterfaceStats { + rx_bytes: value.rx_bytes, + tx_bytes: value.tx_bytes, + timestamp: value.timestamp, + } + } +} + +/// A network interface. +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Interface { + pub index: u32, + pub name: String, + pub friendly_name: Option, + pub description: Option, + pub if_type: InterfaceType, + pub mac_addr: Option, + pub ipv4: Vec, + pub ipv6: Vec, + pub ipv6_scope_ids: Vec, + pub flags: u32, + pub oper_state: OperState, + pub transmit_speed: Option, + pub receive_speed: Option, + pub stats: Option, + #[cfg(feature = "gateway")] + pub gateway: Option, + #[cfg(feature = "gateway")] + pub dns_servers: Vec, + pub mtu: Option, + #[cfg(feature = "gateway")] + pub default: bool, +} + +impl Interface { + #[cfg(feature = "gateway")] + #[allow(clippy::should_implement_trait)] + pub fn default() -> Result { + get_default_interface() + } + + pub fn dummy() -> Interface { + Interface { + index: 0, + name: String::new(), + friendly_name: None, + description: None, + if_type: InterfaceType::Unknown, + mac_addr: None, + ipv4: Vec::new(), + ipv6: Vec::new(), + ipv6_scope_ids: Vec::new(), + flags: 0, + oper_state: OperState::Unknown, + transmit_speed: None, + receive_speed: None, + stats: None, + #[cfg(feature = "gateway")] + gateway: None, + #[cfg(feature = "gateway")] + dns_servers: Vec::new(), + mtu: None, + #[cfg(feature = "gateway")] + default: false, + } + } + + /// Refresh all interface fields from the operating system. + /// + /// This performs a fresh system lookup and may be more expensive than + /// the accessor methods on `Interface`. + pub fn refresh(&mut self) -> io::Result<()> { + let refreshed = lookup_interface(&self.name, self.index).ok_or_else(|| { + io::Error::new(io::ErrorKind::NotFound, "interface could not be refreshed") + })?; + *self = refreshed.into(); + Ok(()) + } + + pub fn is_up(&self) -> bool { + self.flags & IFF_UP != 0 + } + + pub fn is_loopback(&self) -> bool { + self.flags & IFF_LOOPBACK != 0 + } + + pub fn is_point_to_point(&self) -> bool { + self.flags & IFF_POINTOPOINT != 0 + } + + pub fn is_multicast(&self) -> bool { + self.flags & IFF_MULTICAST != 0 + } + + pub fn is_broadcast(&self) -> bool { + self.flags & IFF_BROADCAST != 0 + } + + pub fn is_tun(&self) -> bool { + self.is_up() && self.is_point_to_point() && !self.is_broadcast() && !self.is_loopback() + } + + pub fn is_running(&self) -> bool { + #[cfg(unix)] + { + self.flags & IFF_RUNNING != 0 + } + #[cfg(windows)] + { + self.is_up() + } + } + + pub fn is_physical(&self) -> bool { + lookup_interface(&self.name, self.index) + .map(|iface| iface.is_physical()) + .unwrap_or_else(|| { + self.is_up() && self.is_running() && !self.is_tun() && !self.is_loopback() + }) + } + + pub fn oper_state(&self) -> OperState { + self.oper_state + } + + pub fn is_oper_up(&self) -> bool { + self.oper_state == OperState::Up + } + + /// Refresh the operational state from the operating system. + /// + /// This may perform a fresh interface lookup. + pub fn refresh_oper_state(&mut self) -> io::Result<()> { + if let Some(iface) = lookup_interface(&self.name, self.index) { + self.oper_state = iface.oper_state.into(); + return Ok(()); + } + Err(io::Error::new( + io::ErrorKind::NotFound, + "interface operational state could not be refreshed", + )) + } + + /// Refresh the operational state from the operating system. + /// + /// This may perform a fresh interface lookup. + pub fn update_oper_state(&mut self) { + let _ = self.refresh_oper_state(); + } + + /// Iterate IPv4 addresses without allocating a new vector. + pub fn ipv4_addr_iter(&self) -> impl Iterator + '_ { + self.ipv4.iter().map(|net| net.addr()) + } + + pub fn ipv4_addrs(&self) -> Vec { + self.ipv4_addr_iter().collect() + } + + /// Iterate IPv6 addresses without allocating a new vector. + pub fn ipv6_addr_iter(&self) -> impl Iterator + '_ { + self.ipv6.iter().map(|net| net.addr()) + } + + pub fn ipv6_addrs(&self) -> Vec { + self.ipv6_addr_iter().collect() + } + + /// Iterate IP addresses without allocating a new vector. + pub fn ip_addr_iter(&self) -> impl Iterator + '_ { + self.ipv4_addr_iter() + .map(IpAddr::V4) + .chain(self.ipv6_addr_iter().map(IpAddr::V6)) + } + + pub fn ip_addrs(&self) -> Vec { + self.ip_addr_iter().collect() + } + + pub fn has_ipv4(&self) -> bool { + !self.ipv4.is_empty() + } + + pub fn has_ipv6(&self) -> bool { + !self.ipv6.is_empty() + } + + pub fn has_global_ipv4(&self) -> bool { + self.ipv4_addrs().iter().any(is_global_ipv4) + } + + pub fn has_global_ipv6(&self) -> bool { + self.ipv6_addrs().iter().any(is_global_ipv6) + } + + pub fn has_global_ip(&self) -> bool { + self.ip_addrs().iter().any(is_global_ip) + } + + pub fn global_ipv4_addrs(&self) -> Vec { + self.ipv4_addr_iter().filter(is_global_ipv4).collect() + } + + pub fn global_ipv6_addrs(&self) -> Vec { + self.ipv6_addr_iter().filter(is_global_ipv6).collect() + } + + pub fn global_ip_addrs(&self) -> Vec { + self.ip_addr_iter().filter(is_global_ip).collect() + } + + /// Refresh interface statistics from the operating system. + /// + /// This may perform a fresh interface lookup. + pub fn refresh_stats(&mut self) -> io::Result<()> { + if let Some(iface) = lookup_interface(&self.name, self.index) { + self.stats = iface.stats.map(Into::into); + return Ok(()); + } + Err(io::Error::new( + io::ErrorKind::NotFound, + "interface statistics could not be refreshed", + )) + } + + /// Refresh interface statistics from the operating system. + /// + /// This may perform a fresh interface lookup. + pub fn update_stats(&mut self) -> io::Result<()> { + self.refresh_stats() + } +} + +impl From for Interface { + fn from(value: netdev::Interface) -> Self { + Interface { + index: value.index, + name: value.name, + friendly_name: value.friendly_name, + description: value.description, + if_type: value.if_type.into(), + mac_addr: value.mac_addr, + ipv4: value.ipv4, + ipv6: value.ipv6, + ipv6_scope_ids: value.ipv6_scope_ids, + flags: value.flags, + oper_state: value.oper_state.into(), + transmit_speed: value.transmit_speed, + receive_speed: value.receive_speed, + stats: value.stats.map(Into::into), + #[cfg(feature = "gateway")] + gateway: value.gateway.map(Into::into), + #[cfg(feature = "gateway")] + dns_servers: value.dns_servers, + mtu: value.mtu, + #[cfg(feature = "gateway")] + default: value.default, + } + } +} + +pub fn get_interfaces() -> Vec { + netdev::get_interfaces() + .into_iter() + .map(Into::into) + .collect() +} + +#[cfg(feature = "gateway")] +pub fn get_default_interface() -> Result { + netdev::get_default_interface().map(Into::into) +} + +#[cfg(feature = "gateway")] +pub fn get_default_gateway() -> Result { + netdev::get_default_gateway().map(Into::into) +} + +fn lookup_interface(name: &str, index: u32) -> Option { + netdev::get_interfaces() + .into_iter() + .find(|iface| iface.index == index || iface.name == name) +} diff --git a/nex-core/src/ip.rs b/nex-core/src/ip.rs index d63ee0b..932c400 100644 --- a/nex-core/src/ip.rs +++ b/nex-core/src/ip.rs @@ -67,7 +67,7 @@ pub fn is_global_ipv6(ipv6_addr: &Ipv6Addr) -> bool { || matches!(ipv6_addr.segments(), [0x2001, 4, 0x112, _, _, _, _, _]) // ORCHIDv2 (`2001:20::/28`) // Drone Remote ID Protocol Entity Tags (DETs) Prefix (`2001:30::/28`)` - || matches!(ipv6_addr.segments(), [0x2001, b, _, _, _, _, _, _] if b >= 0x20 && b <= 0x3F) + || matches!(ipv6_addr.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x3F).contains(&b)) )) // 6to4 (`2002::/16`) - it's not explicitly documented as globally reachable, // IANA says N/A. diff --git a/nex-core/src/lib.rs b/nex-core/src/lib.rs index 852b5b3..a8935dd 100644 --- a/nex-core/src/lib.rs +++ b/nex-core/src/lib.rs @@ -1,8 +1,6 @@ //! Core network types and helpers shared across the `nex` crates. //! Includes interface, MAC/IP, and bitfield utilities used by low-level networking code. -pub use netdev; - pub mod bitfield; pub mod interface; pub mod ip; diff --git a/nex-core/src/mac.rs b/nex-core/src/mac.rs index 350c3a2..f7c5b14 100644 --- a/nex-core/src/mac.rs +++ b/nex-core/src/mac.rs @@ -1 +1 @@ -pub use netdev::net::mac::*; +pub use mac_addr::MacAddr; diff --git a/nex-datalink/Cargo.toml b/nex-datalink/Cargo.toml index 1603140..cb61f1b 100644 --- a/nex-datalink/Cargo.toml +++ b/nex-datalink/Cargo.toml @@ -13,7 +13,6 @@ license = "MIT" [dependencies] libc = { workspace = true } bytes = { workspace = true } -netdev = { workspace = true } serde = { workspace = true, features = ["derive"], optional = true } pcap = { version = "2.0", optional = true } nex-core = { workspace = true } @@ -21,7 +20,7 @@ nex-sys = { workspace = true } futures-core = "0.3" [target.'cfg(windows)'.dependencies.windows-sys] -version = "0.59.0" +version = "0.61" features = [ "Win32_Foundation", "Win32_Networking_WinSock", @@ -31,7 +30,7 @@ features = [ ] [features] -serde = ["dep:serde", "netdev/serde"] +serde = ["dep:serde", "nex-core/serde"] pcap = ["dep:pcap"] [dev-dependencies] diff --git a/nex-datalink/src/async_io/wpcap.rs b/nex-datalink/src/async_io/wpcap.rs index f52ba29..09d86d9 100644 --- a/nex-datalink/src/async_io/wpcap.rs +++ b/nex-datalink/src/async_io/wpcap.rs @@ -122,6 +122,7 @@ impl Stream for AsyncWpcapSocketReceiver { /// Create a new asynchronous WinPcap channel. pub fn channel(network_interface: &Interface, config: Config) -> io::Result { + let read_buffer_size = config.read_buffer_size; let mut write_buffer = vec![0u8; config.write_buffer_size]; let adapter = unsafe { @@ -141,7 +142,7 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result io::Result>>, } -fn lock_capture( +fn lock_capture( capture: &Mutex>, ) -> io::Result>> { capture @@ -172,12 +172,12 @@ impl RawSender for InvalidRawSenderImpl { } } -struct RawReceiverImpl { +struct RawReceiverImpl { capture: Arc>>, read_buffer: Vec, } -impl RawReceiver for RawReceiverImpl { +impl RawReceiver for RawReceiverImpl { fn next(&mut self) -> io::Result<&[u8]> { let mut cap = lock_capture(&self.capture)?; match cap.next_packet() { @@ -206,11 +206,17 @@ pub fn interfaces() -> Vec { mac_addr: None, ipv4: Vec::new(), ipv6: Vec::new(), + ipv6_scope_ids: Vec::new(), flags: dev.flags.if_flags.bits(), + oper_state: nex_core::interface::OperState::from_if_flags( + dev.flags.if_flags.bits(), + ), transmit_speed: None, receive_speed: None, + stats: None, gateway: None, dns_servers: Vec::new(), + mtu: None, default: false, }) .collect() diff --git a/nex-datalink/src/wpcap.rs b/nex-datalink/src/wpcap.rs index b46a650..d0938bc 100644 --- a/nex-datalink/src/wpcap.rs +++ b/nex-datalink/src/wpcap.rs @@ -47,6 +47,9 @@ pub struct Config { pub read_buffer_size: usize, } +const DEFAULT_WRITE_BUFFER_SIZE: usize = 4096; +const DEFAULT_READ_BUFFER_SIZE: usize = 65536; + impl<'a> From<&'a super::Config> for Config { fn from(config: &super::Config) -> Config { Config { @@ -59,8 +62,8 @@ impl<'a> From<&'a super::Config> for Config { impl Default for Config { fn default() -> Config { Config { - write_buffer_size: 4096, - read_buffer_size: 4096, + write_buffer_size: DEFAULT_WRITE_BUFFER_SIZE, + read_buffer_size: DEFAULT_READ_BUFFER_SIZE, } } } @@ -84,29 +87,31 @@ pub fn channel(network_interface: &Interface, config: Config) -> io::Result io::Result io::Result io::Result, _write_buffer: Vec, diff --git a/nex-packet/Cargo.toml b/nex-packet/Cargo.toml index 1f989b0..b2bef75 100644 --- a/nex-packet/Cargo.toml +++ b/nex-packet/Cargo.toml @@ -18,3 +18,10 @@ rand = { workspace = true } [features] serde = ["dep:serde", "nex-core/serde", "bytes/serde"] + +[dev-dependencies] +criterion = "0.5" + +[[bench]] +name = "packet_parse" +harness = false diff --git a/nex-packet/benches/packet_parse.rs b/nex-packet/benches/packet_parse.rs new file mode 100644 index 0000000..6a990b5 --- /dev/null +++ b/nex-packet/benches/packet_parse.rs @@ -0,0 +1,68 @@ +use bytes::Bytes; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use nex_packet::{ + frame::{Frame, FrameView, ParseOption}, + packet::Packet, + tcp::TcpPacket, + udp::UdpPacket, +}; + +fn ipv4_tcp_frame() -> Bytes { + Bytes::from_static(&[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0x08, 0x00, 0x45, 0x00, 0x00, 0x30, 0x12, 0x34, 0x40, + 0x00, 64, 0x06, 0, 0, 192, 0, 2, 1, 198, 51, 100, 2, 0x04, 0xd2, 0x00, 0x50, 0, 0, 0, 1, 0, + 0, 0, 0, 0x50, 0x18, 0x20, 0x00, 0, 0, 0, 0, b'h', b'e', b'l', b'l', b'o', b'!', b'!', + b'!', + ]) +} + +fn ipv6_udp_frame() -> Bytes { + Bytes::from_static(&[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0x86, 0xdd, 0x60, 0, 0, 0, 0, 16, 17, 64, 0xfe, 0x80, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 2, 0x04, 0xd2, 0x00, 0x35, 0x00, 0x10, 0, 0, b'd', b'n', b's', b'!', 0, 1, 2, 3, + ]) +} + +fn bench_packet_parse(c: &mut Criterion) { + let mut group = c.benchmark_group("packet_parse"); + let ipv4_tcp = ipv4_tcp_frame(); + let ipv6_udp = ipv6_udp_frame(); + let tcp_segment = ipv4_tcp.slice(14 + 20..); + let udp_datagram = ipv6_udp.slice(14 + 40..); + + group.bench_function("frame_from_buf_ipv4_tcp", |b| { + b.iter(|| Frame::from_buf(&ipv4_tcp, ParseOption::default())) + }); + group.bench_function("frame_try_from_bytes_ipv4_tcp", |b| { + b.iter(|| Frame::try_from_bytes(ipv4_tcp.clone(), ParseOption::default())) + }); + group.bench_function("frame_view_from_buf_ipv4_tcp", |b| { + b.iter(|| FrameView::from_buf(&ipv4_tcp, ParseOption::default())) + }); + group.bench_function("tcp_from_buf", |b| { + b.iter(|| TcpPacket::from_buf(&tcp_segment)) + }); + group.bench_function("tcp_from_bytes", |b| { + b.iter(|| TcpPacket::from_bytes(tcp_segment.clone())) + }); + group.bench_function("udp_from_buf", |b| { + b.iter(|| UdpPacket::from_buf(&udp_datagram)) + }); + group.bench_function("udp_from_bytes", |b| { + b.iter(|| UdpPacket::from_bytes(udp_datagram.clone())) + }); + + for (name, packet) in [("ipv4_tcp", ipv4_tcp), ("ipv6_udp", ipv6_udp)] { + group.bench_with_input( + BenchmarkId::new("frame_view", name), + &packet, + |b, packet| b.iter(|| FrameView::from_buf(packet, ParseOption::default())), + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_packet_parse); +criterion_main!(benches); diff --git a/nex-packet/src/dns.rs b/nex-packet/src/dns.rs index ae1d3fd..bcf2f36 100644 --- a/nex-packet/src/dns.rs +++ b/nex-packet/src/dns.rs @@ -1,4 +1,7 @@ -use crate::packet::{GenericMutablePacket, Packet}; +use crate::{ + packet::{GenericMutablePacket, Packet}, + parse::ParseError, +}; use bytes::{BufMut, Bytes, BytesMut}; use core::str; use nex_core::bitfield::{u1, u16be, u32be}; @@ -720,6 +723,12 @@ impl DnsQueryPacket { } Ok(qname) } + + /// Parse the query name with compression-pointer validation. + pub fn try_get_qname_parsed(&self) -> Result { + decode_dns_name(&self.qname, 0).map(|(name, _)| name) + } + pub fn qname_length(&self) -> usize { self.to_bytes().iter().take_while(|w| *w != &0).count() + 1 } @@ -1013,83 +1022,11 @@ pub struct DnsPacket { impl Packet for DnsPacket { type Header = (); fn from_buf(buf: &[u8]) -> Option { - if buf.len() < 12 { - return None; - } - - let mut cursor = buf; - - // Read DNS header - let id = u16::from_be_bytes([cursor[0], cursor[1]]); - let flags = u16::from_be_bytes([cursor[2], cursor[3]]); - let query_count = u16::from_be_bytes([cursor[4], cursor[5]]); - let response_count = u16::from_be_bytes([cursor[6], cursor[7]]); - let authority_rr_count = u16::from_be_bytes([cursor[8], cursor[9]]); - let additional_rr_count = u16::from_be_bytes([cursor[10], cursor[11]]); - cursor = &cursor[12..]; - - let header = DnsHeader { - id: id.into(), - is_response: ((flags >> 15) & 0x1) as u8, - opcode: OpCode::new(((flags >> 11) & 0xF) as u8), - is_authoriative: ((flags >> 10) & 0x1) as u8, - is_truncated: ((flags >> 9) & 0x1) as u8, - is_recursion_desirable: ((flags >> 8) & 0x1) as u8, - is_recursion_available: ((flags >> 7) & 0x1) as u8, - zero_reserved: ((flags >> 6) & 0x1) as u8, - is_answer_authenticated: ((flags >> 5) & 0x1) as u8, - is_non_authenticated_data: ((flags >> 4) & 0x1) as u8, - rcode: RetCode::new((flags & 0xF) as u8), - query_count: query_count.into(), - response_count: response_count.into(), - authority_rr_count: authority_rr_count.into(), - additional_rr_count: additional_rr_count.into(), - }; - - // Parse each section, passing mutable slices - fn parse_queries(count: usize, buf: &mut &[u8]) -> Option> { - (0..count) - .map(|_| DnsQueryPacket::from_buf_mut(buf)) - .collect() - } - - fn parse_responses(count: usize, buf: &mut &[u8]) -> Option> { - let mut packets = Vec::with_capacity(count); - for _ in 0..count { - match DnsResponsePacket::from_buf_mut(buf) { - Some(pkt) => { - packets.push(pkt); - } - _ => { - break; - } - } - } - Some(packets) - } - - let mut working_buf = cursor; - - let queries = parse_queries(query_count as usize, &mut working_buf)?; - let responses = parse_responses(response_count as usize, &mut working_buf)?; - let authorities = parse_responses(authority_rr_count as usize, &mut working_buf)?; - let additionals = parse_responses(additional_rr_count as usize, &mut working_buf)?; - - // Remaining data becomes the payload - let payload = Bytes::copy_from_slice(working_buf); - - Some(Self { - header, - queries, - responses, - authorities, - additionals, - payload, - }) + Self::try_from_buf(buf).ok() } - fn from_bytes(mut bytes: Bytes) -> Option { - Self::from_buf(&mut bytes) + fn from_bytes(bytes: Bytes) -> Option { + Self::try_from_bytes(bytes).ok() } fn to_bytes(&self) -> Bytes { @@ -1157,7 +1094,7 @@ impl Packet for DnsPacket { } fn total_len(&self) -> usize { - self.header_len() + self.payload_len() + self.header_len() + self.payload.len() } fn into_parts(self) -> (Self::Header, Bytes) { @@ -1167,7 +1104,111 @@ impl Packet for DnsPacket { } } +impl DnsPacket { + /// Parse a DNS packet and return a structured error on failure. + pub fn try_from_buf(buf: &[u8]) -> Result { + if buf.len() < 12 { + return Err(ParseError::BufferTooShort { + context: "DNS packet", + minimum: 12, + actual: buf.len(), + }); + } + + let mut cursor = buf; + + // Read DNS header + let id = u16::from_be_bytes([cursor[0], cursor[1]]); + let flags = u16::from_be_bytes([cursor[2], cursor[3]]); + let query_count = u16::from_be_bytes([cursor[4], cursor[5]]); + let response_count = u16::from_be_bytes([cursor[6], cursor[7]]); + let authority_rr_count = u16::from_be_bytes([cursor[8], cursor[9]]); + let additional_rr_count = u16::from_be_bytes([cursor[10], cursor[11]]); + cursor = &cursor[12..]; + + let header = DnsHeader { + id: id.into(), + is_response: ((flags >> 15) & 0x1) as u8, + opcode: OpCode::new(((flags >> 11) & 0xF) as u8), + is_authoriative: ((flags >> 10) & 0x1) as u8, + is_truncated: ((flags >> 9) & 0x1) as u8, + is_recursion_desirable: ((flags >> 8) & 0x1) as u8, + is_recursion_available: ((flags >> 7) & 0x1) as u8, + zero_reserved: ((flags >> 6) & 0x1) as u8, + is_answer_authenticated: ((flags >> 5) & 0x1) as u8, + is_non_authenticated_data: ((flags >> 4) & 0x1) as u8, + rcode: RetCode::new((flags & 0xF) as u8), + query_count: query_count.into(), + response_count: response_count.into(), + authority_rr_count: authority_rr_count.into(), + additional_rr_count: additional_rr_count.into(), + }; + + // Parse each section, passing mutable slices + fn parse_queries(count: usize, buf: &mut &[u8]) -> Option> { + (0..count) + .map(|_| DnsQueryPacket::from_buf_mut(buf)) + .collect() + } + + fn parse_responses(count: usize, buf: &mut &[u8]) -> Option> { + let mut packets = Vec::with_capacity(count); + for _ in 0..count { + match DnsResponsePacket::from_buf_mut(buf) { + Some(pkt) => { + packets.push(pkt); + } + _ => { + break; + } + } + } + Some(packets) + } + + let mut working_buf = cursor; + + let queries = + parse_queries(query_count as usize, &mut working_buf).ok_or(ParseError::Malformed { + context: "DNS query section", + })?; + let responses = parse_responses(response_count as usize, &mut working_buf).ok_or( + ParseError::Malformed { + context: "DNS answer section", + }, + )?; + let authorities = parse_responses(authority_rr_count as usize, &mut working_buf).ok_or( + ParseError::Malformed { + context: "DNS authority section", + }, + )?; + let additionals = parse_responses(additional_rr_count as usize, &mut working_buf).ok_or( + ParseError::Malformed { + context: "DNS additional section", + }, + )?; + + // Remaining data becomes the payload + let payload = Bytes::copy_from_slice(working_buf); + + Ok(Self { + header, + queries, + responses, + authorities, + additionals, + payload, + }) + } + + /// Parse a DNS packet from owned bytes and return a structured error on failure. + pub fn try_from_bytes(bytes: Bytes) -> Result { + Self::try_from_buf(&bytes) + } +} + /// Represents a DNS name +#[derive(Clone, Debug, PartialEq, Eq)] pub struct DnsName(String); impl DnsName { @@ -1203,6 +1244,11 @@ impl DnsName { pub fn labels(&self) -> Vec<&str> { self.0.split('.').collect() } + + /// Parses a DNS name with compression-pointer validation. + pub fn try_from_bytes(buf: &[u8]) -> Result { + decode_dns_name(buf, 0).map(|(name, _)| DnsName(name)) + } } impl std::fmt::Display for DnsName { @@ -1211,6 +1257,97 @@ impl std::fmt::Display for DnsName { } } +const DNS_MAX_COMPRESSION_DEPTH: usize = 16; + +fn decode_dns_name(buf: &[u8], start: usize) -> Result<(String, usize), ParseError> { + let mut labels = Vec::new(); + let mut pos = start; + let mut consumed = 0usize; + let mut jumped = false; + let mut visited = Vec::new(); + let mut depth = 0usize; + + loop { + if pos >= buf.len() { + return Err(ParseError::Truncated { + context: "DNS name", + expected: pos + 1, + actual: buf.len(), + }); + } + + let len = buf[pos]; + if !jumped { + consumed += 1; + } + + if len == 0 { + break; + } + + if (len & 0xC0) == 0xC0 { + if pos + 1 >= buf.len() { + return Err(ParseError::Truncated { + context: "DNS compression pointer", + expected: pos + 2, + actual: buf.len(), + }); + } + let pointer = (((len & 0x3F) as usize) << 8) | buf[pos + 1] as usize; + if pointer >= buf.len() { + return Err(ParseError::InvalidCompression { + context: "DNS compression pointer", + }); + } + if visited.contains(&pointer) { + return Err(ParseError::CompressionLoop { + context: "DNS name", + }); + } + visited.push(pointer); + depth += 1; + if depth > DNS_MAX_COMPRESSION_DEPTH { + return Err(ParseError::CompressionLoop { + context: "DNS name", + }); + } + if !jumped { + consumed += 1; + } + pos = pointer; + jumped = true; + continue; + } + + if (len & 0xC0) != 0 { + return Err(ParseError::InvalidCompression { + context: "DNS label encoding", + }); + } + + let label_len = len as usize; + pos += 1; + if pos + label_len > buf.len() { + return Err(ParseError::Truncated { + context: "DNS label", + expected: pos + label_len, + actual: buf.len(), + }); + } + let label = + str::from_utf8(&buf[pos..pos + label_len]).map_err(|_| ParseError::InvalidUtf8 { + context: "DNS label", + })?; + labels.push(label.to_string()); + if !jumped { + consumed += label_len; + } + pos += label_len; + } + + Ok((labels.join("."), consumed)) +} + /// Represents a mutable DNS packet. pub type MutableDnsPacket<'a> = GenericMutablePacket<'a, DnsPacket>; @@ -1316,4 +1453,28 @@ mod tests { assert_eq!(frozen.header.id, 0x1234); assert_eq!(frozen.payload[0], 0xaa); } + + #[test] + fn dns_name_detects_compression_loop() { + let err = DnsName::try_from_bytes(&[0xc0, 0x00]).expect_err("loop should fail"); + assert!(matches!(err, ParseError::CompressionLoop { .. })); + } + + #[test] + fn dns_query_try_get_qname_parsed_supports_compression() { + let query = DnsQueryPacket { + qname: vec![ + 0x03, b'w', b'w', b'w', 0xc0, 0x06, 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', + 0x03, b'c', b'o', b'm', 0x00, + ], + qtype: DnsType::A, + qclass: DnsClass::IN, + payload: Bytes::new(), + }; + + assert_eq!( + query.try_get_qname_parsed().expect("compressed name"), + "www.example.com" + ); + } } diff --git a/nex-packet/src/ethernet.rs b/nex-packet/src/ethernet.rs index caecaf2..93b1c75 100644 --- a/nex-packet/src/ethernet.rs +++ b/nex-packet/src/ethernet.rs @@ -7,7 +7,10 @@ use nex_core::mac::MacAddr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::packet::{MutablePacket, Packet}; +use crate::{ + packet::{MutablePacket, Packet}, + parse::ParseError, +}; /// Represents the Ethernet header length. pub const ETHERNET_HEADER_LEN: usize = 14; @@ -194,27 +197,10 @@ impl Packet for EthernetPacket { type Header = EthernetHeader; fn from_buf(bytes: &[u8]) -> Option { - if bytes.len() < ETHERNET_HEADER_LEN { - return None; - } - let destination = - MacAddr::from_octets([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]]); - let source = - MacAddr::from_octets([bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11]]); - let ethertype = EtherType::new(u16::from_be_bytes([bytes[12], bytes[13]])); - let payload = Bytes::copy_from_slice(&bytes[ETHERNET_HEADER_LEN..]); - - Some(EthernetPacket { - header: EthernetHeader { - destination, - source, - ethertype, - }, - payload, - }) + Self::try_from_buf(bytes).ok() } fn from_bytes(bytes: Bytes) -> Option { - Self::from_buf(&bytes) + Self::try_from_bytes(bytes).ok() } fn to_bytes(&self) -> Bytes { let mut buf = Vec::with_capacity(ETHERNET_HEADER_LEN + self.payload.len()); @@ -271,6 +257,58 @@ impl EthernetPacket { None } } + + /// Parse an Ethernet packet and return a structured error on failure. + pub fn try_from_buf(bytes: &[u8]) -> Result { + if bytes.len() < ETHERNET_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "Ethernet packet", + minimum: ETHERNET_HEADER_LEN, + actual: bytes.len(), + }); + } + + let destination = + MacAddr::from_octets([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]]); + let source = + MacAddr::from_octets([bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11]]); + let ethertype = EtherType::new(u16::from_be_bytes([bytes[12], bytes[13]])); + + Ok(EthernetPacket { + header: EthernetHeader { + destination, + source, + ethertype, + }, + payload: Bytes::copy_from_slice(&bytes[ETHERNET_HEADER_LEN..]), + }) + } + + /// Parse an Ethernet packet from owned bytes while preserving the payload slice. + pub fn try_from_bytes(bytes: Bytes) -> Result { + if bytes.len() < ETHERNET_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "Ethernet packet", + minimum: ETHERNET_HEADER_LEN, + actual: bytes.len(), + }); + } + + let destination = + MacAddr::from_octets([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]]); + let source = + MacAddr::from_octets([bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11]]); + let ethertype = EtherType::new(u16::from_be_bytes([bytes[12], bytes[13]])); + + Ok(EthernetPacket { + header: EthernetHeader { + destination, + source, + ethertype, + }, + payload: bytes.slice(ETHERNET_HEADER_LEN..), + }) + } } impl fmt::Display for EthernetPacket { diff --git a/nex-packet/src/frame.rs b/nex-packet/src/frame.rs index 412d357..077c6b5 100644 --- a/nex-packet/src/frame.rs +++ b/nex-packet/src/frame.rs @@ -10,6 +10,7 @@ use crate::{ ipv4::{Ipv4Header, Ipv4Packet}, ipv6::{Ipv6Header, Ipv6Packet}, packet::Packet, + parse::ParseError, tcp::{TcpHeader, TcpPacket}, udp::{UdpHeader, UdpPacket}, }; @@ -67,36 +68,69 @@ pub struct Frame { } impl Frame { + /// Parse a frame from a raw buffer. + /// + /// Unknown or currently unsupported payloads are preserved in `payload` + /// so callers can still inspect the raw bytes. pub fn from_buf(packet: &[u8], option: ParseOption) -> Option { - let mut frame = Frame { - datalink: None, - ip: None, - transport: None, - payload: Bytes::new(), - packet_len: packet.len(), - }; - - let ethernet_packet = if option.from_ip_packet { - create_dummy_ethernet_packet(packet, option.offset)? - } else { - EthernetPacket::from_buf(packet)? - }; - - let ether_type = ethernet_packet.get_ethertype(); - let (ether_header, ether_payload) = ethernet_packet.into_parts(); - frame.datalink = Some(DatalinkLayer { - ethernet: Some(ether_header), - arp: None, - }); - - match ether_type { - EtherType::Ipv4 => parse_ipv4_packet(ether_payload, &mut frame), - EtherType::Ipv6 => parse_ipv6_packet(ether_payload, &mut frame), - EtherType::Arp => parse_arp_packet(ether_payload, &mut frame), - _ => {} - } + Self::try_from_buf(packet, option).ok() + } + + /// Parse a frame and return a structured error on failure. + pub fn try_from_buf(packet: &[u8], option: ParseOption) -> Result { + parse_frame_from_bytes(Bytes::copy_from_slice(packet), option, false) + } + + /// Parse a frame from owned bytes while preserving payload slices when possible. + pub fn try_from_bytes(packet: Bytes, option: ParseOption) -> Result { + parse_frame_from_bytes(packet, option, false) + } + + /// Parse a frame using validation-oriented strict IP parsing. + pub fn try_from_buf_strict(packet: &[u8], option: ParseOption) -> Result { + parse_frame_from_bytes(Bytes::copy_from_slice(packet), option, true) + } + + /// Parse a frame from owned bytes using validation-oriented strict IP parsing. + pub fn try_from_bytes_strict(packet: Bytes, option: ParseOption) -> Result { + parse_frame_from_bytes(packet, option, true) + } + + /// Parse a frame using validation-oriented strict IP parsing. + pub fn from_buf_strict(packet: &[u8], option: ParseOption) -> Option { + Self::try_from_buf_strict(packet, option).ok() + } +} + +/// Borrowed frame view for zero-copy packet inspection on hot paths. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FrameView<'a> { + pub datalink: Option, + pub ip: Option, + pub transport: Option, + pub payload: &'a [u8], + pub packet_len: usize, +} + +impl<'a> FrameView<'a> { + /// Parse a frame view without allocating payload storage. + pub fn from_buf(packet: &'a [u8], option: ParseOption) -> Option { + Self::try_from_buf(packet, option).ok() + } - Some(frame) + /// Parse a frame view and return a structured error on failure. + pub fn try_from_buf(packet: &'a [u8], option: ParseOption) -> Result { + let offset = option.offset; + let from_ip_packet = option.from_ip_packet; + let frame = Frame::try_from_buf(packet, option)?; + let payload = find_payload_slice(packet, &frame, offset, from_ip_packet); + Ok(FrameView { + datalink: frame.datalink, + ip: frame.ip, + transport: frame.transport, + payload, + packet_len: frame.packet_len, + }) } } @@ -107,9 +141,9 @@ pub fn create_dummy_ethernet_packet(packet: &[u8], offset: usize) -> Option Option bool { + if packet.len() < 20 { + return false; + } + let version = packet[0] >> 4; + let header_length = (packet[0] & 0x0f) as usize; + version == 4 && header_length >= 5 && header_length * 4 <= packet.len() +} + +fn is_likely_ipv6_packet(packet: &[u8]) -> bool { + if packet.len() < 40 { + return false; + } + (packet[0] >> 4) == 6 +} + fn parse_arp_packet(packet: Bytes, frame: &mut Frame) { match ArpPacket::from_buf(&packet) { Some(arp_packet) => { @@ -143,9 +193,14 @@ fn parse_arp_packet(packet: Bytes, frame: &mut Frame) { } } -fn parse_ipv4_packet(packet: Bytes, frame: &mut Frame) { - match Ipv4Packet::from_bytes(packet) { - Some(ipv4_packet) => { +fn parse_ipv4_packet(packet: Bytes, frame: &mut Frame, strict: bool) -> Result<(), ParseError> { + let parsed = if strict { + Ipv4Packet::try_from_bytes_strict(packet) + } else { + Ipv4Packet::try_from_bytes(packet) + }; + match parsed { + Ok(ipv4_packet) => { let (header, payload) = ipv4_packet.into_parts(); let proto = header.next_level_protocol; frame.ip = Some(IpLayer { @@ -168,21 +223,29 @@ fn parse_ipv4_packet(packet: Bytes, frame: &mut Frame) { frame.payload = payload; } } + Ok(()) } - None => { + Err(err) if strict => Err(err), + Err(_) => { frame.ip = Some(IpLayer { ipv4: None, ipv6: None, icmp: None, icmpv6: None, }); + Ok(()) } } } -fn parse_ipv6_packet(packet: Bytes, frame: &mut Frame) { - match Ipv6Packet::from_bytes(packet) { - Some(ipv6_packet) => { +fn parse_ipv6_packet(packet: Bytes, frame: &mut Frame, strict: bool) -> Result<(), ParseError> { + let parsed = if strict { + Ipv6Packet::try_from_bytes_strict(packet) + } else { + Ipv6Packet::try_from_bytes(packet) + }; + match parsed { + Ok(ipv6_packet) => { let (header, payload) = ipv6_packet.into_parts(); let proto = header.next_header; frame.ip = Some(IpLayer { @@ -205,14 +268,17 @@ fn parse_ipv6_packet(packet: Bytes, frame: &mut Frame) { frame.payload = payload; } } + Ok(()) } - None => { + Err(err) if strict => Err(err), + Err(_) => { frame.ip = Some(IpLayer { ipv4: None, ipv6: None, icmp: None, icmpv6: None, }); + Ok(()) } } } @@ -257,6 +323,148 @@ fn parse_udp_packet(packet: Bytes, frame: &mut Frame) { } } +fn parse_frame_from_bytes( + packet: Bytes, + option: ParseOption, + strict: bool, +) -> Result { + let packet_len = packet.len(); + let mut frame = Frame { + datalink: None, + ip: None, + transport: None, + payload: Bytes::new(), + packet_len, + }; + + let ethernet_packet = if option.from_ip_packet { + create_dummy_ethernet_packet(&packet, option.offset).ok_or(ParseError::Malformed { + context: "Frame dummy Ethernet classification", + })? + } else { + EthernetPacket::try_from_bytes(packet)? + }; + + let ether_type = ethernet_packet.get_ethertype(); + let (ether_header, ether_payload) = ethernet_packet.into_parts(); + frame.datalink = Some(DatalinkLayer { + ethernet: Some(ether_header), + arp: None, + }); + + match ether_type { + EtherType::Ipv4 => parse_ipv4_packet(ether_payload, &mut frame, strict)?, + EtherType::Ipv6 => parse_ipv6_packet(ether_payload, &mut frame, strict)?, + EtherType::Arp => parse_arp_packet(ether_payload, &mut frame), + _ => frame.payload = ether_payload, + } + + Ok(frame) +} + +fn find_payload_slice<'a>( + packet: &'a [u8], + frame: &Frame, + offset: usize, + from_ip_packet: bool, +) -> &'a [u8] { + let start = if from_ip_packet { offset } else { 14 }; + let available = packet.get(start..).unwrap_or(&[]); + let payload_len = frame.payload.len(); + if payload_len > available.len() { + return &[]; + } + &available[available.len() - payload_len..] +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ethernet::ETHERNET_HEADER_LEN; + + #[test] + fn frame_preserves_unknown_ethertype_payload() { + let payload = [0xde, 0xad, 0xbe, 0xef]; + let mut raw = vec![0u8; ETHERNET_HEADER_LEN + payload.len()]; + raw[12] = 0x88; + raw[13] = 0xb5; + raw[ETHERNET_HEADER_LEN..].copy_from_slice(&payload); + + let frame = Frame::from_buf(&raw, ParseOption::default()).expect("frame"); + + assert_eq!(frame.payload, Bytes::from(payload.to_vec())); + assert!(frame.ip.is_none()); + assert!(frame.transport.is_none()); + } + + #[test] + fn frame_keeps_known_ethertype_parsing_behavior() { + let mut raw = vec![0u8; ETHERNET_HEADER_LEN + 20 + 8 + 4]; + raw[12] = 0x08; + raw[13] = 0x00; + raw[14] = 0x45; + raw[15] = 0x00; + raw[16] = 0x00; + raw[17] = 0x20; + raw[18] = 0x00; + raw[19] = 0x01; + raw[20] = 0x00; + raw[21] = 0x00; + raw[22] = 64; + raw[23] = IpNextProtocol::Udp.value(); + raw[24] = 0; + raw[25] = 0; + raw[26] = 192; + raw[27] = 0; + raw[28] = 2; + raw[29] = 1; + raw[30] = 198; + raw[31] = 51; + raw[32] = 100; + raw[33] = 2; + raw[34] = 0x04; + raw[35] = 0xd2; + raw[36] = 0x00; + raw[37] = 0x35; + raw[38] = 0x00; + raw[39] = 0x0c; + raw[40] = 0x00; + raw[41] = 0x00; + raw[42..46].copy_from_slice(&[1, 2, 3, 4]); + + let frame = Frame::from_buf(&raw, ParseOption::default()).expect("frame"); + + assert_eq!( + frame + .ip + .as_ref() + .and_then(|ip| ip.ipv4.as_ref()) + .map(|h| h.version), + Some(4) + ); + assert_eq!( + frame + .transport + .as_ref() + .and_then(|tr| tr.udp.as_ref()) + .map(|h| h.destination), + Some(53) + ); + assert_eq!(frame.payload, Bytes::from_static(&[1, 2, 3, 4])); + } + + #[test] + fn dummy_ethernet_packet_uses_lightweight_ip_detection() { + let ipv4 = [ + 0x45, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 64, 17, 0, 0, 127, 0, 0, 1, 127, 0, 0, + 1, + ]; + let packet = create_dummy_ethernet_packet(&ipv4, 0).expect("dummy ethernet"); + assert_eq!(packet.header.ethertype, EtherType::Ipv4); + assert_eq!(packet.payload, Bytes::from(ipv4.to_vec())); + } +} + fn parse_icmp_packet(packet: Bytes, frame: &mut Frame) { match IcmpPacket::from_bytes(packet.clone()) { Some(icmp_packet) => { diff --git a/nex-packet/src/ipv4.rs b/nex-packet/src/ipv4.rs index c2c6c54..4376842 100644 --- a/nex-packet/src/ipv4.rs +++ b/nex-packet/src/ipv4.rs @@ -4,6 +4,7 @@ use crate::{ checksum::{ChecksumMode, ChecksumState}, ip::IpNextProtocol, packet::{MutablePacket, Packet}, + parse::ParseError, util, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -211,113 +212,11 @@ impl Packet for Ipv4Packet { type Header = Ipv4Header; fn from_buf(bytes: &[u8]) -> Option { - if bytes.len() < IPV4_HEADER_LEN { - return None; - } - - let version = (bytes[0] & 0xF0) >> 4; - let header_length = (bytes[0] & 0x0F) as usize; - let total_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize; - let total_length = if total_length > bytes.len() { - // fallback - bytes.len() - } else { - total_length - }; - - if header_length < 5 { - return None; - } - - let ihl_bytes = header_length * 4; - if ihl_bytes < IPV4_HEADER_LEN || ihl_bytes > total_length { - return None; - } - let payload = Bytes::copy_from_slice(&bytes[ihl_bytes..total_length]); - - let mut options = Vec::new(); - let mut i = IPV4_HEADER_LEN; - - while i < ihl_bytes { - let b = bytes[i]; - let copied = (b >> 7) & 0x01; - let class = (b >> 5) & 0x03; - let number = Ipv4OptionType::new(b & 0b0001_1111); - - match number { - Ipv4OptionType::EOL => { - options.push(Ipv4OptionPacket { - header: Ipv4OptionHeader { - copied, - class, - number, - length: None, - }, - data: Bytes::new(), - }); - break; - } - Ipv4OptionType::NOP => { - options.push(Ipv4OptionPacket { - header: Ipv4OptionHeader { - copied, - class, - number, - length: None, - }, - data: Bytes::new(), - }); - i += 1; - } - _ => { - if i + 2 > ihl_bytes { - break; - } - let len = bytes[i + 1] as usize; - if len < 2 || i + len > ihl_bytes { - break; - } - - let data = Bytes::copy_from_slice(&bytes[i + 2..i + len]); - - options.push(Ipv4OptionPacket { - header: Ipv4OptionHeader { - copied, - class, - number, - length: Some(len as u8), - }, - data, - }); - - i += len; - } - } - } - - Some(Self { - header: Ipv4Header { - version: version as u4, - header_length: header_length as u4, - dscp: (bytes[1] >> 2) as u6, - ecn: (bytes[1] & 0x03) as u2, - total_length: u16::from_be_bytes([bytes[2], bytes[3]]) as u16be, - identification: u16::from_be_bytes([bytes[4], bytes[5]]) as u16be, - flags: (bytes[6] >> 5) as u3, - fragment_offset: ((u16::from_be_bytes([bytes[6], bytes[7]])) & 0x1FFF) as u13be, - ttl: bytes[8], - next_level_protocol: IpNextProtocol::new(bytes[9]), - checksum: u16::from_be_bytes([bytes[10], bytes[11]]) as u16be, - source: Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15]), - destination: Ipv4Addr::new(bytes[16], bytes[17], bytes[18], bytes[19]), - options, - }, - payload, - }) + Self::try_from_buf(bytes).ok() } fn from_bytes(bytes: Bytes) -> Option { - Self::from_buf(&bytes) + Self::try_from_bytes(bytes).ok() } fn to_bytes(&self) -> Bytes { @@ -405,12 +304,209 @@ impl Packet for Ipv4Packet { } impl Ipv4Packet { + /// Parse an IPv4 packet and return a structured error on failure. + pub fn try_from_buf(bytes: &[u8]) -> Result { + parse_ipv4_from_slice(bytes, false) + } + + /// Parse an IPv4 packet from owned bytes while preserving payload slices when possible. + pub fn try_from_bytes(bytes: Bytes) -> Result { + parse_ipv4_from_bytes(bytes, false) + } + + /// Parse an IPv4 packet using validation-oriented strict checks. + pub fn try_from_buf_strict(bytes: &[u8]) -> Result { + parse_ipv4_from_slice(bytes, true) + } + + /// Parse an IPv4 packet from owned bytes using validation-oriented strict checks. + pub fn try_from_bytes_strict(bytes: Bytes) -> Result { + parse_ipv4_from_bytes(bytes, true) + } + + /// Parse an IPv4 packet using validation-oriented strict checks. + pub fn from_buf_strict(bytes: &[u8]) -> Option { + Self::try_from_buf_strict(bytes).ok() + } + + /// Parse an IPv4 packet from owned bytes using validation-oriented strict checks. + pub fn from_bytes_strict(bytes: Bytes) -> Option { + Self::try_from_bytes_strict(bytes).ok() + } + pub fn with_computed_checksum(mut self) -> Self { self.header.checksum = checksum(&self); self } } +fn parse_ipv4_from_slice(bytes: &[u8], strict: bool) -> Result { + parse_ipv4_parts(bytes, strict, |range| Bytes::copy_from_slice(&bytes[range])) +} + +fn parse_ipv4_from_bytes(bytes: Bytes, strict: bool) -> Result { + parse_ipv4_parts(&bytes, strict, |range| bytes.slice(range)) +} + +fn parse_ipv4_parts( + bytes: &[u8], + strict: bool, + mut slice_bytes: F, +) -> Result +where + F: FnMut(std::ops::Range) -> Bytes, +{ + if bytes.len() < IPV4_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "IPv4 packet", + minimum: IPV4_HEADER_LEN, + actual: bytes.len(), + }); + } + + let version = (bytes[0] & 0xF0) >> 4; + if version != 4 { + return Err(ParseError::Malformed { + context: "IPv4 packet version", + }); + } + + let header_length = (bytes[0] & 0x0F) as usize; + if header_length < 5 { + return Err(ParseError::InvalidLength { + context: "IPv4 header length", + value: header_length, + }); + } + + let ihl_bytes = header_length * IPV4_HEADER_LENGTH_BYTE_UNITS; + if ihl_bytes < IPV4_HEADER_LEN || ihl_bytes > bytes.len() { + return Err(ParseError::Truncated { + context: "IPv4 header", + expected: ihl_bytes, + actual: bytes.len(), + }); + } + + let declared_total_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize; + let effective_declared_total_length = if declared_total_length == 0 { + // Some offloaded captures report a zero IPv4 total length even though the + // full packet bytes are present in the capture buffer. Treat those as + // "use the captured buffer length" for non-strict parsing. + bytes.len() + } else { + declared_total_length + }; + + if effective_declared_total_length < ihl_bytes { + return Err(ParseError::InvalidLength { + context: "IPv4 total length", + value: declared_total_length, + }); + } + + let total_length = if strict { + if effective_declared_total_length > bytes.len() { + return Err(ParseError::Truncated { + context: "IPv4 packet", + expected: effective_declared_total_length, + actual: bytes.len(), + }); + } + effective_declared_total_length + } else { + effective_declared_total_length.min(bytes.len()) + }; + + let mut options = Vec::new(); + let mut i = IPV4_HEADER_LEN; + while i < ihl_bytes { + let b = bytes[i]; + let copied = (b >> 7) & 0x01; + let class = (b >> 5) & 0x03; + let number = Ipv4OptionType::new(b & 0b0001_1111); + + match number { + Ipv4OptionType::EOL => { + options.push(Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied, + class, + number, + length: None, + }, + data: Bytes::new(), + }); + break; + } + Ipv4OptionType::NOP => { + options.push(Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied, + class, + number, + length: None, + }, + data: Bytes::new(), + }); + i += 1; + } + _ => { + if i + 2 > ihl_bytes { + if strict { + return Err(ParseError::Malformed { + context: "IPv4 options", + }); + } + break; + } + let len = bytes[i + 1] as usize; + if len < 2 || i + len > ihl_bytes { + if strict { + return Err(ParseError::InvalidLength { + context: "IPv4 option length", + value: len, + }); + } + break; + } + + options.push(Ipv4OptionPacket { + header: Ipv4OptionHeader { + copied, + class, + number, + length: Some(len as u8), + }, + data: slice_bytes(i + 2..i + len), + }); + + i += len; + } + } + } + + Ok(Ipv4Packet { + header: Ipv4Header { + version: version as u4, + header_length: header_length as u4, + dscp: (bytes[1] >> 2) as u6, + ecn: (bytes[1] & 0x03) as u2, + total_length: total_length as u16be, + identification: u16::from_be_bytes([bytes[4], bytes[5]]) as u16be, + flags: (bytes[6] >> 5) as u3, + fragment_offset: ((u16::from_be_bytes([bytes[6], bytes[7]])) & 0x1FFF) as u13be, + ttl: bytes[8], + next_level_protocol: IpNextProtocol::new(bytes[9]), + checksum: u16::from_be_bytes([bytes[10], bytes[11]]) as u16be, + source: Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15]), + destination: Ipv4Addr::new(bytes[16], bytes[17], bytes[18], bytes[19]), + options, + }, + payload: slice_bytes(ihl_bytes..total_length), + }) +} + /// Represents a mutable IPv4 packet. pub struct MutableIpv4Packet<'a> { buffer: &'a mut [u8], @@ -985,4 +1081,30 @@ mod tests { assert_eq!(recomputed, packet.get_checksum()); assert!(!packet.is_checksum_dirty()); } + + #[test] + fn ipv4_try_from_buf_reports_strict_truncation() { + let raw = [ + 0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 64, 17, 0, 0, 127, 0, 0, 1, 127, 0, 0, + 1, 1, 2, 3, 4, + ]; + + let err = Ipv4Packet::try_from_buf_strict(&raw).expect_err("strict parse should fail"); + assert!(matches!(err, ParseError::Truncated { .. })); + assert!(Ipv4Packet::from_buf(&raw).is_some()); + } + + #[test] + fn ipv4_zero_total_length_uses_captured_length() { + let raw = Bytes::from_static(&[ + 0x45, 0x00, 0x00, 0x00, // total length reported as zero + 0x68, 0x23, 0x40, 0x00, 0x80, 0x06, 0x00, 0x00, 192, 168, 10, 113, 192, 168, 10, 10, + 0xde, 0xad, 0xbe, 0xef, + ]); + + let packet = Ipv4Packet::from_bytes(raw.clone()).expect("TSO-style packet should parse"); + assert_eq!(packet.header.total_length as usize, raw.len()); + assert_eq!(packet.payload.len(), raw.len() - IPV4_HEADER_LEN); + assert_eq!(packet.payload, Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef])); + } } diff --git a/nex-packet/src/ipv6.rs b/nex-packet/src/ipv6.rs index ff4ce05..055cf9c 100644 --- a/nex-packet/src/ipv6.rs +++ b/nex-packet/src/ipv6.rs @@ -1,5 +1,6 @@ use crate::ip::IpNextProtocol; use crate::packet::{MutablePacket, Packet}; +use crate::parse::ParseError; use bytes::{BufMut, Bytes, BytesMut}; use std::net::Ipv6Addr; @@ -32,151 +33,10 @@ impl Packet for Ipv6Packet { type Header = Ipv6Header; fn from_buf(bytes: &[u8]) -> Option { - if bytes.len() < IPV6_HEADER_LEN { - return None; - } - - // --- Parse the header section --- - let version_traffic_flow = &bytes[..4]; - let version = version_traffic_flow[0] >> 4; - let traffic_class = - ((version_traffic_flow[0] & 0x0F) << 4) | (version_traffic_flow[1] >> 4); - let flow_label = u32::from(version_traffic_flow[1] & 0x0F) << 16 - | u32::from(version_traffic_flow[2]) << 8 - | u32::from(version_traffic_flow[3]); - - let payload_length = u16::from_be_bytes([bytes[4], bytes[5]]); - let mut next_header = IpNextProtocol::new(bytes[6]); - let hop_limit = bytes[7]; - - let source = Ipv6Addr::from(<[u8; 16]>::try_from(&bytes[8..24]).ok()?); - let destination = Ipv6Addr::from(<[u8; 16]>::try_from(&bytes[24..40]).ok()?); - - let header = Ipv6Header { - version, - traffic_class, - flow_label, - payload_length, - next_header, - hop_limit, - source, - destination, - }; - - // --- Walk through the extension headers --- - let mut offset = IPV6_HEADER_LEN; - let mut extensions = Vec::new(); - - loop { - match next_header { - IpNextProtocol::Hopopt - | IpNextProtocol::Ipv6Route - | IpNextProtocol::Ipv6Frag - | IpNextProtocol::Ipv6Opts => { - if offset + 2 > bytes.len() { - return None; - } - - let nh = IpNextProtocol::new(bytes[offset]); - let ext_len = bytes[offset + 1] as usize; - - match next_header { - IpNextProtocol::Hopopt | IpNextProtocol::Ipv6Opts => { - let total_len = 8 + ext_len * 8; - if offset + total_len > bytes.len() { - return None; - } - - let data = - Bytes::copy_from_slice(&bytes[offset + 2..offset + total_len]); - let ext = match next_header { - IpNextProtocol::Hopopt => { - Ipv6ExtensionHeader::HopByHop { next: nh, data } - } - IpNextProtocol::Ipv6Opts => { - Ipv6ExtensionHeader::Destination { next: nh, data } - } - _ => Ipv6ExtensionHeader::Raw { - next: nh, - raw: Bytes::copy_from_slice(&bytes[offset..offset + total_len]), - }, - }; - - extensions.push(ext); - next_header = nh; - offset += total_len; - } - - IpNextProtocol::Ipv6Route => { - if offset + 4 > bytes.len() { - return None; - } - - let routing_type = bytes[offset + 2]; - let segments_left = bytes[offset + 3]; - let total_len = 8 + ext_len * 8; - if offset + total_len > bytes.len() { - return None; - } - - let data = - Bytes::copy_from_slice(&bytes[offset + 4..offset + total_len]); - extensions.push(Ipv6ExtensionHeader::Routing { - next: nh, - routing_type, - segments_left, - data, - }); - - next_header = nh; - offset += total_len; - } - - IpNextProtocol::Ipv6Frag => { - if offset + 8 > bytes.len() { - return None; - } - - //let reserved = bytes[offset + 1]; - let frag_off_flags = - u16::from_be_bytes([bytes[offset + 2], bytes[offset + 3]]); - let offset_val = frag_off_flags >> 3; - let more = (frag_off_flags & 0x1) != 0; - let id = u32::from_be_bytes([ - bytes[offset + 4], - bytes[offset + 5], - bytes[offset + 6], - bytes[offset + 7], - ]); - - extensions.push(Ipv6ExtensionHeader::Fragment { - next: nh, - offset: offset_val, - more, - id, - }); - - next_header = nh; - offset += 8; - } - - _ => break, - } - } - - _ => break, - } - } - - let payload = Bytes::copy_from_slice(&bytes[offset..]); - Some(Ipv6Packet { - header, - extensions, - payload, - }) + Self::try_from_buf(bytes).ok() } fn from_bytes(bytes: Bytes) -> Option { - Self::from_buf(&bytes) + Self::try_from_bytes(bytes).ok() } fn to_bytes(&self) -> Bytes { @@ -289,6 +149,26 @@ impl Packet for Ipv6Packet { } impl Ipv6Packet { + /// Parse an IPv6 packet and return a structured error on failure. + pub fn try_from_buf(bytes: &[u8]) -> Result { + parse_ipv6_from_slice(bytes, false) + } + + /// Parse an IPv6 packet from owned bytes while preserving payload slices when possible. + pub fn try_from_bytes(bytes: Bytes) -> Result { + parse_ipv6_from_bytes(bytes, false) + } + + /// Parse an IPv6 packet using validation-oriented strict checks. + pub fn try_from_buf_strict(bytes: &[u8]) -> Result { + parse_ipv6_from_slice(bytes, true) + } + + /// Parse an IPv6 packet from owned bytes using validation-oriented strict checks. + pub fn try_from_bytes_strict(bytes: Bytes) -> Result { + parse_ipv6_from_bytes(bytes, true) + } + pub fn total_len(&self) -> usize { IPV6_HEADER_LEN + self.extensions.iter().map(|ext| ext.len()).sum::() @@ -299,6 +179,187 @@ impl Ipv6Packet { } } +fn parse_ipv6_from_slice(bytes: &[u8], strict: bool) -> Result { + parse_ipv6_parts(bytes, strict, |range| Bytes::copy_from_slice(&bytes[range])) +} + +fn parse_ipv6_from_bytes(bytes: Bytes, strict: bool) -> Result { + parse_ipv6_parts(&bytes, strict, |range| bytes.slice(range)) +} + +fn parse_ipv6_parts( + bytes: &[u8], + strict: bool, + mut slice_bytes: F, +) -> Result +where + F: FnMut(std::ops::Range) -> Bytes, +{ + if bytes.len() < IPV6_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "IPv6 packet", + minimum: IPV6_HEADER_LEN, + actual: bytes.len(), + }); + } + + let version_traffic_flow = &bytes[..4]; + let version = version_traffic_flow[0] >> 4; + if version != 6 { + return Err(ParseError::Malformed { + context: "IPv6 packet version", + }); + } + let traffic_class = ((version_traffic_flow[0] & 0x0F) << 4) | (version_traffic_flow[1] >> 4); + let flow_label = u32::from(version_traffic_flow[1] & 0x0F) << 16 + | u32::from(version_traffic_flow[2]) << 8 + | u32::from(version_traffic_flow[3]); + let payload_length = u16::from_be_bytes([bytes[4], bytes[5]]); + let mut next_header = IpNextProtocol::new(bytes[6]); + let hop_limit = bytes[7]; + let source = + Ipv6Addr::from( + <[u8; 16]>::try_from(&bytes[8..24]).map_err(|_| ParseError::Malformed { + context: "IPv6 source address", + })?, + ); + let destination = Ipv6Addr::from(<[u8; 16]>::try_from(&bytes[24..40]).map_err(|_| { + ParseError::Malformed { + context: "IPv6 destination address", + } + })?); + + let header = Ipv6Header { + version, + traffic_class, + flow_label, + payload_length, + next_header, + hop_limit, + source, + destination, + }; + + let declared_total_len = IPV6_HEADER_LEN + payload_length as usize; + if strict && declared_total_len > bytes.len() { + return Err(ParseError::Truncated { + context: "IPv6 payload", + expected: declared_total_len, + actual: bytes.len(), + }); + } + let available_end = declared_total_len.min(bytes.len()); + + let mut offset = IPV6_HEADER_LEN; + let mut extensions = Vec::new(); + loop { + match next_header { + IpNextProtocol::Hopopt + | IpNextProtocol::Ipv6Route + | IpNextProtocol::Ipv6Frag + | IpNextProtocol::Ipv6Opts => { + if offset + 2 > available_end { + return Err(ParseError::Truncated { + context: "IPv6 extension header", + expected: offset + 2, + actual: available_end, + }); + } + + let nh = IpNextProtocol::new(bytes[offset]); + let ext_len = bytes[offset + 1] as usize; + match next_header { + IpNextProtocol::Hopopt | IpNextProtocol::Ipv6Opts => { + let total_len = 8 + ext_len * 8; + if offset + total_len > available_end { + return Err(ParseError::Truncated { + context: "IPv6 extension header", + expected: offset + total_len, + actual: available_end, + }); + } + let data = slice_bytes(offset + 2..offset + total_len); + let ext = match next_header { + IpNextProtocol::Hopopt => { + Ipv6ExtensionHeader::HopByHop { next: nh, data } + } + IpNextProtocol::Ipv6Opts => { + Ipv6ExtensionHeader::Destination { next: nh, data } + } + _ => unreachable!(), + }; + extensions.push(ext); + next_header = nh; + offset += total_len; + } + IpNextProtocol::Ipv6Route => { + if offset + 4 > available_end { + return Err(ParseError::Truncated { + context: "IPv6 routing header", + expected: offset + 4, + actual: available_end, + }); + } + let routing_type = bytes[offset + 2]; + let segments_left = bytes[offset + 3]; + let total_len = 8 + ext_len * 8; + if offset + total_len > available_end { + return Err(ParseError::Truncated { + context: "IPv6 routing header", + expected: offset + total_len, + actual: available_end, + }); + } + extensions.push(Ipv6ExtensionHeader::Routing { + next: nh, + routing_type, + segments_left, + data: slice_bytes(offset + 4..offset + total_len), + }); + next_header = nh; + offset += total_len; + } + IpNextProtocol::Ipv6Frag => { + if offset + 8 > available_end { + return Err(ParseError::Truncated { + context: "IPv6 fragment header", + expected: offset + 8, + actual: available_end, + }); + } + let frag_off_flags = + u16::from_be_bytes([bytes[offset + 2], bytes[offset + 3]]); + let offset_val = frag_off_flags >> 3; + let more = (frag_off_flags & 0x1) != 0; + let id = u32::from_be_bytes([ + bytes[offset + 4], + bytes[offset + 5], + bytes[offset + 6], + bytes[offset + 7], + ]); + extensions.push(Ipv6ExtensionHeader::Fragment { + next: nh, + offset: offset_val, + more, + id, + }); + next_header = nh; + offset += 8; + } + _ => unreachable!(), + } + } + _ => break, + } + } + + Ok(Ipv6Packet { + header, + extensions, + payload: slice_bytes(offset..available_end), + }) +} + /// Represents a mutable IPv6 packet. pub struct MutableIpv6Packet<'a> { buffer: &'a mut [u8], @@ -754,4 +815,16 @@ mod tests { assert_eq!(frozen.header.source, Ipv6Addr::LOCALHOST); assert_eq!(frozen.payload[0], 0xaa); } + + #[test] + fn ipv6_try_from_buf_reports_strict_truncation() { + let raw = Bytes::from_static(&[ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3, 4, + ]); + + let err = Ipv6Packet::try_from_buf_strict(&raw).expect_err("strict parse should fail"); + assert!(matches!(err, ParseError::Truncated { .. })); + assert!(Ipv6Packet::from_buf(&raw).is_some()); + } } diff --git a/nex-packet/src/lib.rs b/nex-packet/src/lib.rs index 36be4df..a8067c6 100644 --- a/nex-packet/src/lib.rs +++ b/nex-packet/src/lib.rs @@ -15,6 +15,7 @@ pub mod ip; pub mod ipv4; pub mod ipv6; pub mod packet; +pub mod parse; pub mod tcp; pub mod udp; pub mod util; diff --git a/nex-packet/src/parse.rs b/nex-packet/src/parse.rs new file mode 100644 index 0000000..3697b9f --- /dev/null +++ b/nex-packet/src/parse.rs @@ -0,0 +1,92 @@ +//! Structured parse errors for diagnosable packet parsing APIs. + +use core::fmt; + +/// Structured error returned by `try_from_*` parsing APIs. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum ParseError { + /// The input buffer was shorter than the protocol minimum. + BufferTooShort { + /// Human-readable parse context. + context: &'static str, + /// Minimum required number of bytes. + minimum: usize, + /// Actual number of bytes available. + actual: usize, + }, + /// A length-like field contained an invalid value. + InvalidLength { + /// Human-readable parse context. + context: &'static str, + /// Parsed value that failed validation. + value: usize, + }, + /// The packet contains a malformed header field. + Malformed { + /// Human-readable parse context. + context: &'static str, + }, + /// The packet payload was truncated relative to its header lengths. + Truncated { + /// Human-readable parse context. + context: &'static str, + /// Expected number of bytes. + expected: usize, + /// Actual number of bytes available. + actual: usize, + }, + /// Parsing failed because a compression loop or excessive indirection was detected. + CompressionLoop { + /// Human-readable parse context. + context: &'static str, + }, + /// Parsing failed because an unsupported or invalid pointer/compression form was encountered. + InvalidCompression { + /// Human-readable parse context. + context: &'static str, + }, + /// A UTF-8 conversion failed while parsing text-like data. + InvalidUtf8 { + /// Human-readable parse context. + context: &'static str, + }, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseError::BufferTooShort { + context, + minimum, + actual, + } => write!( + f, + "{context}: buffer too short (expected at least {minimum} bytes, got {actual})" + ), + ParseError::InvalidLength { context, value } => { + write!(f, "{context}: invalid length value {value}") + } + ParseError::Malformed { context } => write!(f, "{context}: malformed packet data"), + ParseError::Truncated { + context, + expected, + actual, + } => write!( + f, + "{context}: truncated payload (expected {expected} bytes, got {actual})" + ), + ParseError::CompressionLoop { context } => { + write!(f, "{context}: compression pointer loop detected") + } + ParseError::InvalidCompression { context } => { + write!(f, "{context}: invalid compression pointer") + } + ParseError::InvalidUtf8 { context } => { + write!(f, "{context}: invalid UTF-8 sequence") + } + } + } +} + +impl std::error::Error for ParseError {} diff --git a/nex-packet/src/tcp.rs b/nex-packet/src/tcp.rs index 12f160e..f8c1de9 100644 --- a/nex-packet/src/tcp.rs +++ b/nex-packet/src/tcp.rs @@ -3,6 +3,7 @@ use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext}; use crate::ip::IpNextProtocol; use crate::packet::{MutablePacket, Packet}; +use crate::parse::ParseError; use crate::util::{self, Octets}; use std::net::Ipv6Addr; @@ -466,91 +467,10 @@ impl Packet for TcpPacket { type Header = TcpHeader; fn from_buf(mut bytes: &[u8]) -> Option { - if bytes.len() < TCP_HEADER_LEN { - return None; - } - - let source = bytes.get_u16(); - let destination = bytes.get_u16(); - let sequence = bytes.get_u32(); - let acknowledgement = bytes.get_u32(); - - let offset_reserved = bytes.get_u8(); - let data_offset = offset_reserved >> 4; - let reserved = offset_reserved & 0x0F; - - let flags = bytes.get_u8(); - let window = bytes.get_u16(); - let checksum = bytes.get_u16(); - let urgent_ptr = bytes.get_u16(); - - let header_len = data_offset as usize * 4; - if header_len < TCP_HEADER_LEN || bytes.len() + 20 < header_len { - return None; - } - - let mut options = Vec::new(); - let options_len = header_len - TCP_HEADER_LEN; - let (mut options_bytes, rest) = bytes.split_at(options_len); - bytes = rest; - - while options_bytes.has_remaining() { - let kind = TcpOptionKind::new(options_bytes.get_u8()); - match kind { - TcpOptionKind::EOL => { - options.push(TcpOptionPacket { - kind, - length: None, - data: Bytes::new(), - }); - break; - } - TcpOptionKind::NOP => { - options.push(TcpOptionPacket { - kind, - length: None, - data: Bytes::new(), - }); - } - _ => { - if options_bytes.remaining() < 1 { - return None; - } - let len = options_bytes.get_u8(); - if len < 2 || (len as usize) > options_bytes.remaining() + 2 { - return None; - } - let data_len = (len - 2) as usize; - let (data_slice, rest) = options_bytes.split_at(data_len); - options_bytes = rest; - options.push(TcpOptionPacket { - kind, - length: Some(len), - data: Bytes::copy_from_slice(data_slice), - }); - } - } - } - - Some(TcpPacket { - header: TcpHeader { - source, - destination, - sequence, - acknowledgement, - data_offset: u4::from_be(data_offset), - reserved: u4::from_be(reserved), - flags, - window, - checksum, - urgent_ptr, - options, - }, - payload: Bytes::copy_from_slice(bytes), - }) + Self::try_from_buf(&mut bytes).ok() } fn from_bytes(mut bytes: Bytes) -> Option { - Self::from_buf(&mut bytes) + Self::try_from_bytes(bytes.split_to(bytes.len())).ok() } fn to_bytes(&self) -> Bytes { @@ -657,6 +577,219 @@ impl Packet for TcpPacket { } impl TcpPacket { + /// Parse a TCP packet and return a structured error on failure. + pub fn try_from_buf(mut bytes: &[u8]) -> Result { + if bytes.len() < TCP_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "TCP packet", + minimum: TCP_HEADER_LEN, + actual: bytes.len(), + }); + } + + let source = bytes.get_u16(); + let destination = bytes.get_u16(); + let sequence = bytes.get_u32(); + let acknowledgement = bytes.get_u32(); + + let offset_reserved = bytes.get_u8(); + let data_offset = offset_reserved >> 4; + let reserved = offset_reserved & 0x0F; + + let flags = bytes.get_u8(); + let window = bytes.get_u16(); + let checksum = bytes.get_u16(); + let urgent_ptr = bytes.get_u16(); + + let header_len = data_offset as usize * 4; + if header_len < TCP_HEADER_LEN { + return Err(ParseError::InvalidLength { + context: "TCP data offset", + value: header_len, + }); + } + if bytes.len() + TCP_HEADER_LEN < header_len { + return Err(ParseError::Truncated { + context: "TCP header", + expected: header_len, + actual: bytes.len() + TCP_HEADER_LEN, + }); + } + + let mut options = Vec::new(); + let options_len = header_len - TCP_HEADER_LEN; + let (mut options_bytes, rest) = bytes.split_at(options_len); + bytes = rest; + + while options_bytes.has_remaining() { + let kind = TcpOptionKind::new(options_bytes.get_u8()); + match kind { + TcpOptionKind::EOL => { + options.push(TcpOptionPacket { + kind, + length: None, + data: Bytes::new(), + }); + break; + } + TcpOptionKind::NOP => { + options.push(TcpOptionPacket { + kind, + length: None, + data: Bytes::new(), + }); + } + _ => { + if options_bytes.remaining() < 1 { + return Err(ParseError::Malformed { + context: "TCP options", + }); + } + let len = options_bytes.get_u8(); + if len < 2 || (len as usize) > options_bytes.remaining() + 2 { + return Err(ParseError::InvalidLength { + context: "TCP option length", + value: len as usize, + }); + } + let data_len = (len - 2) as usize; + let (data_slice, rest) = options_bytes.split_at(data_len); + options_bytes = rest; + options.push(TcpOptionPacket { + kind, + length: Some(len), + data: Bytes::copy_from_slice(data_slice), + }); + } + } + } + + Ok(TcpPacket { + header: TcpHeader { + source, + destination, + sequence, + acknowledgement, + data_offset: u4::from_be(data_offset), + reserved: u4::from_be(reserved), + flags, + window, + checksum, + urgent_ptr, + options, + }, + payload: Bytes::copy_from_slice(bytes), + }) + } + + /// Parse a TCP packet from owned bytes while preserving payload slices when possible. + pub fn try_from_bytes(bytes: Bytes) -> Result { + if bytes.len() < TCP_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "TCP packet", + minimum: TCP_HEADER_LEN, + actual: bytes.len(), + }); + } + + let source = u16::from_be_bytes([bytes[0], bytes[1]]); + let destination = u16::from_be_bytes([bytes[2], bytes[3]]); + let sequence = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + let acknowledgement = u32::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]); + let offset_reserved = bytes[12]; + let data_offset = offset_reserved >> 4; + let reserved = offset_reserved & 0x0F; + let flags = bytes[13]; + let window = u16::from_be_bytes([bytes[14], bytes[15]]); + let checksum = u16::from_be_bytes([bytes[16], bytes[17]]); + let urgent_ptr = u16::from_be_bytes([bytes[18], bytes[19]]); + + let header_len = data_offset as usize * 4; + if header_len < TCP_HEADER_LEN { + return Err(ParseError::InvalidLength { + context: "TCP data offset", + value: header_len, + }); + } + if bytes.len() < header_len { + return Err(ParseError::Truncated { + context: "TCP header", + expected: header_len, + actual: bytes.len(), + }); + } + + let mut options = Vec::new(); + let mut offset = TCP_HEADER_LEN; + while offset < header_len { + let kind = TcpOptionKind::new(bytes[offset]); + offset += 1; + match kind { + TcpOptionKind::EOL => { + options.push(TcpOptionPacket { + kind, + length: None, + data: Bytes::new(), + }); + break; + } + TcpOptionKind::NOP => { + options.push(TcpOptionPacket { + kind, + length: None, + data: Bytes::new(), + }); + } + _ => { + if offset >= header_len { + return Err(ParseError::Malformed { + context: "TCP options", + }); + } + let len = bytes[offset]; + offset += 1; + if len < 2 { + return Err(ParseError::InvalidLength { + context: "TCP option length", + value: len as usize, + }); + } + let data_len = (len - 2) as usize; + if offset + data_len > header_len { + return Err(ParseError::Truncated { + context: "TCP option data", + expected: data_len, + actual: header_len.saturating_sub(offset), + }); + } + options.push(TcpOptionPacket { + kind, + length: Some(len), + data: bytes.slice(offset..offset + data_len), + }); + offset += data_len; + } + } + } + + Ok(TcpPacket { + header: TcpHeader { + source, + destination, + sequence, + acknowledgement, + data_offset: u4::from_be(data_offset), + reserved: u4::from_be(reserved), + flags, + window, + checksum, + urgent_ptr, + options, + }, + payload: bytes.slice(header_len..), + }) + } + pub fn tcp_options_length(&self) -> usize { if self.header.data_offset > 5 { self.header.data_offset as usize * 4 - 20 diff --git a/nex-packet/src/udp.rs b/nex-packet/src/udp.rs index 7e72003..f7ae44e 100644 --- a/nex-packet/src/udp.rs +++ b/nex-packet/src/udp.rs @@ -3,6 +3,7 @@ use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext}; use crate::ip::IpNextProtocol; use crate::packet::{MutablePacket, Packet}; +use crate::parse::ParseError; use crate::util; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -35,38 +36,10 @@ pub struct UdpPacket { impl Packet for UdpPacket { type Header = UdpHeader; fn from_buf(mut bytes: &[u8]) -> Option { - if bytes.len() < UDP_HEADER_LEN { - return None; - } - - let source = bytes.get_u16(); - let destination = bytes.get_u16(); - let length = bytes.get_u16(); - let checksum = bytes.get_u16(); - - if length < UDP_HEADER_LEN as u16 { - return None; - } - - let payload_len = length as usize - UDP_HEADER_LEN; - if bytes.len() < payload_len { - return None; - } - - let (payload_slice, _) = bytes.split_at(payload_len); - - Some(UdpPacket { - header: UdpHeader { - source, - destination, - length, - checksum, - }, - payload: Bytes::copy_from_slice(payload_slice), - }) + Self::try_from_buf(&mut bytes).ok() } fn from_bytes(mut bytes: Bytes) -> Option { - Self::from_buf(&mut bytes) + Self::try_from_bytes(bytes.split_to(bytes.len())).ok() } fn to_bytes(&self) -> Bytes { let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN + self.payload.len()); @@ -169,6 +142,92 @@ impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> { } } +impl UdpPacket { + /// Parse a UDP packet and return a structured error on failure. + pub fn try_from_buf(mut bytes: &[u8]) -> Result { + if bytes.len() < UDP_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "UDP packet", + minimum: UDP_HEADER_LEN, + actual: bytes.len(), + }); + } + + let source = bytes.get_u16(); + let destination = bytes.get_u16(); + let length = bytes.get_u16(); + let checksum = bytes.get_u16(); + + if length < UDP_HEADER_LEN as u16 { + return Err(ParseError::InvalidLength { + context: "UDP length", + value: length as usize, + }); + } + + let payload_len = length as usize - UDP_HEADER_LEN; + if bytes.len() < payload_len { + return Err(ParseError::Truncated { + context: "UDP payload", + expected: payload_len, + actual: bytes.len(), + }); + } + + Ok(UdpPacket { + header: UdpHeader { + source, + destination, + length, + checksum, + }, + payload: Bytes::copy_from_slice(&bytes[..payload_len]), + }) + } + + /// Parse a UDP packet from owned bytes while preserving the payload slice. + pub fn try_from_bytes(bytes: Bytes) -> Result { + if bytes.len() < UDP_HEADER_LEN { + return Err(ParseError::BufferTooShort { + context: "UDP packet", + minimum: UDP_HEADER_LEN, + actual: bytes.len(), + }); + } + + let source = u16::from_be_bytes([bytes[0], bytes[1]]); + let destination = u16::from_be_bytes([bytes[2], bytes[3]]); + let length = u16::from_be_bytes([bytes[4], bytes[5]]); + let checksum = u16::from_be_bytes([bytes[6], bytes[7]]); + + if length < UDP_HEADER_LEN as u16 { + return Err(ParseError::InvalidLength { + context: "UDP length", + value: length as usize, + }); + } + + let payload_len = length as usize - UDP_HEADER_LEN; + if bytes.len() < UDP_HEADER_LEN + payload_len { + return Err(ParseError::Truncated { + context: "UDP payload", + expected: payload_len, + actual: bytes.len().saturating_sub(UDP_HEADER_LEN), + }); + } + + Ok(UdpPacket { + header: UdpHeader { + source, + destination, + length, + checksum, + }, + payload: bytes.slice(UDP_HEADER_LEN..UDP_HEADER_LEN + payload_len), + }) + } +} + impl<'a> MutableUdpPacket<'a> { /// Create a new packet without validating length fields. pub fn new_unchecked(buffer: &'a mut [u8]) -> Self { diff --git a/nex-socket/Cargo.toml b/nex-socket/Cargo.toml index 4121053..e4e37d2 100644 --- a/nex-socket/Cargo.toml +++ b/nex-socket/Cargo.toml @@ -21,7 +21,7 @@ libc = { workspace = true } nix = { version = "0.30", features = ["poll", "net", "uio"] } [target.'cfg(windows)'.dependencies.windows-sys] -version = "0.59.0" +version = "0.61" features = [ "Win32_Foundation", "Win32_Networking_WinSock", diff --git a/nex-socket/src/icmp/config.rs b/nex-socket/src/icmp/config.rs index 788ed46..79a9558 100644 --- a/nex-socket/src/icmp/config.rs +++ b/nex-socket/src/icmp/config.rs @@ -154,6 +154,56 @@ impl IcmpConfig { self.fib = Some(fib); self } + + /// Validate the configuration before socket creation. + pub fn validate(&self) -> io::Result<()> { + if let Some(addr) = self.bind { + let addr_family = crate::SocketFamily::from_socket_addr(&addr); + if addr_family != self.socket_family { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bind address family does not match socket_family", + )); + } + } + + if self.socket_family.is_v4() && self.hoplimit.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "hoplimit is only supported for IPv6 ICMP sockets", + )); + } + + if self.socket_family.is_v6() && self.ttl.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "ttl is only supported for IPv4 ICMP sockets", + )); + } + + if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "read_timeout must be greater than zero", + )); + } + + if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "write_timeout must be greater than zero", + )); + } + + if matches!(self.interface.as_deref(), Some("")) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "interface must not be empty", + )); + } + + Ok(()) + } } #[cfg(test)] @@ -182,4 +232,10 @@ mod tests { assert_eq!(v4.socket_family, SocketFamily::IPV4); assert_eq!(v6.socket_family, SocketFamily::IPV6); } + + #[test] + fn icmp_config_validate_rejects_family_mismatch() { + let cfg = IcmpConfig::new(IcmpKind::V4).with_bind("[::1]:0".parse().unwrap()); + assert!(cfg.validate().is_err()); + } } diff --git a/nex-socket/src/icmp/sync_impl.rs b/nex-socket/src/icmp/sync_impl.rs index fcd2eaf..79772ef 100644 --- a/nex-socket/src/icmp/sync_impl.rs +++ b/nex-socket/src/icmp/sync_impl.rs @@ -15,6 +15,8 @@ pub struct IcmpSocket { impl IcmpSocket { /// Create a new synchronous ICMP socket. pub fn new(config: &IcmpConfig) -> io::Result { + config.validate()?; + let (domain, proto) = match config.socket_family { SocketFamily::IPV4 => (Domain::IPV4, Some(Protocol::ICMPV4)), SocketFamily::IPV6 => (Domain::IPV6, Some(Protocol::ICMPV6)), diff --git a/nex-socket/src/tcp/async_impl.rs b/nex-socket/src/tcp/async_impl.rs index 06178a2..c9e48a0 100644 --- a/nex-socket/src/tcp/async_impl.rs +++ b/nex-socket/src/tcp/async_impl.rs @@ -14,6 +14,8 @@ pub struct AsyncTcpSocket { impl AsyncTcpSocket { /// Create a socket from the given configuration without connecting. pub fn from_config(config: &TcpConfig) -> io::Result { + config.validate()?; + let socket = Socket::new( config.socket_family.to_domain(), config.socket_type.to_sock_type(), diff --git a/nex-socket/src/tcp/config.rs b/nex-socket/src/tcp/config.rs index 2170c7e..19830aa 100644 --- a/nex-socket/src/tcp/config.rs +++ b/nex-socket/src/tcp/config.rs @@ -1,4 +1,5 @@ use socket2::Type as SockType; +use std::io; use std::net::SocketAddr; use std::time::Duration; @@ -229,6 +230,84 @@ impl TcpConfig { self.bind_device = Some(iface.into()); self } + + /// Validate the configuration before socket creation. + pub fn validate(&self) -> io::Result<()> { + if let Some(addr) = self.bind_addr { + let addr_family = crate::SocketFamily::from_socket_addr(&addr); + if addr_family != self.socket_family { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bind_addr family does not match socket_family", + )); + } + } + + if self.socket_family.is_v4() { + if self.hoplimit.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "hoplimit is only supported for IPv6 TCP sockets", + )); + } + if self.tclass_v6.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "tclass_v6 is only supported for IPv6 TCP sockets", + )); + } + if self.only_v6.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "only_v6 is only supported for IPv6 TCP sockets", + )); + } + } + + if self.socket_family.is_v6() && self.ttl.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "ttl is only supported for IPv4 TCP sockets", + )); + } + + if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "read_timeout must be greater than zero", + )); + } + + if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "write_timeout must be greater than zero", + )); + } + + if matches!(self.recv_buffer_size, Some(0)) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "recv_buffer_size must be greater than zero", + )); + } + + if matches!(self.send_buffer_size, Some(0)) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "send_buffer_size must be greater than zero", + )); + } + + if matches!(self.bind_device.as_deref(), Some("")) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bind_device must not be empty", + )); + } + + Ok(()) + } } #[cfg(test)] @@ -270,4 +349,10 @@ mod tests { assert_eq!(cfg.socket_family, SocketFamily::IPV6); assert_eq!(cfg.socket_type, TcpSocketType::Stream); } + + #[test] + fn tcp_config_validate_rejects_family_mismatch() { + let cfg = TcpConfig::v4_stream().with_bind("[::1]:0".parse().unwrap()); + assert!(cfg.validate().is_err()); + } } diff --git a/nex-socket/src/tcp/sync_impl.rs b/nex-socket/src/tcp/sync_impl.rs index 0edd329..f43676d 100644 --- a/nex-socket/src/tcp/sync_impl.rs +++ b/nex-socket/src/tcp/sync_impl.rs @@ -9,17 +9,20 @@ use crate::tcp::TcpConfig; use std::os::fd::AsRawFd; #[cfg(unix)] -use nix::poll::{PollFd, PollFlags, poll}; +use nix::poll::{PollFd, PollFlags, PollTimeout, poll}; /// Low level synchronous TCP socket. #[derive(Debug)] pub struct TcpSocket { socket: Socket, + nonblocking: bool, } impl TcpSocket { /// Build a socket according to `TcpSocketConfig`. pub fn from_config(config: &TcpConfig) -> io::Result { + config.validate()?; + let socket = Socket::new( config.socket_family.to_domain(), config.socket_type.to_sock_type(), @@ -111,14 +114,20 @@ impl TcpSocket { socket.bind(&addr.into())?; } - Ok(Self { socket }) + Ok(Self { + socket, + nonblocking: config.nonblocking, + }) } /// Create a socket of arbitrary type (STREAM or RAW). pub fn new(domain: Domain, sock_type: SockType) -> io::Result { let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?; socket.set_nonblocking(false)?; - Ok(Self { socket }) + Ok(Self { + socket, + nonblocking: false, + }) } /// Convenience constructor for an IPv4 STREAM socket. @@ -151,14 +160,17 @@ impl TcpSocket { self.socket.connect(&addr.into()) } - /// Connect to the target address with a timeout. + /// Connect to the target address with a timeout and return the connected stream. + /// + /// The returned `TcpStream` must be used for subsequent I/O. #[cfg(unix)] pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result { - let raw_fd = self.socket.as_raw_fd(); - self.socket.set_nonblocking(true)?; + let socket = self.socket.try_clone()?; + socket.set_nonblocking(true)?; + let raw_fd = socket.as_raw_fd(); // Try to connect first - match self.socket.connect(&target.into()) { + match socket.connect(&target.into()) { Ok(_) => { /* succeeded immediately */ } Err(err) if err.kind() == io::ErrorKind::WouldBlock @@ -170,22 +182,21 @@ impl TcpSocket { } // Wait for the connection using poll - let timeout_ms = timeout.as_millis() as i32; use std::os::unix::io::BorrowedFd; // Safety: raw_fd is valid for the lifetime of this scope let mut fds = [PollFd::new( unsafe { BorrowedFd::borrow_raw(raw_fd) }, PollFlags::POLLOUT, )]; - let n = poll(&mut fds, Some(timeout_ms as u16))?; + let poll_timeout = PollTimeout::try_from(timeout).unwrap_or(PollTimeout::MAX); + let n = poll(&mut fds, poll_timeout)?; if n == 0 { return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")); } // Check the result with `SO_ERROR` - let err: i32 = self - .socket + let err: i32 = socket .take_error()? .map(|e| e.raw_os_error().unwrap_or(0)) .unwrap_or(0); @@ -193,9 +204,9 @@ impl TcpSocket { return Err(io::Error::from_raw_os_error(err)); } - self.socket.set_nonblocking(false)?; + socket.set_nonblocking(self.nonblocking)?; - match self.socket.try_clone() { + match socket.try_clone() { Ok(cloned_socket) => { // Convert the socket into a `std::net::TcpStream` let std_stream: TcpStream = cloned_socket.into(); @@ -205,6 +216,9 @@ impl TcpSocket { } } + /// Connect to the target address with a timeout and return the connected stream. + /// + /// The returned `TcpStream` must be used for subsequent I/O. #[cfg(windows)] pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result { use std::mem::size_of; @@ -213,11 +227,12 @@ impl TcpSocket { POLLWRNORM, SO_ERROR, SOCKET, SOCKET_ERROR, SOL_SOCKET, WSAPOLLFD, WSAPoll, getsockopt, }; - let sock = self.socket.as_raw_socket() as SOCKET; - self.socket.set_nonblocking(true)?; + let socket = self.socket.try_clone()?; + socket.set_nonblocking(true)?; + let sock = socket.as_raw_socket() as SOCKET; // Start connect - match self.socket.connect(&target.into()) { + match socket.connect(&target.into()) { Ok(_) => { /* connection succeeded immediately */ } Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {} Err(e) => return Err(e), @@ -255,9 +270,9 @@ impl TcpSocket { return Err(io::Error::from_raw_os_error(so_error)); } - self.socket.set_nonblocking(false)?; + socket.set_nonblocking(self.nonblocking)?; - let std_stream: TcpStream = self.socket.try_clone()?.into(); + let std_stream: TcpStream = socket.into(); Ok(std_stream) } @@ -520,7 +535,13 @@ impl TcpSocket { /// Construct from a raw `socket2::Socket`. pub fn from_socket(socket: Socket) -> Self { - Self { socket } + Self { + socket, + // `socket2::Socket` does not expose a portable getter for the current + // blocking mode, so externally supplied sockets default to blocking + // expectations in this synchronous wrapper. + nonblocking: false, + } } /// Borrow the inner `socket2::Socket`. @@ -533,3 +554,48 @@ impl TcpSocket { self.socket } } + +#[cfg(test)] +mod tests { + #[cfg(unix)] + use super::*; + #[cfg(unix)] + use libc::{F_GETFL, O_NONBLOCK, fcntl}; + #[cfg(unix)] + use std::net::TcpListener as StdTcpListener; + + #[cfg(unix)] + fn socket_is_nonblocking(socket: &Socket) -> bool { + let flags = unsafe { fcntl(socket.as_raw_fd(), F_GETFL) }; + assert!(flags >= 0, "F_GETFL failed: {}", io::Error::last_os_error()); + (flags & O_NONBLOCK) != 0 + } + + #[cfg(unix)] + #[test] + fn connect_timeout_does_not_mutate_original_nonblocking_state_after_invalid_input() { + let sock = TcpSocket::v4_stream().expect("socket"); + sock.socket.set_nonblocking(true).expect("set nonblocking"); + + let result = sock.connect_timeout("[::1]:80".parse().unwrap(), Duration::from_secs(1)); + assert!(result.is_err()); + assert!(socket_is_nonblocking(&sock.socket)); + } + + #[cfg(unix)] + #[test] + fn connect_timeout_does_not_mutate_original_blocking_state_after_success() { + let listener = StdTcpListener::bind("127.0.0.1:0").expect("listener"); + let addr = listener.local_addr().expect("local addr"); + let handle = std::thread::spawn(move || listener.accept().expect("accept")); + + let sock = TcpSocket::v4_stream().expect("socket"); + sock.socket.set_nonblocking(false).expect("set blocking"); + let _stream = sock + .connect_timeout(addr, Duration::from_secs(1)) + .expect("connect"); + + assert!(!socket_is_nonblocking(&sock.socket)); + let _ = handle.join(); + } +} diff --git a/nex-socket/src/udp/async_impl.rs b/nex-socket/src/udp/async_impl.rs index a889786..a7b5a83 100644 --- a/nex-socket/src/udp/async_impl.rs +++ b/nex-socket/src/udp/async_impl.rs @@ -13,6 +13,8 @@ pub struct AsyncUdpSocket { impl AsyncUdpSocket { /// Create an asynchronous UDP socket from the given configuration. pub fn from_config(config: &UdpConfig) -> io::Result { + config.validate()?; + let socket = Socket::new( config.socket_family.to_domain(), config.socket_type.to_sock_type(), diff --git a/nex-socket/src/udp/config.rs b/nex-socket/src/udp/config.rs index e369c3d..8e9a250 100644 --- a/nex-socket/src/udp/config.rs +++ b/nex-socket/src/udp/config.rs @@ -1,4 +1,4 @@ -use std::{net::SocketAddr, time::Duration}; +use std::{io, net::SocketAddr, time::Duration}; use socket2::Type as SockType; @@ -213,6 +213,92 @@ impl UdpConfig { self.bind_device = Some(iface.into()); self } + + /// Validate the configuration before socket creation. + pub fn validate(&self) -> io::Result<()> { + if let Some(addr) = self.bind_addr { + let addr_family = crate::SocketFamily::from_socket_addr(&addr); + if addr_family != self.socket_family { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bind_addr family does not match socket_family", + )); + } + } + + if self.socket_family.is_v4() { + if self.hoplimit.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "hoplimit is only supported for IPv6 UDP sockets", + )); + } + if self.tclass_v6.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "tclass_v6 is only supported for IPv6 UDP sockets", + )); + } + if self.only_v6.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "only_v6 is only supported for IPv6 UDP sockets", + )); + } + } + + if self.socket_family.is_v6() { + if self.ttl.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "ttl is only supported for IPv4 UDP sockets", + )); + } + if self.broadcast.is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "broadcast is only supported for IPv4 UDP sockets", + )); + } + } + + if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "read_timeout must be greater than zero", + )); + } + + if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "write_timeout must be greater than zero", + )); + } + + if matches!(self.recv_buffer_size, Some(0)) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "recv_buffer_size must be greater than zero", + )); + } + + if matches!(self.send_buffer_size, Some(0)) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "send_buffer_size must be greater than zero", + )); + } + + if matches!(self.bind_device.as_deref(), Some("")) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "bind_device must not be empty", + )); + } + + Ok(()) + } } #[cfg(test)] @@ -243,4 +329,10 @@ mod tests { assert_eq!(cfg.socket_family, SocketFamily::IPV6); assert!(cfg.bind_addr.is_some()); } + + #[test] + fn udp_config_validate_rejects_ipv6_broadcast() { + let cfg = UdpConfig::new_with_family(SocketFamily::IPV6).with_broadcast(true); + assert!(cfg.validate().is_err()); + } } diff --git a/nex-socket/src/udp/sync_impl.rs b/nex-socket/src/udp/sync_impl.rs index 79ab3b0..7db7869 100644 --- a/nex-socket/src/udp/sync_impl.rs +++ b/nex-socket/src/udp/sync_impl.rs @@ -35,6 +35,8 @@ pub struct UdpSendMeta { impl UdpSocket { /// Create a socket from the provided configuration. pub fn from_config(config: &UdpConfig) -> io::Result { + config.validate()?; + let socket = Socket::new( config.socket_family.to_domain(), config.socket_type.to_sock_type(), @@ -670,6 +672,7 @@ mod tests { #[test] fn create_v4_socket() { let sock = UdpSocket::v4_dgram().expect("create socket"); + sock.socket.bind(&"0.0.0.0:0".parse::().unwrap().into()).expect("bind"); let addr = sock.local_addr().expect("addr"); assert!(addr.is_ipv4()); } diff --git a/nex-sys/Cargo.toml b/nex-sys/Cargo.toml index a84d8e1..f68d87d 100644 --- a/nex-sys/Cargo.toml +++ b/nex-sys/Cargo.toml @@ -14,7 +14,7 @@ license = "MIT" libc = { workspace = true } [target.'cfg(windows)'.dependencies.windows-sys] -version = "0.59.0" +version = "0.61" features = [ "Win32_Foundation", "Win32_Networking_WinSock", diff --git a/nex-sys/src/lib.rs b/nex-sys/src/lib.rs index 7b8732f..882aee9 100644 --- a/nex-sys/src/lib.rs +++ b/nex-sys/src/lib.rs @@ -24,6 +24,7 @@ impl Drop for FileDesc { } /// Sends data to a socket, returning the number of bytes sent. +#[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn send_to( socket: CSocket, buffer: &[u8], diff --git a/nex-sys/src/unix.rs b/nex-sys/src/unix.rs index 686a410..e22464a 100644 --- a/nex-sys/src/unix.rs +++ b/nex-sys/src/unix.rs @@ -29,6 +29,12 @@ pub const AF_INET6: libc::c_int = libc::AF_INET6; pub use libc::{IFF_BROADCAST, IFF_LOOPBACK, IFF_MULTICAST, IFF_POINTOPOINT, IFF_UP}; +/// Close a raw socket/file descriptor. +/// +/// # Safety +/// +/// `sock` must be a valid descriptor owned by the caller. It must not be used +/// again after this function returns. pub unsafe fn close(sock: CSocket) { unsafe { let _ = libc::close(sock); @@ -42,7 +48,7 @@ fn ntohs(u: u16) -> u16 { pub fn sockaddr_to_addr(storage: &SockAddrStorage, len: usize) -> io::Result { match storage.ss_family as libc::c_int { AF_INET => { - assert!(len as usize >= mem::size_of::()); + assert!(len >= mem::size_of::()); let storage: &SockAddrIn = unsafe { mem::transmute(storage) }; let ip = ipv4_addr_int(storage.sin_addr); // octets @@ -55,7 +61,7 @@ pub fn sockaddr_to_addr(storage: &SockAddrStorage, len: usize) -> io::Result { - assert!(len as usize >= mem::size_of::()); + assert!(len >= mem::size_of::()); let storage: &SockAddrIn6 = unsafe { mem::transmute(storage) }; let arr: [u16; 8] = unsafe { mem::transmute(storage.sin6_addr.s6_addr) }; // hextets @@ -81,7 +87,7 @@ pub fn sockaddr_to_addr(storage: &SockAddrStorage, len: usize) -> io::Result u32 { - (addr.s_addr as u32).to_be() + addr.s_addr.to_be() } /// Convert a platform specific `timeval` into a Duration. @@ -110,6 +116,12 @@ pub fn duration_to_timespec(dur: Duration) -> libc::timespec { } } +/// Call `sendto(2)` using raw socket arguments. +/// +/// # Safety +/// +/// `buf` must be valid for reads of `len` bytes. `addr` must point to a valid +/// socket address of length `addrlen`. pub unsafe fn sendto( socket: CSocket, buf: Buf, @@ -121,6 +133,12 @@ pub unsafe fn sendto( unsafe { libc::sendto(socket, buf, len, flags, addr, addrlen) } } +/// Call `recvfrom(2)` using raw socket arguments. +/// +/// # Safety +/// +/// `buf` must be valid for writes of `len` bytes. `addr` and `addrlen` must +/// point to writable storage for the returned socket address. pub unsafe fn recvfrom( socket: CSocket, buf: MutBuf,