Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: Add recv_buf for UdpSocket and UnixDatagram #5583

Merged
merged 3 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 115 additions & 7 deletions tokio/src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ impl UdpSocket {
/// address to which it is connected. On success, returns the number of
/// bytes read.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
Expand Down Expand Up @@ -881,10 +881,12 @@ impl UdpSocket {
/// Tries to receive data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
///
Expand Down Expand Up @@ -931,25 +933,75 @@ impl UdpSocket {
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

let n = (*self.io).recv(dst)?;

// Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok(n)
})
}

/// Receives a single datagram message on the socket from the remote address
/// to which it is connected, advancing the buffer's internal cursor,
/// returning how many bytes were read.
///
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UdpSocket;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
/// socket.connect("127.0.0.1:8081").await?;
///
/// let mut buf = Vec::with_capacity(512);
/// let len = socket.recv_buf(&mut buf).await?;
///
/// println!("received {} bytes {:?}", len, &buf[..len]);
///
/// Ok(())
/// }
/// ```
pub async fn recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

let n = (*self.io).recv(dst)?;

// Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok(n)
})
}).await
}

/// Tries to receive a single datagram message on the socket. On success,
/// returns the number of bytes read and the origin.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
///
Expand Down Expand Up @@ -1004,17 +1056,73 @@ impl UdpSocket {
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
let (n, addr) = (*self.io).recv_from(dst)?;

// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok((n, addr))
})
}

/// Receives a single datagram message on the socket, advancing the
/// buffer's internal cursor, returning how many bytes were read and the origin.
///
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Notes
/// Note that the socket address **cannot** be implicitly trusted, because it is relatively
/// trivial to send a UDP datagram with a spoofed origin in a [packet injection attack].
/// Because UDP is stateless and does not validate the origin of a packet,
/// the attacker does not need to be able to intercept traffic in order to interfere.
/// It is important to be aware of this when designing your application-level protocol.
///
/// [packet injection attack]: https://en.wikipedia.org/wiki/Packet_injection
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UdpSocket;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
/// socket.connect("127.0.0.1:8081").await?;
///
/// let mut buf = Vec::with_capacity(512);
/// let (len, addr) = socket.recv_buf_from(&mut buf).await?;
///
/// println!("received {:?} bytes from {:?}", len, addr);
///
/// Ok(())
/// }
/// ```
pub async fn recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

let (n, addr) = (*self.io).recv_from(dst)?;

// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
unsafe {
buf.advance_mut(n);
}

Ok((n,addr))
}).await
}
}

/// Sends data on the socket to the given address. On success, returns the
Expand Down Expand Up @@ -1252,7 +1360,7 @@ impl UdpSocket {
/// Tries to receive a single datagram message on the socket. On success,
/// returns the number of bytes read and the origin.
///
/// The function must be called with valid byte array buf of sufficient size
/// This method must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
Expand Down
103 changes: 103 additions & 0 deletions tokio/src/net/unix/datagram/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,8 @@ impl UnixDatagram {
cfg_io_util! {
/// Tries to receive data from the socket without waiting.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
///
/// ```no_run
Expand Down Expand Up @@ -866,9 +868,64 @@ impl UnixDatagram {
Ok((n, SocketAddr(addr)))
}

/// Receives from the socket, advances the
/// buffer's internal cursor and returns how many bytes were read and the origin.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
/// ```
/// # use std::error::Error;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn Error>> {
/// use tokio::net::UnixDatagram;
/// use tempfile::tempdir;
///
/// // We use a temporary directory so that the socket
/// // files left by the bound sockets will get cleaned up.
/// let tmp = tempdir()?;
///
/// // Bind each socket to a filesystem path
/// let tx_path = tmp.path().join("tx");
/// let tx = UnixDatagram::bind(&tx_path)?;
/// let rx_path = tmp.path().join("rx");
/// let rx = UnixDatagram::bind(&rx_path)?;
///
/// let bytes = b"hello world";
/// tx.send_to(bytes, &rx_path).await?;
///
/// let mut buf = Vec::with_capacity(24);
/// let (size, addr) = rx.recv_buf_from(&mut buf).await?;
///
/// let dgram = &buf[..size];
/// assert_eq!(dgram, bytes);
/// assert_eq!(addr.as_pathname().unwrap(), &tx_path);
///
/// # Ok(())
/// # }
/// ```
pub async fn recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

// Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the
// buffer.
let (n, addr) = (*self.io).recv_from(dst)?;

unsafe {
buf.advance_mut(n);
}
Ok((n,SocketAddr(addr)))
}).await
}

/// Tries to read data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
///
/// ```no_run
Expand Down Expand Up @@ -926,6 +983,52 @@ impl UnixDatagram {
Ok(n)
})
}

/// Receives data from the socket from the address to which it is connected,
/// advancing the buffer's internal cursor, returning how many bytes were read.
///
/// This method can be used even if `buf` is uninitialized.
///
/// # Examples
/// ```
/// # use std::error::Error;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn Error>> {
/// use tokio::net::UnixDatagram;
///
/// // Create the pair of sockets
/// let (sock1, sock2) = UnixDatagram::pair()?;
///
/// // Since the sockets are paired, the paired send/recv
/// // functions can be used
/// let bytes = b"hello world";
/// sock1.send(bytes).await?;
///
/// let mut buff = Vec::with_capacity(24);
/// let size = sock2.recv_buf(&mut buff).await?;
///
/// let dgram = &buff[..size];
/// assert_eq!(dgram, bytes);
///
/// # Ok(())
/// # }
/// ```
pub async fn recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().async_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };

// Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the
// buffer.
let n = (*self.io).recv(dst)?;

unsafe {
buf.advance_mut(n);
}
Ok(n)
}).await
}
}

/// Sends data on the socket to the specified address.
Expand Down
34 changes: 34 additions & 0 deletions tokio/tests/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,23 @@ async fn try_recv_buf() {
}
}

#[tokio::test]
async fn recv_buf() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

sender.connect(receiver.local_addr()?).await?;
receiver.connect(sender.local_addr()?).await?;

sender.send(MSG).await?;
let mut recv_buf = Vec::with_capacity(32);
let len = receiver.recv_buf(&mut recv_buf).await?;

assert_eq!(len, MSG_LEN);
assert_eq!(&recv_buf[..len], MSG);
Ok(())
}

#[tokio::test]
async fn try_recv_buf_from() {
// Create listener
Expand Down Expand Up @@ -567,6 +584,23 @@ async fn try_recv_buf_from() {
}
}

#[tokio::test]
async fn recv_buf_from() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

sender.connect(receiver.local_addr()?).await?;

sender.send(MSG).await?;
let mut recv_buf = Vec::with_capacity(32);
let (len, caddr) = receiver.recv_buf_from(&mut recv_buf).await?;

assert_eq!(len, MSG_LEN);
assert_eq!(&recv_buf[..len], MSG);
assert_eq!(caddr, sender.local_addr()?);
Ok(())
}

#[tokio::test]
async fn poll_ready() {
// Create listener
Expand Down
Loading