diff --git a/zircon-object/src/ipc/socket.rs b/zircon-object/src/ipc/socket.rs index 601e791f6..36d8e4f09 100644 --- a/zircon-object/src/ipc/socket.rs +++ b/zircon-object/src/ipc/socket.rs @@ -1,7 +1,9 @@ use { crate::object::*, + alloc::collections::VecDeque, alloc::sync::{Arc, Weak}, - alloc::{collections::VecDeque, vec::Vec}, + alloc::vec::Vec, + bitflags::bitflags, spin::Mutex, }; @@ -14,17 +16,22 @@ use { pub struct Socket { base: KObjectBase, peer: Weak, + flags: SocketFlags, // constant value inner: Mutex, } #[derive(Default)] struct SocketInner { - read_disabled: bool, - read_threshold: usize, - write_threshold: usize, // only for core-test + control_msg: Vec, data: VecDeque, + datagram_len: VecDeque, + read_threshold: usize, + write_threshold: usize, + read_disabled: bool, } +const SOCKET_SIZE: usize = 128 * 2048; + impl_kobject!(Socket fn peer(&self) -> ZxResult> { let peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; @@ -35,100 +42,301 @@ impl_kobject!(Socket } ); -// Only support stream mode -// The size of data is unlimited +bitflags! { + /// Signals that waitable kernel objects expose to applications. + #[derive(Default)] + pub struct SocketFlags: u32 { + #[allow(clippy::identity_op)] + // These options can be passed to socket_shutdown(). + const SHUTDOWN_WRITE = 1; + const SHUTDOWN_READ = 1 << 1; + const SHUTDOWN_MASK = Self::SHUTDOWN_WRITE.bits | Self::SHUTDOWN_READ.bits; + + // These can be passed to socket_create(). + // const STREAM = 0; // Don't use contains + const DATAGRAM = 1; + const HAS_CONTROL = 1 << 1; + const HAS_ACCEPT = 1 << 2; + const CREATE_MASK = Self::DATAGRAM.bits | Self::HAS_CONTROL.bits | Self::HAS_ACCEPT.bits; + + // These can be passed to socket_read() and socket_write(). + const SOCKET_CONTROL = 1 << 2; + + // These can be passed to socket_read(). + const SOCKET_PEEK = 1 << 3; + } +} + impl Socket { /// Create a socket. #[allow(unsafe_code)] - pub fn create() -> (Arc, Arc) { + pub fn create(flags: u32) -> ZxResult<(Arc, Arc)> { + let flags = SocketFlags::from_bits(flags).ok_or(ZxError::INVALID_ARGS)?; + if !(flags - SocketFlags::CREATE_MASK).is_empty() { + return Err(ZxError::INVALID_ARGS); + } + let mut starting_signals: Signal = Signal::WRITABLE; + if flags.contains(SocketFlags::HAS_ACCEPT) { + starting_signals |= Signal::SOCKET_SHARE; + } + if flags.contains(SocketFlags::HAS_CONTROL) { + starting_signals |= Signal::SOCKET_CONTROL_WRITABLE; + } let mut end0 = Arc::new(Socket { - base: KObjectBase::with_signal(Signal::WRITABLE), + base: KObjectBase::with_signal(starting_signals), peer: Weak::default(), + flags, inner: Default::default(), }); let end1 = Arc::new(Socket { - base: KObjectBase::with_signal(Signal::WRITABLE), + base: KObjectBase::with_signal(starting_signals), peer: Arc::downgrade(&end0), + flags, inner: Default::default(), }); // no other reference of `end0` unsafe { Arc::get_mut_unchecked(&mut end0).peer = Arc::downgrade(&end1); } - (end0, end1) + Ok((end0, end1)) } - /// Read data from the socket. - pub fn read(&self, size: usize, peek: bool) -> ZxResult> { - let mut inner = self.inner.lock(); - if inner.data.is_empty() { - let _peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; - if inner.read_disabled { + /// Write data to the socket. + pub fn write(&self, options: SocketFlags, data: &[u8], count: usize) -> ZxResult { + assert!(data.len() == count); + if options.contains(SocketFlags::SOCKET_CONTROL) { + if !self.flags.contains(SocketFlags::HAS_CONTROL) { return Err(ZxError::BAD_STATE); } + if count == 0 { + return Err(ZxError::INVALID_ARGS); + } + if count > 1024 { + return Err(ZxError::OUT_OF_RANGE); + } + let peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; + let actual_count = peer.write_control(data)?; + self.base.signal_clear(Signal::SOCKET_CONTROL_WRITABLE); + Ok(actual_count) + } else { + if self.base.signal().contains(Signal::SOCKET_WRITE_DISABLED) { + return Err(ZxError::BAD_STATE); + } + let peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; + let actual_count = peer.write_data(data)?; + if actual_count > 0 { + let mut clear = Signal::empty(); + let peer_inner = peer.inner.lock(); + let inner = self.inner.lock(); + let peer_rest_size = SOCKET_SIZE - peer_inner.data.len(); + if peer_rest_size == 0 { + clear |= Signal::WRITABLE; + } + if inner.write_threshold > 0 && peer_rest_size < inner.write_threshold { + clear |= Signal::SOCKET_WRITE_THRESHOLD; + } + self.base.signal_clear(clear); + } + Ok(actual_count) + } + } + + fn write_control(&self, data: &[u8]) -> ZxResult { + let mut inner = self.inner.lock(); + if !inner.control_msg.is_empty() { + return Err(ZxError::SHOULD_WAIT); + } + let actual_count = data.len(); + inner.control_msg.extend_from_slice(data); + self.base.signal_set(Signal::SOCKET_CONTROL_READABLE); + Ok(actual_count) + } + + fn write_data(&self, data: &[u8]) -> ZxResult { + let data_len = self.inner.lock().data.len(); + let was_empty = data_len == 0; + let rest_size = SOCKET_SIZE - data_len; + if rest_size == 0 { return Err(ZxError::SHOULD_WAIT); } - let size = size.min(inner.data.len()); - let data = if peek { - let (slice0, slice1) = inner.data.as_slices(); - if size <= slice0.len() { - Vec::from(&slice0[..size]) - } else { - let mut v = Vec::from(slice0); - v.extend(&slice1[..size - slice0.len()]); - v + let write_size = data.len().min(rest_size); + let actual_count = if self.flags.contains(SocketFlags::DATAGRAM) { + if data.len() > SOCKET_SIZE { + return Err(ZxError::OUT_OF_RANGE); } + self.write_datagram(&data[..write_size])? } else { - inner.data.drain(..size).collect() + self.write_stream(&data[..write_size])? }; - let mut clear = Signal::empty(); - if inner.read_threshold > 0 && inner.data.len() < inner.read_threshold { - clear |= Signal::SOCKET_READ_THRESHOLD; + if actual_count > 0 { + let mut set = Signal::empty(); + if was_empty { + set |= Signal::READABLE; + } + let inner = self.inner.lock(); + if inner.read_threshold > 0 && inner.data.len() >= inner.read_threshold { + set |= Signal::SOCKET_READ_THRESHOLD; + } + self.base.signal_set(set); } - if inner.data.is_empty() { - clear |= Signal::READABLE; + Ok(actual_count) + } + + fn write_datagram(&self, data: &[u8]) -> ZxResult { + if data.is_empty() { + return Err(ZxError::INVALID_ARGS); } - self.base.signal_clear(clear); - Ok(data) + let mut inner = self.inner.lock(); + let actual_count = data.len(); + inner.data.extend(&data[..]); + inner.datagram_len.push_back(actual_count); + Ok(actual_count) } - /// Write data to the socket. - pub fn write(&self, buffer: &[u8]) -> ZxResult { - let peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; - if self.signal().contains(Signal::SOCKET_WRITE_DISABLED) { - return Err(ZxError::BAD_STATE); + fn write_stream(&self, data: &[u8]) -> ZxResult { + let actual_count = data.len(); + let mut inner = self.inner.lock(); + inner.data.extend(&data[..]); + Ok(actual_count) + } + + /// Read data from the socket. + pub fn read(&self, options: SocketFlags, data: &mut [u8], count: usize) -> ZxResult { + assert!(data.len() == count); + if options.contains(SocketFlags::SOCKET_CONTROL) { + if !self.flags.contains(SocketFlags::HAS_CONTROL) { + return Err(ZxError::BAD_STATE); + } + self.read_control(options, data) + } else { + self.read_data(options, data) } - if buffer.is_empty() { - return Ok(0); + } + + fn read_control(&self, options: SocketFlags, data: &mut [u8]) -> ZxResult { + let mut inner = self.inner.lock(); + if inner.control_msg.is_empty() { + return Err(ZxError::SHOULD_WAIT); } - let buffer_len = buffer.len(); - let mut peer_inner = peer.inner.lock(); - let mut set = Signal::empty(); - if peer_inner.data.is_empty() { - set |= Signal::READABLE; + let read_size = data.len().min(inner.control_msg.len()); + if options.contains(SocketFlags::SOCKET_PEEK) { + for (i, x) in inner.control_msg.iter().take(read_size).enumerate() { + data[i] = *x; + } + } else { + for (i, x) in inner.control_msg.drain(..read_size).enumerate() { + data[i] = x; + } + self.base.signal_clear(Signal::SOCKET_CONTROL_READABLE); + if let Some(peer) = self.peer.upgrade() { + peer.base.signal_set(Signal::SOCKET_CONTROL_WRITABLE); + } + } + Ok(read_size) + } + + fn read_data(&self, options: SocketFlags, data: &mut [u8]) -> ZxResult { + let data_len = self.inner.lock().data.len(); + if data_len == 0 { + let _peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; + let inner = self.inner.lock(); + if inner.read_disabled { + return Err(ZxError::BAD_STATE); + } + return Err(ZxError::SHOULD_WAIT); + } + let was_full = data_len == SOCKET_SIZE; + let peek = options.contains(SocketFlags::SOCKET_PEEK); + let actual_count = if self.flags.contains(SocketFlags::DATAGRAM) { + self.read_datagram(options, data, peek)? + } else { + self.read_stream(options, data, peek)? + }; + if !peek && actual_count > 0 { + let inner = self.inner.lock(); + let mut clear = Signal::empty(); + if inner.read_threshold > 0 && inner.data.len() < inner.read_threshold { + clear |= Signal::SOCKET_READ_THRESHOLD; + } + if inner.data.is_empty() { + clear |= Signal::READABLE; + } + self.base.signal_clear(clear); + if let Ok(peer) = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED) { + let mut set = Signal::empty(); + let peer_inner = peer.inner.lock(); + if peer_inner.write_threshold > 0 + && SOCKET_SIZE - inner.data.len() >= peer_inner.write_threshold + { + set |= Signal::SOCKET_WRITE_THRESHOLD; + } + if was_full { + set |= Signal::WRITABLE; + } + peer.base.signal_set(set); + } } - peer_inner.data.extend(buffer); - if peer_inner.read_threshold > 0 && peer_inner.data.len() >= peer_inner.read_threshold { - set |= Signal::SOCKET_READ_THRESHOLD; + Ok(actual_count) + } + + fn read_datagram(&self, _options: SocketFlags, data: &mut [u8], peek: bool) -> ZxResult { + if data.is_empty() { + return Ok(0); } - peer.base.signal_set(set); - Ok(buffer_len) + let mut inner = self.inner.lock(); + let datagram_len = if peek { + *inner.datagram_len.get(0).unwrap() + } else { + inner.datagram_len.pop_front().unwrap() + }; + let read_size = data.len().min(datagram_len); + if peek { + for (i, x) in inner.data.iter().take(read_size).enumerate() { + data[i] = *x; + } + } else { + for (i, x) in inner.data.drain(..datagram_len).take(read_size).enumerate() { + data[i] = x; + } + }; + Ok(read_size) + } + + fn read_stream(&self, _options: SocketFlags, data: &mut [u8], peek: bool) -> ZxResult { + let mut inner = self.inner.lock(); + let read_size = data.len().min(inner.data.len()); + if peek { + for (i, x) in inner.data.iter().take(read_size).enumerate() { + data[i] = *x; + } + } else { + for (i, x) in inner.data.drain(..read_size).enumerate() { + data[i] = x; + } + }; + Ok(read_size) } /// Get information of the socket. pub fn get_info(&self) -> SocketInfo { - let self_size = self.inner.lock().data.len(); + let inner = self.inner.lock(); + let self_size = inner.data.len(); let peer_size = match self.peer.upgrade() { Some(peer) => peer.inner.lock().data.len(), None => 0, }; + let rx_buf_available = if self.flags.contains(SocketFlags::DATAGRAM) { + *inner.datagram_len.get(0).unwrap_or(&0) + } else { + self_size + }; SocketInfo { - options: 0, + options: self.flags.bits(), padding1: 0, - rx_buf_max: u64::MAX, + rx_buf_max: SOCKET_SIZE as _, rx_buf_size: self_size as _, - rx_buf_available: self_size as _, - tx_buf_max: u64::MAX, + rx_buf_available: rx_buf_available as _, + tx_buf_max: SOCKET_SIZE as _, tx_buf_size: peer_size as _, } } @@ -159,6 +367,9 @@ impl Socket { } pub fn set_read_threshold(&self, threshold: usize) -> ZxResult { + if threshold > SOCKET_SIZE { + return Err(ZxError::INVALID_ARGS); + } let mut inner = self.inner.lock(); inner.read_threshold = threshold; if threshold == 0 { @@ -173,7 +384,17 @@ impl Socket { pub fn set_write_threshold(&self, threshold: usize) -> ZxResult { let peer = self.peer.upgrade().ok_or(ZxError::PEER_CLOSED)?; - peer.inner.lock().write_threshold = threshold; + if threshold > SOCKET_SIZE { + return Err(ZxError::INVALID_ARGS); + } + self.inner.lock().write_threshold = threshold; + if threshold == 0 { + self.base.signal_clear(Signal::SOCKET_WRITE_THRESHOLD); + } else if SOCKET_SIZE - peer.inner.lock().data.len() >= threshold { + self.base.signal_set(Signal::SOCKET_WRITE_THRESHOLD); + } else { + self.base.signal_clear(Signal::SOCKET_WRITE_THRESHOLD); + } Ok(()) } diff --git a/zircon-object/src/object/signal.rs b/zircon-object/src/object/signal.rs index bac1077fe..7538f138e 100644 --- a/zircon-object/src/object/signal.rs +++ b/zircon-object/src/object/signal.rs @@ -18,9 +18,14 @@ bitflags! { const SOCKET_PEER_WRITE_DISABLED = 1 << 4; const SOCKET_WRITE_DISABLED = 1 << 5; + const SOCKET_CONTROL_READABLE = 1 << 6; + const SOCKET_CONTROL_WRITABLE = 1 << 7; + const SCOEKT_ACCEPT = 1 << 8; + const SOCKET_SHARE = 1 << 9; const SOCKET_READ_THRESHOLD = 1 << 10; const SOCKET_WRITE_THRESHOLD = 1 << 11; + const TASK_TERMINATED = Self::SIGNALED.bits; const JOB_TERMINATED = Self::SIGNALED.bits; diff --git a/zircon-syscall/src/socket.rs b/zircon-syscall/src/socket.rs index 20ed38ec6..142da7a4b 100644 --- a/zircon-syscall/src/socket.rs +++ b/zircon-syscall/src/socket.rs @@ -1,4 +1,4 @@ -use {super::*, bitflags::bitflags, zircon_object::ipc::*}; +use {super::*, zircon_object::ipc::Socket, zircon_object::ipc::SocketFlags}; impl Syscall<'_> { pub fn sys_socket_create( @@ -8,11 +8,7 @@ impl Syscall<'_> { mut out1: UserOutPtr, ) -> ZxResult { info!("socket.create: options={:#x?}", options); - if options != 0 { - error!("socket.create: only implemented options=0"); - return Err(ZxError::NOT_SUPPORTED); - } - let (end0, end1) = Socket::create(); + let (end0, end1) = Socket::create(options)?; let proc = self.thread.proc(); let handle0 = proc.add_handle(Handle::new(end0, Rights::DEFAULT_SOCKET)); let handle1 = proc.add_handle(Handle::new(end1, Rights::DEFAULT_SOCKET)); @@ -23,72 +19,70 @@ impl Syscall<'_> { pub fn sys_socket_write( &self, - socket: HandleValue, + handle_value: HandleValue, options: u32, - buffer: UserInPtr, - size: usize, - mut actual_size: UserOutPtr, + user_bytes: UserInPtr, + count: usize, + mut actual_count_ptr: UserOutPtr, ) -> ZxResult { info!( "socket.write: socket={:#x?}, options={:#x?}, buffer={:#x?}, size={:#x?}", - socket, options, buffer, size, + handle_value, options, user_bytes, count, ); - if options != 0 { - unimplemented!(); + if count > 0 && user_bytes.is_null() { + return Err(ZxError::INVALID_ARGS); + } + let options = SocketFlags::from_bits(options).ok_or(ZxError::INVALID_ARGS)?; + if !(options - SocketFlags::SOCKET_CONTROL).is_empty() { + return Err(ZxError::INVALID_ARGS); } let proc = self.thread.proc(); - let socket = proc.get_object_with_rights::(socket, Rights::WRITE)?; - let buffer = buffer.read_array(size)?; - let size = socket.write(&buffer)?; - actual_size.write_if_not_null(size)?; + let socket = proc.get_object_with_rights::(handle_value, Rights::WRITE)?; + let data = user_bytes.read_array(count)?; + let actual_count = socket.write(options, &data, count)?; + actual_count_ptr.write_if_not_null(actual_count)?; Ok(()) } pub fn sys_socket_read( &self, - socket: HandleValue, + handle_value: HandleValue, options: u32, - mut buffer: UserOutPtr, - size: usize, - mut actual_size: UserOutPtr, + mut user_bytes: UserOutPtr, + count: usize, + mut actual_count_ptr: UserOutPtr, ) -> ZxResult { - let options = SocketOptions::from_bits_truncate(options); info!( "socket.read: socket={:#x?}, options={:#x?}, buffer={:#x?}, size={:#x?}", - socket, options, buffer, size, + handle_value, options, user_bytes, count, ); + if count > 0 && user_bytes.is_null() { + return Err(ZxError::INVALID_ARGS); + } + let options = SocketFlags::from_bits(options).ok_or(ZxError::INVALID_ARGS)?; + if !(options - SocketFlags::SOCKET_CONTROL - SocketFlags::SOCKET_PEEK).is_empty() { + return Err(ZxError::INVALID_ARGS); + } let proc = self.thread.proc(); - let socket = proc.get_object_with_rights::(socket, Rights::READ)?; - let peek = options.contains(SocketOptions::PEEK); - let result = socket.read(size, peek)?; - actual_size.write_if_not_null(result.len())?; - buffer.write_array(&result)?; + let socket = proc.get_object_with_rights::(handle_value, Rights::READ)?; + let mut data = vec![0; count]; + let actual_count = socket.read(options, &mut data, count)?; + user_bytes.write_array(&data)?; + actual_count_ptr.write_if_not_null(actual_count)?; Ok(()) } pub fn sys_socket_shutdown(&self, socket: HandleValue, options: u32) -> ZxResult { - let options = SocketOptions::from_bits_truncate(options); + let options = SocketFlags::from_bits_truncate(options); info!( "socket.shutdown: socket={:#x?}, options={:#x?}", socket, options ); let proc = self.thread.proc(); let socket = proc.get_object_with_rights::(socket, Rights::WRITE)?; - let read = options.contains(SocketOptions::SHUTDOWN_READ); - let write = options.contains(SocketOptions::SHUTDOWN_WRITE); + let read = options.contains(SocketFlags::SHUTDOWN_READ); + let write = options.contains(SocketFlags::SHUTDOWN_WRITE); socket.shutdown(read, write)?; Ok(()) } } - -bitflags! { - #[derive(Default)] - struct SocketOptions: u32 { - #[allow(clippy::identity_op)] - const SHUTDOWN_WRITE = 1 << 0; - const SHUTDOWN_READ = 1 << 1; - #[allow(clippy::identity_op)] - const DATAGRAM = 1 << 0; - const PEEK = 1 << 3; - } -}