From ea56fba668dc983d860103fc278b7144a63cd272 Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Fri, 24 Feb 2023 16:38:53 -0800 Subject: [PATCH] Rebase on the new wasi-sockets. (#91) * Rebase on the new wasi-sockets. This switches to using the wasi-sockets wit files from WebAssembly/wasi-sockets#16. Many things are still stubbed out with `todo!()` for now. * Fix compilation on Windows. --- host/src/ip_name_lookup.rs | 51 +++ host/src/lib.rs | 7 + host/src/network.rs | 114 +++++ host/src/tcp.rs | 262 ++++------- host/src/udp.rs | 127 ++++++ wasi-common/cap-std-sync/src/lib.rs | 10 +- wasi-common/cap-std-sync/src/net.rs | 662 +++++++++++----------------- wasi-common/src/connection.rs | 69 --- wasi-common/src/ctx.rs | 4 +- wasi-common/src/lib.rs | 10 +- wasi-common/src/listener.rs | 39 -- wasi-common/src/tcp_listener.rs | 44 -- wasi-common/src/tcp_socket.rs | 54 +++ wasi-common/src/udp_socket.rs | 51 +++ 14 files changed, 757 insertions(+), 747 deletions(-) create mode 100644 host/src/ip_name_lookup.rs create mode 100644 host/src/network.rs create mode 100644 host/src/udp.rs delete mode 100644 wasi-common/src/connection.rs delete mode 100644 wasi-common/src/listener.rs delete mode 100644 wasi-common/src/tcp_listener.rs create mode 100644 wasi-common/src/tcp_socket.rs create mode 100644 wasi-common/src/udp_socket.rs diff --git a/host/src/ip_name_lookup.rs b/host/src/ip_name_lookup.rs new file mode 100644 index 000000000000..a1cc6c155fee --- /dev/null +++ b/host/src/ip_name_lookup.rs @@ -0,0 +1,51 @@ +#![allow(unused_variables)] + +use crate::{ + wasi_ip_name_lookup::{ResolveAddressStream, WasiIpNameLookup}, + wasi_network::{Error, IpAddress, IpAddressFamily, Network}, + wasi_poll::Pollable, + HostResult, WasiCtx, +}; + +#[async_trait::async_trait] +impl WasiIpNameLookup for WasiCtx { + async fn resolve_addresses( + &mut self, + network: Network, + name: String, + address_family: Option, + include_unavailable: bool, + ) -> HostResult { + todo!() + } + + async fn resolve_next_address( + &mut self, + stream: ResolveAddressStream, + ) -> HostResult, Error> { + todo!() + } + + async fn drop_resolve_address_stream( + &mut self, + stream: ResolveAddressStream, + ) -> anyhow::Result<()> { + todo!() + } + + async fn non_blocking(&mut self, stream: ResolveAddressStream) -> HostResult { + todo!() + } + + async fn set_non_blocking( + &mut self, + stream: ResolveAddressStream, + value: bool, + ) -> HostResult<(), Error> { + todo!() + } + + async fn subscribe(&mut self, stream: ResolveAddressStream) -> anyhow::Result { + todo!() + } +} diff --git a/host/src/lib.rs b/host/src/lib.rs index ed4d8dbf6728..04ff88fae8d5 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -3,11 +3,14 @@ mod env; mod exit; mod filesystem; mod io; +mod ip_name_lookup; mod logging; +mod network; mod poll; mod random; mod stderr; mod tcp; +mod udp; pub use wasi_common::{table::Table, WasiCtx}; type HostResult = anyhow::Result>; @@ -33,6 +36,10 @@ pub fn add_to_linker( wasi_io::add_to_linker(l, f)?; wasi_random::add_to_linker(l, f)?; wasi_tcp::add_to_linker(l, f)?; + wasi_udp::add_to_linker(l, f)?; + wasi_ip_name_lookup::add_to_linker(l, f)?; + wasi_default_network::add_to_linker(l, f)?; + wasi_network::add_to_linker(l, f)?; wasi_exit::add_to_linker(l, f)?; wasi_environment::add_to_linker(l, f)?; Ok(()) diff --git a/host/src/network.rs b/host/src/network.rs new file mode 100644 index 000000000000..59a211457e85 --- /dev/null +++ b/host/src/network.rs @@ -0,0 +1,114 @@ +use crate::{ + wasi_default_network::WasiDefaultNetwork, + wasi_network::{Network, WasiNetwork}, + WasiCtx, +}; +use crate::{ + //wasi_network::{IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress}, + wasi_tcp, + wasi_udp, +}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +pub(crate) fn convert(_error: wasi_common::Error) -> anyhow::Error { + todo!("convert wasi-common Error to wasi_network::Error") +} + +#[async_trait::async_trait] +impl WasiNetwork for WasiCtx { + async fn drop_network(&mut self, _network: Network) -> anyhow::Result<()> { + todo!() + } +} + +#[async_trait::async_trait] +impl WasiDefaultNetwork for WasiCtx { + async fn default_network(&mut self) -> anyhow::Result { + todo!() + } +} + +impl From for wasi_tcp::IpSocketAddress { + fn from(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(v4) => Self::Ipv4(v4.into()), + SocketAddr::V6(v6) => Self::Ipv6(v6.into()), + } + } +} + +impl From for wasi_udp::IpSocketAddress { + fn from(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(v4) => Self::Ipv4(v4.into()), + SocketAddr::V6(v6) => Self::Ipv6(v6.into()), + } + } +} + +impl From for wasi_tcp::Ipv4SocketAddress { + fn from(addr: SocketAddrV4) -> Self { + Self { + address: MyIpv4Addr::from(addr.ip()).0, + port: addr.port(), + } + } +} + +impl From for wasi_udp::Ipv4SocketAddress { + fn from(addr: SocketAddrV4) -> Self { + Self { + address: MyIpv4Addr::from(addr.ip()).0, + port: addr.port(), + } + } +} + +impl From for wasi_tcp::Ipv6SocketAddress { + fn from(addr: SocketAddrV6) -> Self { + Self { + address: MyIpv6Addr::from(addr.ip()).0, + port: addr.port(), + flow_info: addr.flowinfo(), + scope_id: addr.scope_id(), + } + } +} + +impl From for wasi_udp::Ipv6SocketAddress { + fn from(addr: SocketAddrV6) -> Self { + Self { + address: MyIpv6Addr::from(addr.ip()).0, + port: addr.port(), + flow_info: addr.flowinfo(), + scope_id: addr.scope_id(), + } + } +} + +// Newtypes to guide conversions. +struct MyIpv4Addr((u8, u8, u8, u8)); +struct MyIpv6Addr((u16, u16, u16, u16, u16, u16, u16, u16)); + +impl From<&Ipv4Addr> for MyIpv4Addr { + fn from(addr: &Ipv4Addr) -> Self { + let octets = addr.octets(); + Self((octets[0], octets[1], octets[2], octets[3])) + } +} + +impl From<&Ipv6Addr> for MyIpv6Addr { + fn from(addr: &Ipv6Addr) -> Self { + let segments = addr.segments(); + Self(( + segments[0], + segments[1], + segments[2], + segments[3], + segments[4], + segments[5], + segments[6], + segments[7], + )) + } +} diff --git a/host/src/tcp.rs b/host/src/tcp.rs index a2e84d671b71..90b25807c189 100644 --- a/host/src/tcp.rs +++ b/host/src/tcp.rs @@ -1,101 +1,29 @@ #![allow(unused_variables)] use crate::{ + network::convert, wasi_io::{InputStream, OutputStream}, - wasi_tcp::{ - Connection, ConnectionFlags, Errno, IoSize, IpSocketAddress, Ipv4SocketAddress, - Ipv6SocketAddress, Listener, ListenerFlags, Network, TcpListener, WasiTcp, - }, + wasi_network::{Error, IpAddressFamily, Network}, + wasi_poll::Pollable, + wasi_tcp::{IpSocketAddress, ShutdownType, TcpSocket, WasiTcp}, HostResult, WasiCtx, }; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::ops::BitAnd; -use wasi_common::listener::TableListenerExt; -use wasi_common::tcp_listener::TableTcpListenerExt; - -/// TODO: Remove once wasmtime #5589 lands. -fn contains + Eq + Copy>(flags: T, flag: T) -> bool { - (flags & flag) == flag -} - -fn convert(error: wasi_common::Error) -> anyhow::Error { - if let Some(errno) = error.downcast_ref() { - use wasi_common::Errno::*; - - match errno { - Acces => Errno::Access, - Again => Errno::Again, - Already => Errno::Already, - Badf => Errno::Badf, - Busy => Errno::Busy, - Ilseq => Errno::Ilseq, - Inprogress => Errno::Inprogress, - Intr => Errno::Intr, - Inval => Errno::Inval, - Io => Errno::Io, - Msgsize => Errno::Msgsize, - Nametoolong => Errno::Nametoolong, - Noent => Errno::Noent, - Nomem => Errno::Nomem, - Nosys => Errno::Nosys, - Notrecoverable => Errno::Notrecoverable, - Notsup => Errno::Notsup, - Overflow => Errno::Overflow, - Perm => Errno::Perm, - Addrinuse => Errno::Addrinuse, - Addrnotavail => Errno::Addrnotavail, - Afnosupport => Errno::Afnosupport, - Connaborted => Errno::ConnectionAborted, - Connrefused => Errno::ConnectionRefused, - Connreset => Errno::ConnectionReset, - Destaddrreq => Errno::Destaddrreq, - Hostunreach => Errno::HostUnreachable, - Isconn => Errno::Isconn, - Multihop => Errno::Multihop, - Netreset => Errno::NetworkReset, - Netdown => Errno::NetworkDown, - Netunreach => Errno::NetworkUnreachable, - Nobufs => Errno::Nobufs, - Noprotoopt => Errno::Noprotoopt, - Timedout => Errno::Timedout, - _ => { - panic!("Unexpected errno: {:?}", errno); - } - } - .into() - } else { - error.into() - } -} +use wasi_common::tcp_socket::TableTcpSocketExt; #[async_trait::async_trait] impl WasiTcp for WasiCtx { - async fn listen( - &mut self, - network: Network, - address: IpSocketAddress, - backlog: Option, - flags: ListenerFlags, - ) -> HostResult { + async fn listen(&mut self, socket: TcpSocket, backlog: Option) -> HostResult<(), Error> { todo!() } async fn accept( &mut self, - listener: Listener, - flags: ConnectionFlags, - ) -> HostResult<(Connection, InputStream, OutputStream), Errno> { + socket: TcpSocket, + ) -> HostResult<(TcpSocket, InputStream, OutputStream), Error> { let table = self.table_mut(); - let l = table.get_listener_mut(listener)?; + let l = table.get_tcp_socket_mut(socket)?; - let nonblocking = contains(flags, ConnectionFlags::NONBLOCK); - - if contains(flags, ConnectionFlags::KEEPALIVE) || contains(flags, ConnectionFlags::NODELAY) - { - todo!() - } - - let (connection, input_stream, output_stream) = l.accept(nonblocking).await?; + let (connection, input_stream, output_stream, _addr) = l.accept(false).await?; let connection = table.push(Box::new(connection)).map_err(convert)?; let input_stream = table.push(Box::new(input_stream)).map_err(convert)?; @@ -104,170 +32,132 @@ impl WasiTcp for WasiCtx { Ok(Ok((connection, input_stream, output_stream))) } - async fn accept_tcp( + async fn connect( &mut self, - listener: TcpListener, - flags: ConnectionFlags, - ) -> HostResult<(Connection, InputStream, OutputStream, IpSocketAddress), Errno> { - let table = self.table_mut(); - let l = table.get_tcp_listener_mut(listener)?; - - let nonblocking = contains(flags, ConnectionFlags::NONBLOCK); + socket: TcpSocket, + remote_address: IpSocketAddress, + ) -> HostResult<(InputStream, OutputStream), Error> { + todo!() + } - if contains(flags, ConnectionFlags::KEEPALIVE) || contains(flags, ConnectionFlags::NODELAY) - { - todo!() - } + async fn receive_buffer_size(&mut self, socket: TcpSocket) -> HostResult { + todo!() + } - let (connection, input_stream, output_stream, addr) = l.accept(nonblocking).await?; + async fn set_receive_buffer_size( + &mut self, + socket: TcpSocket, + value: u64, + ) -> HostResult<(), Error> { + todo!() + } - let connection = table.push(Box::new(connection)).map_err(convert)?; - let input_stream = table.push(Box::new(input_stream)).map_err(convert)?; - let output_stream = table.push(Box::new(output_stream)).map_err(convert)?; + async fn send_buffer_size(&mut self, socket: TcpSocket) -> HostResult { + todo!() + } - Ok(Ok((connection, input_stream, output_stream, addr.into()))) + async fn set_send_buffer_size( + &mut self, + socket: TcpSocket, + value: u64, + ) -> HostResult<(), Error> { + todo!() } - async fn connect( + async fn create_tcp_socket( &mut self, network: Network, + address_family: IpAddressFamily, + ) -> HostResult { + todo!() + } + + async fn bind( + &mut self, + this: TcpSocket, local_address: IpSocketAddress, - remote_address: IpSocketAddress, - flags: ConnectionFlags, - ) -> HostResult<(Connection, InputStream, OutputStream), Errno> { + ) -> HostResult<(), Error> { todo!() } - async fn send(&mut self, connection: Connection, bytes: Vec) -> HostResult { + async fn local_address(&mut self, this: TcpSocket) -> HostResult { todo!() } - async fn receive( + async fn shutdown( &mut self, - connection: Connection, - length: IoSize, - ) -> HostResult<(Vec, bool), Errno> { + this: TcpSocket, + shutdown_type: ShutdownType, + ) -> HostResult<(), Error> { todo!() } - async fn is_connected(&mut self, connection: Connection) -> anyhow::Result { - // This should ultimately call `getpeername` and test whether it - // gets a `ENOTCONN` error indicating not-connected. + async fn remote_address(&mut self, this: TcpSocket) -> HostResult { todo!() } - async fn get_flags(&mut self, connection: Connection) -> HostResult { + async fn keep_alive(&mut self, this: TcpSocket) -> HostResult { todo!() } - async fn set_flags( - &mut self, - connection: Connection, - flags: ConnectionFlags, - ) -> HostResult<(), Errno> { + async fn set_keep_alive(&mut self, this: TcpSocket, value: bool) -> HostResult<(), Error> { todo!() } - async fn get_receive_buffer_size( - &mut self, - connection: Connection, - ) -> HostResult { + async fn no_delay(&mut self, this: TcpSocket) -> HostResult { todo!() } - async fn set_receive_buffer_size( - &mut self, - connection: Connection, - value: IoSize, - ) -> HostResult<(), Errno> { + async fn set_no_delay(&mut self, this: TcpSocket, value: bool) -> HostResult<(), Error> { todo!() } - async fn get_send_buffer_size(&mut self, connection: Connection) -> HostResult { + async fn address_family(&mut self, this: TcpSocket) -> anyhow::Result { todo!() } - async fn set_send_buffer_size( - &mut self, - connection: Connection, - value: IoSize, - ) -> HostResult<(), Errno> { + async fn unicast_hop_limit(&mut self, this: TcpSocket) -> HostResult { todo!() } - async fn bytes_readable(&mut self, socket: Connection) -> HostResult<(IoSize, bool), Errno> { - drop(socket); + async fn set_unicast_hop_limit(&mut self, this: TcpSocket, value: u8) -> HostResult<(), Error> { todo!() } - async fn bytes_writable(&mut self, socket: Connection) -> HostResult<(IoSize, bool), Errno> { - drop(socket); + async fn ipv6_only(&mut self, this: TcpSocket) -> HostResult { todo!() } - async fn close_tcp_listener(&mut self, listener: TcpListener) -> anyhow::Result<()> { - drop(listener); + async fn set_ipv6_only(&mut self, this: TcpSocket, value: bool) -> HostResult<(), Error> { todo!() } - async fn close_connection(&mut self, connection: Connection) -> anyhow::Result<()> { - drop(connection); + async fn non_blocking(&mut self, this: TcpSocket) -> HostResult { todo!() } -} -impl From for IpSocketAddress { - fn from(addr: SocketAddr) -> Self { - match addr { - SocketAddr::V4(v4) => Self::Ipv4(v4.into()), - SocketAddr::V6(v6) => Self::Ipv6(v6.into()), - } + async fn set_non_blocking(&mut self, this: TcpSocket, value: bool) -> HostResult<(), Error> { + todo!() } -} -impl From for Ipv4SocketAddress { - fn from(addr: SocketAddrV4) -> Self { - Self { - address: MyIpv4Addr::from(addr.ip()).0, - port: addr.port(), - } + async fn subscribe(&mut self, this: TcpSocket) -> anyhow::Result { + todo!() } -} -impl From for Ipv6SocketAddress { - fn from(addr: SocketAddrV6) -> Self { - Self { - address: MyIpv6Addr::from(addr.ip()).0, - port: addr.port(), - flow_info: addr.flowinfo(), - scope_id: addr.scope_id(), - } + /* TODO: Revisit after https://github.com/WebAssembly/wasi-sockets/issues/17 + async fn bytes_readable(&mut self, socket: Connection) -> HostResult<(u64, bool), Error> { + drop(socket); + todo!() } -} - -// Newtypes to guide conversions. -struct MyIpv4Addr((u8, u8, u8, u8)); -struct MyIpv6Addr((u16, u16, u16, u16, u16, u16, u16, u16)); -impl From<&Ipv4Addr> for MyIpv4Addr { - fn from(addr: &Ipv4Addr) -> Self { - let octets = addr.octets(); - Self((octets[0], octets[1], octets[2], octets[3])) + async fn bytes_writable(&mut self, socket: Connection) -> HostResult<(u64, bool), Error> { + drop(socket); + todo!() } -} + */ -impl From<&Ipv6Addr> for MyIpv6Addr { - fn from(addr: &Ipv6Addr) -> Self { - let segments = addr.segments(); - Self(( - segments[0], - segments[1], - segments[2], - segments[3], - segments[4], - segments[5], - segments[6], - segments[7], - )) + async fn drop_tcp_socket(&mut self, socket: TcpSocket) -> anyhow::Result<()> { + drop(socket); + todo!() } } diff --git a/host/src/udp.rs b/host/src/udp.rs new file mode 100644 index 000000000000..c36d76bcda1e --- /dev/null +++ b/host/src/udp.rs @@ -0,0 +1,127 @@ +#![allow(unused_variables)] + +use crate::{ + wasi_network::{Error, IpAddressFamily, Network}, + wasi_poll::Pollable, + wasi_udp::{Datagram, IpSocketAddress, UdpSocket, WasiUdp}, + HostResult, WasiCtx, +}; +use wasi_common::udp_socket::TableUdpSocketExt; + +#[async_trait::async_trait] +impl WasiUdp for WasiCtx { + async fn connect( + &mut self, + network: Network, + remote_address: IpSocketAddress, + ) -> HostResult<(), Error> { + todo!() + } + + async fn send(&mut self, socket: UdpSocket, datagram: Datagram) -> HostResult<(), Error> { + todo!() + } + + async fn receive(&mut self, socket: UdpSocket) -> HostResult { + todo!() + } + + async fn receive_buffer_size(&mut self, socket: UdpSocket) -> HostResult { + todo!() + } + + async fn set_receive_buffer_size( + &mut self, + socket: UdpSocket, + value: u64, + ) -> HostResult<(), Error> { + todo!() + } + + async fn send_buffer_size(&mut self, socket: UdpSocket) -> HostResult { + todo!() + } + + async fn set_send_buffer_size( + &mut self, + socket: UdpSocket, + value: u64, + ) -> HostResult<(), Error> { + todo!() + } + + async fn create_udp_socket( + &mut self, + network: Network, + address_family: IpAddressFamily, + ) -> HostResult { + todo!() + } + + async fn bind( + &mut self, + this: UdpSocket, + local_address: IpSocketAddress, + ) -> HostResult<(), Error> { + todo!() + } + + async fn local_address(&mut self, this: UdpSocket) -> HostResult { + todo!() + } + + async fn remote_address(&mut self, this: UdpSocket) -> HostResult { + todo!() + } + + async fn address_family(&mut self, this: UdpSocket) -> anyhow::Result { + todo!() + } + + async fn unicast_hop_limit(&mut self, this: UdpSocket) -> HostResult { + todo!() + } + + async fn set_unicast_hop_limit(&mut self, this: UdpSocket, value: u8) -> HostResult<(), Error> { + todo!() + } + + async fn ipv6_only(&mut self, this: UdpSocket) -> HostResult { + todo!() + } + + async fn set_ipv6_only(&mut self, this: UdpSocket, value: bool) -> HostResult<(), Error> { + todo!() + } + + async fn non_blocking(&mut self, this: UdpSocket) -> HostResult { + todo!() + } + + async fn set_non_blocking(&mut self, this: UdpSocket, value: bool) -> HostResult<(), Error> { + let this = self.table.get_udp_socket_mut(this)?; + this.set_nonblocking(value)?; + Ok(Ok(())) + } + + async fn subscribe(&mut self, this: UdpSocket) -> anyhow::Result { + todo!() + } + + /* TODO: Revisit after https://github.com/WebAssembly/wasi-sockets/issues/17 + async fn bytes_readable(&mut self, socket: UdpSocket) -> HostResult<(u64, bool), Error> { + drop(socket); + todo!() + } + + async fn bytes_writable(&mut self, socket: UdpSocket) -> HostResult<(u64, bool), Error> { + drop(socket); + todo!() + } + */ + + async fn drop_udp_socket(&mut self, socket: UdpSocket) -> anyhow::Result<()> { + drop(socket); + todo!() + } +} diff --git a/wasi-common/cap-std-sync/src/lib.rs b/wasi-common/cap-std-sync/src/lib.rs index af35c1b323a3..0bd8413221a3 100644 --- a/wasi-common/cap-std-sync/src/lib.rs +++ b/wasi-common/cap-std-sync/src/lib.rs @@ -45,12 +45,12 @@ pub use cap_std::net::TcpListener; pub use clocks::clocks_ctx; pub use sched::sched_ctx; -use crate::net::Listener; +use crate::net::TcpSocket; use cap_rand::{Rng, RngCore, SeedableRng}; use wasi_common::{ - listener::WasiListener, stream::{InputStream, OutputStream}, table::Table, + tcp_socket::WasiTcpSocket, WasiCtx, }; @@ -94,9 +94,9 @@ impl WasiCtxBuilder { self.0.insert_dir(fd, dir); self } - pub fn preopened_listener(mut self, fd: u32, listener: impl Into) -> Self { - let listener: Listener = listener.into(); - let listener: Box = listener.into(); + pub fn preopened_listener(mut self, fd: u32, listener: impl Into) -> Self { + let listener: TcpSocket = listener.into(); + let listener: Box = Box::new(TcpSocket::from(listener)); self.0.insert_listener(fd, listener); self diff --git a/wasi-common/cap-std-sync/src/net.rs b/wasi-common/cap-std-sync/src/net.rs index c36ff5e7fd81..9f21de8c2d0a 100644 --- a/wasi-common/cap-std-sync/src/net.rs +++ b/wasi-common/cap-std-sync/src/net.rs @@ -1,3 +1,4 @@ +use cap_std::net::{TcpListener, TcpStream}; use io_extras::borrowed::BorrowedReadable; #[cfg(windows)] use io_extras::os::windows::{AsHandleOrSocket, BorrowedHandleOrSocket}; @@ -6,6 +7,7 @@ use io_lifetimes::AsSocketlike; use io_lifetimes::{AsFd, BorrowedFd}; #[cfg(windows)] use io_lifetimes::{AsSocket, BorrowedSocket}; +use rustix::fd::OwnedFd; use std::any::Any; use std::convert::TryInto; use std::io::{self, Read, Write}; @@ -15,471 +17,339 @@ use system_interface::io::IoExt; use system_interface::io::IsReadWrite; use system_interface::io::ReadReady; use wasi_common::{ - connection::{RiFlags, RoFlags, SdFlags, SiFlags, WasiConnection}, - listener::WasiListener, stream::{InputStream, OutputStream}, - tcp_listener::WasiTcpListener, + tcp_socket::{SdFlags, WasiTcpSocket}, + udp_socket::{RiFlags, RoFlags, WasiUdpSocket}, Error, ErrorExt, }; -pub enum Listener { - TcpListener(cap_std::net::TcpListener), - #[cfg(unix)] - UnixListener(cap_std::os::unix::net::UnixListener), -} +pub struct TcpSocket(Arc); +pub struct UdpSocket(Arc); -pub enum Connection { - TcpStream(cap_std::net::TcpStream), - #[cfg(unix)] - UnixStream(cap_std::os::unix::net::UnixStream), -} +impl TcpSocket { + pub fn new(owned: OwnedFd) -> Self { + Self(Arc::new(owned)) + } -impl From for Listener { - fn from(listener: cap_std::net::TcpListener) -> Self { - Self::TcpListener(listener) + pub fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) } } -impl From for Connection { - fn from(stream: cap_std::net::TcpStream) -> Self { - Self::TcpStream(stream) +impl UdpSocket { + pub fn new(owned: OwnedFd) -> Self { + Self(Arc::new(owned)) } -} -#[cfg(unix)] -impl From for Listener { - fn from(listener: cap_std::os::unix::net::UnixListener) -> Self { - Self::UnixListener(listener) + pub fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) } } -#[cfg(unix)] -impl From for Connection { - fn from(stream: cap_std::os::unix::net::UnixStream) -> Self { - Self::UnixStream(stream) +#[async_trait::async_trait] +impl WasiTcpSocket for TcpSocket { + fn as_any(&self) -> &dyn Any { + self } -} -#[cfg(unix)] -impl From for Box { - fn from(listener: Listener) -> Self { - match listener { - Listener::TcpListener(l) => Box::new(crate::net::TcpListener::from_cap_std(l)), - Listener::UnixListener(l) => Box::new(crate::net::UnixListener::from_cap_std(l)), - } + async fn accept( + &mut self, + nonblocking: bool, + ) -> Result< + ( + Box, + Box, + Box, + SocketAddr, + ), + Error, + > { + let (connection, addr) = self.0.as_socketlike_view::().accept()?; + connection.set_nonblocking(nonblocking)?; + let connection = TcpSocket::new(connection.into()); + let input_stream = connection.clone(); + let output_stream = connection.clone(); + Ok(( + Box::new(connection), + Box::new(input_stream), + Box::new(output_stream), + addr, + )) } -} -#[cfg(windows)] -impl From for Box { - fn from(listener: Listener) -> Self { - match listener { - Listener::TcpListener(l) => Box::new(crate::net::TcpListener::from_cap_std(l)), - } + fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { + self.0 + .as_socketlike_view::() + .set_nonblocking(flag)?; + Ok(()) } -} -#[cfg(unix)] -impl From for Box { - fn from(listener: Connection) -> Self { - match listener { - Connection::TcpStream(l) => Box::new(crate::net::TcpStream::from_cap_std(l)), - Connection::UnixStream(l) => Box::new(crate::net::UnixStream::from_cap_std(l)), + async fn sock_shutdown(&mut self, how: SdFlags) -> Result<(), Error> { + let how = if how == SdFlags::READ | SdFlags::WRITE { + cap_std::net::Shutdown::Both + } else if how == SdFlags::READ { + cap_std::net::Shutdown::Read + } else if how == SdFlags::WRITE { + cap_std::net::Shutdown::Write + } else { + return Err(Error::invalid_argument()); + }; + self.0.as_socketlike_view::().shutdown(how)?; + Ok(()) + } + + async fn readable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.0 { + Ok(()) + } else { + Err(Error::badf()) } } -} -#[cfg(windows)] -impl From for Box { - fn from(listener: Connection) -> Self { - match listener { - Connection::TcpStream(l) => Box::new(crate::net::TcpStream::from_cap_std(l)), + async fn writable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.1 { + Ok(()) + } else { + Err(Error::badf()) } } } -macro_rules! wasi_listener_impl { - ($ty:ty, $stream:ty) => { - #[async_trait::async_trait] - impl WasiListener for $ty { - fn as_any(&self) -> &dyn Any { - self - } - - async fn accept( - &mut self, - nonblocking: bool, - ) -> Result< - ( - Box, - Box, - Box, - ), - Error, - > { - let (stream, _) = self.0.accept()?; - stream.set_nonblocking(nonblocking)?; - let connection = <$stream>::from_cap_std(stream); - let input_stream = connection.clone(); - let output_stream = connection.clone(); - Ok(( - Box::new(connection), - Box::new(input_stream), - Box::new(output_stream), - )) - } +#[async_trait::async_trait] +impl WasiUdpSocket for UdpSocket { + fn as_any(&self) -> &dyn Any { + self + } - fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { - self.0.set_nonblocking(flag)?; - Ok(()) - } - } + fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { + self.0 + .as_socketlike_view::() + .set_nonblocking(flag)?; + Ok(()) + } - #[cfg(windows)] - impl AsSocket for $ty { - #[inline] - fn as_socket(&self) -> BorrowedSocket<'_> { - self.0.as_socket() - } + async fn sock_recv<'a>( + &mut self, + ri_data: &mut [io::IoSliceMut<'a>], + ri_flags: RiFlags, + ) -> Result<(u64, RoFlags), Error> { + if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() { + return Err(Error::not_supported()); } - #[cfg(windows)] - impl AsHandleOrSocket for $ty { - #[inline] - fn as_handle_or_socket(&self) -> BorrowedHandleOrSocket { - self.0.as_handle_or_socket() + if ri_flags.contains(RiFlags::RECV_PEEK) { + if let Some(first) = ri_data.iter_mut().next() { + let n = self.0.as_socketlike_view::().peek(first)?; + return Ok((n as u64, RoFlags::empty())); + } else { + return Ok((0, RoFlags::empty())); } } - #[cfg(unix)] - impl AsFd for $ty { - fn as_fd(&self) -> BorrowedFd<'_> { - self.0.as_fd() - } + if ri_flags.contains(RiFlags::RECV_WAITALL) { + let n: usize = ri_data.iter().map(|buf| buf.len()).sum(); + self.0 + .as_socketlike_view::() + .read_exact_vectored(ri_data)?; + return Ok((n as u64, RoFlags::empty())); } - }; -} -macro_rules! wasi_tcp_listener_impl { - ($ty:ty, $stream:ty) => { - #[async_trait::async_trait] - impl WasiTcpListener for $ty { - fn as_any(&self) -> &dyn Any { - self - } - - async fn accept( - &mut self, - nonblocking: bool, - ) -> Result< - ( - Box, - Box, - Box, - SocketAddr, - ), - Error, - > { - let (stream, addr) = self.0.accept()?; - stream.set_nonblocking(nonblocking)?; - let connection = <$stream>::from_cap_std(stream); - let input_stream = connection.clone(); - let output_stream = connection.clone(); - Ok(( - Box::new(connection), - Box::new(input_stream), - Box::new(output_stream), - addr, - )) - } + let n = self + .0 + .as_socketlike_view::() + .read_vectored(ri_data)?; + Ok((n as u64, RoFlags::empty())) + } - fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { - self.0.set_nonblocking(flag)?; - Ok(()) - } + async fn sock_send<'a>(&mut self, si_data: &[io::IoSlice<'a>]) -> Result { + let n = self + .0 + .as_socketlike_view::() + .write_vectored(si_data)?; + Ok(n as u64) + } - fn into_listener(self) -> Box { - Box::new(self) - } + async fn readable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.0 { + Ok(()) + } else { + Err(Error::badf()) } - }; -} - -pub struct TcpListener(cap_std::net::TcpListener); - -impl TcpListener { - pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self { - TcpListener(cap_std) } -} -wasi_listener_impl!(TcpListener, TcpStream); -wasi_tcp_listener_impl!(TcpListener, TcpStream); - -#[cfg(unix)] -pub struct UnixListener(cap_std::os::unix::net::UnixListener); -#[cfg(unix)] -impl UnixListener { - pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self { - UnixListener(cap_std) + async fn writable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.1 { + Ok(()) + } else { + Err(Error::badf()) + } } } -#[cfg(unix)] -wasi_listener_impl!(UnixListener, UnixStream); - -macro_rules! wasi_stream_write_impl { - ($ty:ty, $std_ty:ty) => { - #[async_trait::async_trait] - impl WasiConnection for $ty { - fn as_any(&self) -> &dyn Any { - self - } - - async fn sock_recv<'a>( - &mut self, - ri_data: &mut [io::IoSliceMut<'a>], - ri_flags: RiFlags, - ) -> Result<(u64, RoFlags), Error> { - if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() { - return Err(Error::not_supported()); - } - - if ri_flags.contains(RiFlags::RECV_PEEK) { - if let Some(first) = ri_data.iter_mut().next() { - let n = self.0.peek(first)?; - return Ok((n as u64, RoFlags::empty())); - } else { - return Ok((0, RoFlags::empty())); - } - } - - if ri_flags.contains(RiFlags::RECV_WAITALL) { - let n: usize = ri_data.iter().map(|buf| buf.len()).sum(); - self.0.read_exact_vectored(ri_data)?; - return Ok((n as u64, RoFlags::empty())); - } - - let n = self.0.read_vectored(ri_data)?; - Ok((n as u64, RoFlags::empty())) - } - - async fn sock_send<'a>( - &mut self, - si_data: &[io::IoSlice<'a>], - si_flags: SiFlags, - ) -> Result { - if si_flags != SiFlags::empty() { - return Err(Error::not_supported()); - } - - let n = self.0.write_vectored(si_data)?; - Ok(n as u64) - } - - async fn sock_shutdown(&mut self, how: SdFlags) -> Result<(), Error> { - let how = if how == SdFlags::RD | SdFlags::WR { - cap_std::net::Shutdown::Both - } else if how == SdFlags::RD { - cap_std::net::Shutdown::Read - } else if how == SdFlags::WR { - cap_std::net::Shutdown::Write - } else { - return Err(Error::invalid_argument()); - }; - self.0.shutdown(how)?; - Ok(()) - } - - fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error> { - self.0.set_nonblocking(flag)?; - Ok(()) - } +#[async_trait::async_trait] +impl InputStream for TcpSocket { + fn as_any(&self) -> &dyn Any { + self + } + #[cfg(unix)] + fn pollable_read(&self) -> Option { + Some(self.0.as_fd()) + } - async fn readable(&self) -> Result<(), Error> { - if is_read_write(&*self.0)?.0 { - Ok(()) - } else { - Err(Error::badf()) - } - } + #[cfg(windows)] + fn pollable_read(&self) -> Option { + Some(BorrowedHandleOrSocket::from_socket(self.0.as_socket())) + } - async fn writable(&self) -> Result<(), Error> { - if is_read_write(&*self.0)?.1 { - Ok(()) - } else { - Err(Error::badf()) - } - } + async fn read(&mut self, buf: &mut [u8]) -> Result<(u64, bool), Error> { + match Read::read(&mut &*self.as_socketlike_view::(), buf) { + Ok(0) => Ok((0, true)), + Ok(n) => Ok((n as u64, false)), + Err(err) if err.kind() == io::ErrorKind::Interrupted => Ok((0, false)), + Err(err) => Err(err.into()), } + } + async fn read_vectored<'a>( + &mut self, + bufs: &mut [io::IoSliceMut<'a>], + ) -> Result<(u64, bool), Error> { + match Read::read_vectored(&mut &*self.as_socketlike_view::(), bufs) { + Ok(0) => Ok((0, true)), + Ok(n) => Ok((n as u64, false)), + Err(err) if err.kind() == io::ErrorKind::Interrupted => Ok((0, false)), + Err(err) => Err(err.into()), + } + } + #[cfg(can_vector)] + fn is_read_vectored(&self) -> bool { + Read::is_read_vectored(&mut &*self.as_socketlike_view::()) + } - #[async_trait::async_trait] - impl InputStream for $ty { - fn as_any(&self) -> &dyn Any { - self - } - #[cfg(unix)] - fn pollable_read(&self) -> Option { - Some(self.0.as_fd()) - } - - #[cfg(windows)] - fn pollable_read(&self) -> Option { - Some(self.0.as_handle_or_socket()) - } - - async fn read(&mut self, buf: &mut [u8]) -> Result<(u64, bool), Error> { - match Read::read(&mut &*self.as_socketlike_view::<$std_ty>(), buf) { - Ok(0) => Ok((0, true)), - Ok(n) => Ok((n as u64, false)), - Err(err) if err.kind() == io::ErrorKind::Interrupted => Ok((0, false)), - Err(err) => Err(err.into()), - } - } - async fn read_vectored<'a>( - &mut self, - bufs: &mut [io::IoSliceMut<'a>], - ) -> Result<(u64, bool), Error> { - match Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs) { - Ok(0) => Ok((0, true)), - Ok(n) => Ok((n as u64, false)), - Err(err) if err.kind() == io::ErrorKind::Interrupted => Ok((0, false)), - Err(err) => Err(err.into()), - } - } - #[cfg(can_vector)] - fn is_read_vectored(&self) -> bool { - Read::is_read_vectored(&mut &*self.as_socketlike_view::<$std_ty>()) - } - - async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { - let num = io::copy(&mut io::Read::take(&*self.0, nelem), &mut io::sink())?; - Ok((num, num < nelem)) - } + async fn skip(&mut self, nelem: u64) -> Result<(u64, bool), Error> { + let num = io::copy( + &mut io::Read::take(&*self.0.as_socketlike_view::(), nelem), + &mut io::sink(), + )?; + Ok((num, num < nelem)) + } - async fn num_ready_bytes(&self) -> Result { - let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?; - Ok(val) - } + async fn num_ready_bytes(&self) -> Result { + let val = self.as_socketlike_view::().num_ready_bytes()?; + Ok(val) + } - async fn readable(&self) -> Result<(), Error> { - if is_read_write(&*self.0)?.0 { - Ok(()) - } else { - Err(Error::badf()) - } - } + async fn readable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.0 { + Ok(()) + } else { + Err(Error::badf()) } - #[async_trait::async_trait] - impl OutputStream for $ty { - fn as_any(&self) -> &dyn Any { - self - } + } +} - #[cfg(unix)] - fn pollable_write(&self) -> Option { - Some(self.0.as_fd()) - } +#[async_trait::async_trait] +impl OutputStream for TcpSocket { + fn as_any(&self) -> &dyn Any { + self + } - #[cfg(windows)] - fn pollable_write(&self) -> Option { - Some(self.0.as_handle_or_socket()) - } + #[cfg(unix)] + fn pollable_write(&self) -> Option { + Some(self.0.as_fd()) + } - async fn write(&mut self, buf: &[u8]) -> Result { - let n = Write::write(&mut &*self.as_socketlike_view::<$std_ty>(), buf)?; - Ok(n.try_into()?) - } - async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { - let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?; - Ok(n.try_into()?) - } - #[cfg(can_vector)] - fn is_write_vectored(&self) -> bool { - Write::is_write_vectored(&mut &*self.as_socketlike_view::<$std_ty>()) - } - async fn splice( - &mut self, - src: &mut dyn InputStream, - nelem: u64, - ) -> Result<(u64, bool), Error> { - if let Some(readable) = src.pollable_read() { - let num = io::copy( - &mut io::Read::take(BorrowedReadable::borrow(readable), nelem), - &mut &*self.0, - )?; - Ok((num, num < nelem)) - } else { - OutputStream::splice(self, src, nelem).await - } - } - async fn write_zeroes(&mut self, nelem: u64) -> Result { - let num = io::copy(&mut io::Read::take(io::repeat(0), nelem), &mut &*self.0)?; - Ok(num) - } - async fn writable(&self) -> Result<(), Error> { - if is_read_write(&*self.0)?.1 { - Ok(()) - } else { - Err(Error::badf()) - } - } - } - #[cfg(unix)] - impl AsFd for $ty { - fn as_fd(&self) -> BorrowedFd<'_> { - self.0.as_fd() - } - } + #[cfg(windows)] + fn pollable_write(&self) -> Option { + Some(BorrowedHandleOrSocket::from_socket(self.0.as_socket())) + } - #[cfg(windows)] - impl AsSocket for $ty { - /// Borrows the socket. - fn as_socket(&self) -> BorrowedSocket<'_> { - self.0.as_socket() - } + async fn write(&mut self, buf: &[u8]) -> Result { + let n = Write::write(&mut &*self.as_socketlike_view::(), buf)?; + Ok(n.try_into()?) + } + async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { + let n = Write::write_vectored(&mut &*self.as_socketlike_view::(), bufs)?; + Ok(n.try_into()?) + } + #[cfg(can_vector)] + fn is_write_vectored(&self) -> bool { + Write::is_write_vectored(&mut &*self.as_socketlike_view::()) + } + async fn splice( + &mut self, + src: &mut dyn InputStream, + nelem: u64, + ) -> Result<(u64, bool), Error> { + if let Some(readable) = src.pollable_read() { + let num = io::copy( + &mut io::Read::take(BorrowedReadable::borrow(readable), nelem), + &mut &*self.0.as_socketlike_view::(), + )?; + Ok((num, num < nelem)) + } else { + OutputStream::splice(self, src, nelem).await } - - #[cfg(windows)] - impl AsHandleOrSocket for TcpStream { - #[inline] - fn as_handle_or_socket(&self) -> BorrowedHandleOrSocket { - self.0.as_handle_or_socket() - } + } + async fn write_zeroes(&mut self, nelem: u64) -> Result { + let num = io::copy( + &mut io::Read::take(io::repeat(0), nelem), + &mut &*self.0.as_socketlike_view::(), + )?; + Ok(num) + } + async fn writable(&self) -> Result<(), Error> { + if is_read_write(&*self.0)?.1 { + Ok(()) + } else { + Err(Error::badf()) } - }; -} - -pub struct TcpStream(Arc); - -impl TcpStream { - pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self { - Self(Arc::new(socket)) } +} - pub fn clone(&self) -> Self { - Self(Arc::clone(&self.0)) +#[cfg(unix)] +impl AsFd for TcpSocket { + fn as_fd(&self) -> BorrowedFd<'_> { + self.0.as_fd() } } -wasi_stream_write_impl!(TcpStream, std::net::TcpStream); - #[cfg(unix)] -pub struct UnixStream(Arc); +impl AsFd for UdpSocket { + fn as_fd(&self) -> BorrowedFd<'_> { + self.0.as_fd() + } +} -#[cfg(unix)] -impl UnixStream { - pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self { - Self(Arc::new(socket)) +#[cfg(windows)] +impl AsSocket for TcpSocket { + /// Borrows the socket. + fn as_socket(&self) -> BorrowedSocket<'_> { + self.0.as_socket() } +} - pub fn clone(&self) -> Self { - Self(Arc::clone(&self.0)) +#[cfg(windows)] +impl AsHandleOrSocket for TcpSocket { + #[inline] + fn as_handle_or_socket(&self) -> BorrowedHandleOrSocket { + BorrowedHandleOrSocket::from_socket(self.0.as_socket()) + } +} +#[cfg(windows)] +impl AsSocket for UdpSocket { + /// Borrows the socket. + fn as_socket(&self) -> BorrowedSocket<'_> { + self.0.as_socket() } } -#[cfg(unix)] -wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream); +#[cfg(windows)] +impl AsHandleOrSocket for UdpSocket { + #[inline] + fn as_handle_or_socket(&self) -> BorrowedHandleOrSocket { + BorrowedHandleOrSocket::from_socket(self.0.as_socket()) + } +} /// Return the file-descriptor flags for a given file-like object. /// diff --git a/wasi-common/src/connection.rs b/wasi-common/src/connection.rs deleted file mode 100644 index 1c41f704ea3f..000000000000 --- a/wasi-common/src/connection.rs +++ /dev/null @@ -1,69 +0,0 @@ -//! Socket connections. - -use crate::Error; -use bitflags::bitflags; -use std::any::Any; - -/// A socket connection. -#[async_trait::async_trait] -pub trait WasiConnection: Send + Sync { - fn as_any(&self) -> &dyn Any; - - async fn sock_recv<'a>( - &mut self, - ri_data: &mut [std::io::IoSliceMut<'a>], - ri_flags: RiFlags, - ) -> Result<(u64, RoFlags), Error>; - - async fn sock_send<'a>( - &mut self, - si_data: &[std::io::IoSlice<'a>], - si_flags: SiFlags, - ) -> Result; - - async fn sock_shutdown(&mut self, _how: SdFlags) -> Result<(), Error>; - - fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>; - - async fn readable(&self) -> Result<(), Error>; - - async fn writable(&self) -> Result<(), Error>; -} - -bitflags! { - pub struct SdFlags: u32 { - const RD = 0b1; - const WR = 0b10; - } -} - -bitflags! { - pub struct SiFlags: u32 { - } -} - -bitflags! { - pub struct RiFlags: u32 { - const RECV_PEEK = 0b1; - const RECV_WAITALL = 0b10; - } -} - -bitflags! { - pub struct RoFlags: u32 { - const RECV_DATA_TRUNCATED = 0b1; - } -} - -pub trait TableConnectionExt { - fn get_connection(&self, fd: u32) -> Result<&dyn WasiConnection, Error>; - fn get_connection_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; -} -impl TableConnectionExt for crate::table::Table { - fn get_connection(&self, fd: u32) -> Result<&dyn WasiConnection, Error> { - self.get::>(fd).map(|f| f.as_ref()) - } - fn get_connection_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { - self.get_mut::>(fd) - } -} diff --git a/wasi-common/src/ctx.rs b/wasi-common/src/ctx.rs index b538e53c81b9..d0c0b86b7cd9 100644 --- a/wasi-common/src/ctx.rs +++ b/wasi-common/src/ctx.rs @@ -1,10 +1,10 @@ use crate::clocks::WasiClocks; use crate::dir::WasiDir; use crate::file::WasiFile; -use crate::listener::WasiListener; use crate::sched::WasiSched; use crate::stream::{InputStream, OutputStream}; use crate::table::Table; +use crate::tcp_socket::WasiTcpSocket; use crate::Error; use cap_rand::RngCore; @@ -50,7 +50,7 @@ impl WasiCtx { self.table_mut().insert_at(fd, Box::new(stream)); } - pub fn insert_listener(&mut self, fd: u32, listener: Box) { + pub fn insert_listener(&mut self, fd: u32, listener: Box) { self.table_mut().insert_at(fd, Box::new(listener)); } diff --git a/wasi-common/src/lib.rs b/wasi-common/src/lib.rs index c3d70f7d7a6f..9f7763ca21cd 100644 --- a/wasi-common/src/lib.rs +++ b/wasi-common/src/lib.rs @@ -52,29 +52,27 @@ //! `wasi_cap_std_sync::WasiCtxBuilder::new()` function uses this public //! interface to plug in its own implementations of each of these resources. pub mod clocks; -pub mod connection; mod ctx; pub mod dir; mod error; pub mod file; -pub mod listener; pub mod pipe; pub mod random; pub mod sched; pub mod stream; pub mod table; -pub mod tcp_listener; +pub mod tcp_socket; +pub mod udp_socket; pub use cap_fs_ext::SystemTimeSpec; pub use cap_rand::RngCore; pub use clocks::{WasiClocks, WasiMonotonicClock, WasiWallClock}; -pub use connection::WasiConnection; pub use ctx::WasiCtx; pub use dir::WasiDir; pub use error::{Errno, Error, ErrorExt, I32Exit}; pub use file::WasiFile; -pub use listener::WasiListener; pub use sched::{Poll, WasiSched}; pub use stream::{InputStream, OutputStream}; pub use table::Table; -pub use tcp_listener::WasiTcpListener; +pub use tcp_socket::WasiTcpSocket; +pub use udp_socket::WasiUdpSocket; diff --git a/wasi-common/src/listener.rs b/wasi-common/src/listener.rs deleted file mode 100644 index b8eb34a71741..000000000000 --- a/wasi-common/src/listener.rs +++ /dev/null @@ -1,39 +0,0 @@ -//! Socket listeners. - -use crate::connection::WasiConnection; -use crate::Error; -use crate::{InputStream, OutputStream}; -use std::any::Any; - -/// A socket listener. -#[async_trait::async_trait] -pub trait WasiListener: Send + Sync { - fn as_any(&self) -> &dyn Any; - - async fn accept( - &mut self, - nonblocking: bool, - ) -> Result< - ( - Box, - Box, - Box, - ), - Error, - >; - - fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>; -} - -pub trait TableListenerExt { - fn get_listener(&self, fd: u32) -> Result<&dyn WasiListener, Error>; - fn get_listener_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; -} -impl TableListenerExt for crate::table::Table { - fn get_listener(&self, fd: u32) -> Result<&dyn WasiListener, Error> { - self.get::>(fd).map(|f| f.as_ref()) - } - fn get_listener_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { - self.get_mut::>(fd) - } -} diff --git a/wasi-common/src/tcp_listener.rs b/wasi-common/src/tcp_listener.rs deleted file mode 100644 index 337567a838dc..000000000000 --- a/wasi-common/src/tcp_listener.rs +++ /dev/null @@ -1,44 +0,0 @@ -//! TCP socket listeners. - -use crate::connection::WasiConnection; -use crate::Error; -use crate::WasiListener; -use crate::{InputStream, OutputStream}; -use std::any::Any; -use std::net::SocketAddr; - -/// A TCP socket listener. -#[async_trait::async_trait] -pub trait WasiTcpListener: Send + Sync { - fn as_any(&self) -> &dyn Any; - - async fn accept( - &mut self, - nonblocking: bool, - ) -> Result< - ( - Box, - Box, - Box, - SocketAddr, - ), - Error, - >; - - fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>; - - fn into_listener(self) -> Box; -} - -pub trait TableTcpListenerExt { - fn get_tcp_listener(&self, fd: u32) -> Result<&dyn WasiTcpListener, Error>; - fn get_tcp_listener_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; -} -impl TableTcpListenerExt for crate::table::Table { - fn get_tcp_listener(&self, fd: u32) -> Result<&dyn WasiTcpListener, Error> { - self.get::>(fd).map(|f| f.as_ref()) - } - fn get_tcp_listener_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { - self.get_mut::>(fd) - } -} diff --git a/wasi-common/src/tcp_socket.rs b/wasi-common/src/tcp_socket.rs new file mode 100644 index 000000000000..4d6b00a5a684 --- /dev/null +++ b/wasi-common/src/tcp_socket.rs @@ -0,0 +1,54 @@ +//! TCP sockets. + +use crate::Error; +use crate::{InputStream, OutputStream}; +use bitflags::bitflags; +use std::any::Any; +use std::net::SocketAddr; + +/// A TCP socket. +#[async_trait::async_trait] +pub trait WasiTcpSocket: Send + Sync { + fn as_any(&self) -> &dyn Any; + + async fn accept( + &mut self, + nonblocking: bool, + ) -> Result< + ( + Box, + Box, + Box, + SocketAddr, + ), + Error, + >; + + async fn sock_shutdown(&mut self, how: SdFlags) -> Result<(), Error>; + + fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>; + + async fn readable(&self) -> Result<(), Error>; + + async fn writable(&self) -> Result<(), Error>; +} + +bitflags! { + pub struct SdFlags: u32 { + const READ = 0b1; + const WRITE = 0b10; + } +} + +pub trait TableTcpSocketExt { + fn get_tcp_socket(&self, fd: u32) -> Result<&dyn WasiTcpSocket, Error>; + fn get_tcp_socket_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; +} +impl TableTcpSocketExt for crate::table::Table { + fn get_tcp_socket(&self, fd: u32) -> Result<&dyn WasiTcpSocket, Error> { + self.get::>(fd).map(|f| f.as_ref()) + } + fn get_tcp_socket_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { + self.get_mut::>(fd) + } +} diff --git a/wasi-common/src/udp_socket.rs b/wasi-common/src/udp_socket.rs new file mode 100644 index 000000000000..f849c656c1c1 --- /dev/null +++ b/wasi-common/src/udp_socket.rs @@ -0,0 +1,51 @@ +//! UDP sockets. + +use crate::Error; +use bitflags::bitflags; +use std::any::Any; + +/// A UDP socket. +#[async_trait::async_trait] +pub trait WasiUdpSocket: Send + Sync { + fn as_any(&self) -> &dyn Any; + + async fn sock_recv<'a>( + &mut self, + ri_data: &mut [std::io::IoSliceMut<'a>], + ri_flags: RiFlags, + ) -> Result<(u64, RoFlags), Error>; + + async fn sock_send<'a>(&mut self, si_data: &[std::io::IoSlice<'a>]) -> Result; + + fn set_nonblocking(&mut self, flag: bool) -> Result<(), Error>; + + async fn readable(&self) -> Result<(), Error>; + + async fn writable(&self) -> Result<(), Error>; +} + +bitflags! { + pub struct RoFlags: u32 { + const RECV_DATA_TRUNCATED = 0b1; + } +} + +bitflags! { + pub struct RiFlags: u32 { + const RECV_PEEK = 0b1; + const RECV_WAITALL = 0b10; + } +} + +pub trait TableUdpSocketExt { + fn get_udp_socket(&self, fd: u32) -> Result<&dyn WasiUdpSocket, Error>; + fn get_udp_socket_mut(&mut self, fd: u32) -> Result<&mut Box, Error>; +} +impl TableUdpSocketExt for crate::table::Table { + fn get_udp_socket(&self, fd: u32) -> Result<&dyn WasiUdpSocket, Error> { + self.get::>(fd).map(|f| f.as_ref()) + } + fn get_udp_socket_mut(&mut self, fd: u32) -> Result<&mut Box, Error> { + self.get_mut::>(fd) + } +}