Skip to content

Commit

Permalink
Fix quadratic behavior of repeated vectored writes
Browse files Browse the repository at this point in the history
Some implementations of `Write::write_vectored` in the standard
library (`BufWriter`, `LineWriter`, `Stdout`, `Stderr`) check all
buffers to calculate the total length. This is O(n) over the number of
buffers.

It's common that only a limited number of buffers is written at a
time (e.g. 1024 for `writev(2)`). `write_vectored_all` will then call
`write_vectored` repeatedly, leading to a runtime of O(n²) over the
number of buffers.

The fix is to only calculate as much as needed if it's needed.
  • Loading branch information
blyxxyz committed Mar 3, 2024
1 parent 3793e5b commit 8212fc5
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
51 changes: 27 additions & 24 deletions library/std/src/io/buffered/bufwriter.rs
Expand Up @@ -554,35 +554,38 @@ impl<W: ?Sized + Write> Write for BufWriter<W> {
// same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
// computation overflows, then surely the input cannot fit in our buffer, so we forward
// to the inner writer's `write_vectored` method to let it handle it appropriately.
let saturated_total_len =
bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len()));
let mut saturated_total_len: usize = 0;

if saturated_total_len > self.spare_capacity() {
// Flush if the total length of the input exceeds our buffer's spare capacity.
// If we would have overflowed, this condition also holds, and we need to flush.
self.flush_buf()?;
for buf in bufs {
saturated_total_len = saturated_total_len.saturating_add(buf.len());

if saturated_total_len > self.spare_capacity() && !self.buf.is_empty() {
// Flush if the total length of the input exceeds our buffer's spare capacity.
// If we would have overflowed, this condition also holds, and we need to flush.
self.flush_buf()?;
}

if saturated_total_len >= self.buf.capacity() {
// Forward to our inner writer if the total length of the input is greater than or
// equal to our buffer capacity. If we would have overflowed, this condition also
// holds, and we punt to the inner writer.
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
return r;
}
}

if saturated_total_len >= self.buf.capacity() {
// Forward to our inner writer if the total length of the input is greater than or
// equal to our buffer capacity. If we would have overflowed, this condition also
// holds, and we punt to the inner writer.
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
r
} else {
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.

// SAFETY: We checked whether or not the spare capacity was large enough above. If
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
// room for any input <= the buffer size, which includes this input.
unsafe {
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
};
// SAFETY: We checked whether or not the spare capacity was large enough above. If
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
// room for any input <= the buffer size, which includes this input.
unsafe {
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
};

Ok(saturated_total_len)
}
Ok(saturated_total_len)
} else {
let mut iter = bufs.iter();
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
Expand Down
15 changes: 12 additions & 3 deletions library/std/src/io/buffered/linewritershim.rs
Expand Up @@ -175,6 +175,10 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> {
}

// Find the buffer containing the last newline
// FIXME: This is overly slow if there are very many bufs and none contain
// newlines. e.g. writev() on Linux only writes up to 1024 slices, so
// scanning the rest is wasted effort. This makes write_all_vectored()
// quadratic.
let last_newline_buf_idx = bufs
.iter()
.enumerate()
Expand Down Expand Up @@ -215,9 +219,14 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> {

// Don't try to reconstruct the exact amount written; just bail
// in the event of a partial write
let lines_len = lines.iter().map(|buf| buf.len()).sum();
if flushed < lines_len {
return Ok(flushed);
let mut lines_len: usize = 0;
for buf in lines {
// With overlapping/duplicate slices the total length may in theory
// exceed usize::MAX
lines_len = lines_len.saturating_add(buf.len());
if flushed < lines_len {
return Ok(flushed);
}
}

// Now that the write has succeeded, buffer the rest (or as much of the
Expand Down
15 changes: 11 additions & 4 deletions library/std/src/io/stdio.rs
Expand Up @@ -128,8 +128,8 @@ impl Write for StdoutRaw {
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let total = bufs.iter().map(|b| b.len()).sum();
handle_ebadf(self.0.write_vectored(bufs), total)
let total = || bufs.iter().map(|b| b.len()).sum();
handle_ebadf_lazy(self.0.write_vectored(bufs), total)
}

#[inline]
Expand Down Expand Up @@ -160,8 +160,8 @@ impl Write for StderrRaw {
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let total = bufs.iter().map(|b| b.len()).sum();
handle_ebadf(self.0.write_vectored(bufs), total)
let total = || bufs.iter().map(|b| b.len()).sum();
handle_ebadf_lazy(self.0.write_vectored(bufs), total)
}

#[inline]
Expand Down Expand Up @@ -193,6 +193,13 @@ fn handle_ebadf<T>(r: io::Result<T>, default: T) -> io::Result<T> {
}
}

fn handle_ebadf_lazy<T>(r: io::Result<T>, default: impl FnOnce() -> T) -> io::Result<T> {
match r {
Err(ref e) if stdio::is_ebadf(e) => Ok(default()),
r => r,
}
}

/// A handle to the standard input stream of a process.
///
/// Each handle is a shared reference to a global buffer of input data to this
Expand Down

0 comments on commit 8212fc5

Please sign in to comment.