Skip to content

Commit

Permalink
net: add recv_buf for UdpSocket and UnixDatagram (#5583)
Browse files Browse the repository at this point in the history
  • Loading branch information
newfla committed Mar 28, 2023
1 parent 663e56e commit b31f1a4
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 7 deletions.
122 changes: 115 additions & 7 deletions tokio/src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ impl UdpSocket {
/// address to which it is connected. On success, returns the number of
/// bytes read.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
Expand Down Expand Up @@ -881,10 +881,12 @@ impl UdpSocket {
/// Tries to receive data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
///
Expand Down Expand Up @@ -931,25 +933,75 @@ impl UdpSocket {
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

let n = (*self.io).recv(dst)?;

// Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok(n)
})
}

/// Receives a single datagram message on the socket from the remote address
/// to which it is connected, advancing the buffer's internal cursor,
/// returning how many bytes were read.
///
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UdpSocket;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
/// socket.connect("127.0.0.1:8081").await?;
///
/// let mut buf = Vec::with_capacity(512);
/// let len = socket.recv_buf(&mut buf).await?;
///
/// println!("received {} bytes {:?}", len, &buf[..len]);
///
/// Ok(())
/// }
/// ```
pub async fn recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

let n = (*self.io).recv(dst)?;

// Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok(n)
})
}).await
}

/// Tries to receive a single datagram message on the socket. On success,
/// returns the number of bytes read and the origin.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
///
Expand Down Expand Up @@ -1004,17 +1056,73 @@ impl UdpSocket {
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
let (n, addr) = (*self.io).recv_from(dst)?;

// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok((n, addr))
})
}

/// Receives a single datagram message on the socket, advancing the
/// buffer's internal cursor, returning how many bytes were read and the origin.
///
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Notes
/// Note that the socket address **cannot** be implicitly trusted, because it is relatively
/// trivial to send a UDP datagram with a spoofed origin in a [packet injection attack].
/// Because UDP is stateless and does not validate the origin of a packet,
/// the attacker does not need to be able to intercept traffic in order to interfere.
/// It is important to be aware of this when designing your application-level protocol.
///
/// [packet injection attack]: https://en.wikipedia.org/wiki/Packet_injection
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UdpSocket;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
/// socket.connect("127.0.0.1:8081").await?;
///
/// let mut buf = Vec::with_capacity(512);
/// let (len, addr) = socket.recv_buf_from(&mut buf).await?;
///
/// println!("received {:?} bytes from {:?}", len, addr);
///
/// Ok(())
/// }
/// ```
pub async fn recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

let (n, addr) = (*self.io).recv_from(dst)?;

// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok((n,addr))
}).await
}
}

/// Sends data on the socket to the given address. On success, returns the
Expand Down Expand Up @@ -1252,7 +1360,7 @@ impl UdpSocket {
/// Tries to receive a single datagram message on the socket. On success,
/// returns the number of bytes read and the origin.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
Expand Down
103 changes: 103 additions & 0 deletions tokio/src/net/unix/datagram/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,8 @@ impl UnixDatagram {
cfg_io_util! {
/// Tries to receive data from the socket without waiting.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
///
/// ```no_run
Expand Down Expand Up @@ -866,9 +868,64 @@ impl UnixDatagram {
Ok((n, SocketAddr(addr)))
}

/// Receives from the socket, advances the
/// buffer's internal cursor and returns how many bytes were read and the origin.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
/// ```
/// # use std::error::Error;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn Error>> {
/// use tokio::net::UnixDatagram;
/// use tempfile::tempdir;
///
/// // We use a temporary directory so that the socket
/// // files left by the bound sockets will get cleaned up.
/// let tmp = tempdir()?;
///
/// // Bind each socket to a filesystem path
/// let tx_path = tmp.path().join("tx");
/// let tx = UnixDatagram::bind(&tx_path)?;
/// let rx_path = tmp.path().join("rx");
/// let rx = UnixDatagram::bind(&rx_path)?;
///
/// let bytes = b"hello world";
/// tx.send_to(bytes, &rx_path).await?;
///
/// let mut buf = Vec::with_capacity(24);
/// let (size, addr) = rx.recv_buf_from(&mut buf).await?;
///
/// let dgram = &buf[..size];
/// assert_eq!(dgram, bytes);
/// assert_eq!(addr.as_pathname().unwrap(), &tx_path);
///
/// # Ok(())
/// # }
/// ```
pub async fn recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

// Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the
// buffer.
let (n, addr) = (*self.io).recv_from(dst)?;

unsafe {
buf.advance_mut(n);
}
Ok((n,SocketAddr(addr)))
}).await
}

/// Tries to read data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
///
/// ```no_run
Expand Down Expand Up @@ -926,6 +983,52 @@ impl UnixDatagram {
Ok(n)
})
}

/// Receives data from the socket from the address to which it is connected,
/// advancing the buffer's internal cursor, returning how many bytes were read.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
/// ```
/// # use std::error::Error;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn Error>> {
/// use tokio::net::UnixDatagram;
///
/// // Create the pair of sockets
/// let (sock1, sock2) = UnixDatagram::pair()?;
///
/// // Since the sockets are paired, the paired send/recv
/// // functions can be used
/// let bytes = b"hello world";
/// sock1.send(bytes).await?;
///
/// let mut buff = Vec::with_capacity(24);
/// let size = sock2.recv_buf(&mut buff).await?;
///
/// let dgram = &buff[..size];
/// assert_eq!(dgram, bytes);
///
/// # Ok(())
/// # }
/// ```
pub async fn recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

// Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the
// buffer.
let n = (*self.io).recv(dst)?;

unsafe {
buf.advance_mut(n);
}
Ok(n)
}).await
}
}

/// Sends data on the socket to the specified address.
Expand Down
34 changes: 34 additions & 0 deletions tokio/tests/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,23 @@ async fn try_recv_buf() {
}
}

#[tokio::test]
async fn recv_buf() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

sender.connect(receiver.local_addr()?).await?;
receiver.connect(sender.local_addr()?).await?;

sender.send(MSG).await?;
let mut recv_buf = Vec::with_capacity(32);
let len = receiver.recv_buf(&mut recv_buf).await?;

assert_eq!(len, MSG_LEN);
assert_eq!(&recv_buf[..len], MSG);
Ok(())
}

#[tokio::test]
async fn try_recv_buf_from() {
// Create listener
Expand Down Expand Up @@ -567,6 +584,23 @@ async fn try_recv_buf_from() {
}
}

#[tokio::test]
async fn recv_buf_from() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

sender.connect(receiver.local_addr()?).await?;

sender.send(MSG).await?;
let mut recv_buf = Vec::with_capacity(32);
let (len, caddr) = receiver.recv_buf_from(&mut recv_buf).await?;

assert_eq!(len, MSG_LEN);
assert_eq!(&recv_buf[..len], MSG);
assert_eq!(caddr, sender.local_addr()?);
Ok(())
}

#[tokio::test]
async fn poll_ready() {
// Create listener
Expand Down
Loading

0 comments on commit b31f1a4

Please sign in to comment.